1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use futures::future::Shared;
11use language::language_settings::FormatOnSave;
12pub use mention::*;
13use project::lsp_store::{FormatTrigger, LspFormatTarget};
14use serde::{Deserialize, Serialize};
15use settings::Settings as _;
16pub use terminal::*;
17
18use action_log::ActionLog;
19use agent_client_protocol::{self as acp};
20use anyhow::{Context as _, Result, anyhow};
21use editor::Bias;
22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
24use itertools::Itertools;
25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
26use markdown::Markdown;
27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
28use std::collections::HashMap;
29use std::error::Error;
30use std::fmt::{Formatter, Write};
31use std::ops::Range;
32use std::process::ExitStatus;
33use std::rc::Rc;
34use std::time::{Duration, Instant};
35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
36use ui::App;
37use util::{ResultExt, get_system_shell};
38use uuid::Uuid;
39
40#[derive(Debug)]
41pub struct UserMessage {
42 pub id: Option<UserMessageId>,
43 pub content: ContentBlock,
44 pub chunks: Vec<acp::ContentBlock>,
45 pub checkpoint: Option<Checkpoint>,
46}
47
48#[derive(Debug)]
49pub struct Checkpoint {
50 git_checkpoint: GitStoreCheckpoint,
51 pub show: bool,
52}
53
54impl UserMessage {
55 fn to_markdown(&self, cx: &App) -> String {
56 let mut markdown = String::new();
57 if self
58 .checkpoint
59 .as_ref()
60 .is_some_and(|checkpoint| checkpoint.show)
61 {
62 writeln!(markdown, "## User (checkpoint)").unwrap();
63 } else {
64 writeln!(markdown, "## User").unwrap();
65 }
66 writeln!(markdown).unwrap();
67 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
68 writeln!(markdown).unwrap();
69 markdown
70 }
71}
72
73#[derive(Debug, PartialEq)]
74pub struct AssistantMessage {
75 pub chunks: Vec<AssistantMessageChunk>,
76}
77
78impl AssistantMessage {
79 pub fn to_markdown(&self, cx: &App) -> String {
80 format!(
81 "## Assistant\n\n{}\n\n",
82 self.chunks
83 .iter()
84 .map(|chunk| chunk.to_markdown(cx))
85 .join("\n\n")
86 )
87 }
88}
89
90#[derive(Debug, PartialEq)]
91pub enum AssistantMessageChunk {
92 Message { block: ContentBlock },
93 Thought { block: ContentBlock },
94}
95
96impl AssistantMessageChunk {
97 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
98 Self::Message {
99 block: ContentBlock::new(chunk.into(), language_registry, cx),
100 }
101 }
102
103 fn to_markdown(&self, cx: &App) -> String {
104 match self {
105 Self::Message { block } => block.to_markdown(cx).to_string(),
106 Self::Thought { block } => {
107 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
108 }
109 }
110 }
111}
112
113#[derive(Debug)]
114pub enum AgentThreadEntry {
115 UserMessage(UserMessage),
116 AssistantMessage(AssistantMessage),
117 ToolCall(ToolCall),
118}
119
120impl AgentThreadEntry {
121 pub fn to_markdown(&self, cx: &App) -> String {
122 match self {
123 Self::UserMessage(message) => message.to_markdown(cx),
124 Self::AssistantMessage(message) => message.to_markdown(cx),
125 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
126 }
127 }
128
129 pub fn user_message(&self) -> Option<&UserMessage> {
130 if let AgentThreadEntry::UserMessage(message) = self {
131 Some(message)
132 } else {
133 None
134 }
135 }
136
137 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
138 if let AgentThreadEntry::ToolCall(call) = self {
139 itertools::Either::Left(call.diffs())
140 } else {
141 itertools::Either::Right(std::iter::empty())
142 }
143 }
144
145 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
146 if let AgentThreadEntry::ToolCall(call) = self {
147 itertools::Either::Left(call.terminals())
148 } else {
149 itertools::Either::Right(std::iter::empty())
150 }
151 }
152
153 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
154 if let AgentThreadEntry::ToolCall(ToolCall {
155 locations,
156 resolved_locations,
157 ..
158 }) = self
159 {
160 Some((
161 locations.get(ix)?.clone(),
162 resolved_locations.get(ix)?.clone()?,
163 ))
164 } else {
165 None
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct ToolCall {
172 pub id: acp::ToolCallId,
173 pub label: Entity<Markdown>,
174 pub kind: acp::ToolKind,
175 pub content: Vec<ToolCallContent>,
176 pub status: ToolCallStatus,
177 pub locations: Vec<acp::ToolCallLocation>,
178 pub resolved_locations: Vec<Option<AgentLocation>>,
179 pub raw_input: Option<serde_json::Value>,
180 pub raw_output: Option<serde_json::Value>,
181}
182
183impl ToolCall {
184 fn from_acp(
185 tool_call: acp::ToolCall,
186 status: ToolCallStatus,
187 language_registry: Arc<LanguageRegistry>,
188 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
189 cx: &mut App,
190 ) -> Result<Self> {
191 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
192 first_line.to_owned() + "…"
193 } else {
194 tool_call.title
195 };
196 let mut content = Vec::with_capacity(tool_call.content.len());
197 for item in tool_call.content {
198 content.push(ToolCallContent::from_acp(
199 item,
200 language_registry.clone(),
201 terminals,
202 cx,
203 )?);
204 }
205
206 let result = Self {
207 id: tool_call.id,
208 label: cx
209 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
210 kind: tool_call.kind,
211 content,
212 locations: tool_call.locations,
213 resolved_locations: Vec::default(),
214 status,
215 raw_input: tool_call.raw_input,
216 raw_output: tool_call.raw_output,
217 };
218 Ok(result)
219 }
220
221 fn update_fields(
222 &mut self,
223 fields: acp::ToolCallUpdateFields,
224 language_registry: Arc<LanguageRegistry>,
225 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
226 cx: &mut App,
227 ) -> Result<()> {
228 let acp::ToolCallUpdateFields {
229 kind,
230 status,
231 title,
232 content,
233 locations,
234 raw_input,
235 raw_output,
236 } = fields;
237
238 if let Some(kind) = kind {
239 self.kind = kind;
240 }
241
242 if let Some(status) = status {
243 self.status = status.into();
244 }
245
246 if let Some(title) = title {
247 self.label.update(cx, |label, cx| {
248 if let Some((first_line, _)) = title.split_once("\n") {
249 label.replace(first_line.to_owned() + "…", cx)
250 } else {
251 label.replace(title, cx);
252 }
253 });
254 }
255
256 if let Some(content) = content {
257 let new_content_len = content.len();
258 let mut content = content.into_iter();
259
260 // Reuse existing content if we can
261 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
262 old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
263 }
264 for new in content {
265 self.content.push(ToolCallContent::from_acp(
266 new,
267 language_registry.clone(),
268 terminals,
269 cx,
270 )?)
271 }
272 self.content.truncate(new_content_len);
273 }
274
275 if let Some(locations) = locations {
276 self.locations = locations;
277 }
278
279 if let Some(raw_input) = raw_input {
280 self.raw_input = Some(raw_input);
281 }
282
283 if let Some(raw_output) = raw_output {
284 if self.content.is_empty()
285 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
286 {
287 self.content
288 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
289 markdown,
290 }));
291 }
292 self.raw_output = Some(raw_output);
293 }
294 Ok(())
295 }
296
297 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
298 self.content.iter().filter_map(|content| match content {
299 ToolCallContent::Diff(diff) => Some(diff),
300 ToolCallContent::ContentBlock(_) => None,
301 ToolCallContent::Terminal(_) => None,
302 })
303 }
304
305 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
306 self.content.iter().filter_map(|content| match content {
307 ToolCallContent::Terminal(terminal) => Some(terminal),
308 ToolCallContent::ContentBlock(_) => None,
309 ToolCallContent::Diff(_) => None,
310 })
311 }
312
313 fn to_markdown(&self, cx: &App) -> String {
314 let mut markdown = format!(
315 "**Tool Call: {}**\nStatus: {}\n\n",
316 self.label.read(cx).source(),
317 self.status
318 );
319 for content in &self.content {
320 markdown.push_str(content.to_markdown(cx).as_str());
321 markdown.push_str("\n\n");
322 }
323 markdown
324 }
325
326 async fn resolve_location(
327 location: acp::ToolCallLocation,
328 project: WeakEntity<Project>,
329 cx: &mut AsyncApp,
330 ) -> Option<AgentLocation> {
331 let buffer = project
332 .update(cx, |project, cx| {
333 project
334 .project_path_for_absolute_path(&location.path, cx)
335 .map(|path| project.open_buffer(path, cx))
336 })
337 .ok()??;
338 let buffer = buffer.await.log_err()?;
339 let position = buffer
340 .update(cx, |buffer, _| {
341 if let Some(row) = location.line {
342 let snapshot = buffer.snapshot();
343 let column = snapshot.indent_size_for_line(row).len;
344 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
345 snapshot.anchor_before(point)
346 } else {
347 Anchor::MIN
348 }
349 })
350 .ok()?;
351
352 Some(AgentLocation {
353 buffer: buffer.downgrade(),
354 position,
355 })
356 }
357
358 fn resolve_locations(
359 &self,
360 project: Entity<Project>,
361 cx: &mut App,
362 ) -> Task<Vec<Option<AgentLocation>>> {
363 let locations = self.locations.clone();
364 project.update(cx, |_, cx| {
365 cx.spawn(async move |project, cx| {
366 let mut new_locations = Vec::new();
367 for location in locations {
368 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
369 }
370 new_locations
371 })
372 })
373 }
374}
375
376#[derive(Debug)]
377pub enum ToolCallStatus {
378 /// The tool call hasn't started running yet, but we start showing it to
379 /// the user.
380 Pending,
381 /// The tool call is waiting for confirmation from the user.
382 WaitingForConfirmation {
383 options: Vec<acp::PermissionOption>,
384 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
385 },
386 /// The tool call is currently running.
387 InProgress,
388 /// The tool call completed successfully.
389 Completed,
390 /// The tool call failed.
391 Failed,
392 /// The user rejected the tool call.
393 Rejected,
394 /// The user canceled generation so the tool call was canceled.
395 Canceled,
396}
397
398impl From<acp::ToolCallStatus> for ToolCallStatus {
399 fn from(status: acp::ToolCallStatus) -> Self {
400 match status {
401 acp::ToolCallStatus::Pending => Self::Pending,
402 acp::ToolCallStatus::InProgress => Self::InProgress,
403 acp::ToolCallStatus::Completed => Self::Completed,
404 acp::ToolCallStatus::Failed => Self::Failed,
405 }
406 }
407}
408
409impl Display for ToolCallStatus {
410 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411 write!(
412 f,
413 "{}",
414 match self {
415 ToolCallStatus::Pending => "Pending",
416 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
417 ToolCallStatus::InProgress => "In Progress",
418 ToolCallStatus::Completed => "Completed",
419 ToolCallStatus::Failed => "Failed",
420 ToolCallStatus::Rejected => "Rejected",
421 ToolCallStatus::Canceled => "Canceled",
422 }
423 )
424 }
425}
426
427#[derive(Debug, PartialEq, Clone)]
428pub enum ContentBlock {
429 Empty,
430 Markdown { markdown: Entity<Markdown> },
431 ResourceLink { resource_link: acp::ResourceLink },
432}
433
434impl ContentBlock {
435 pub fn new(
436 block: acp::ContentBlock,
437 language_registry: &Arc<LanguageRegistry>,
438 cx: &mut App,
439 ) -> Self {
440 let mut this = Self::Empty;
441 this.append(block, language_registry, cx);
442 this
443 }
444
445 pub fn new_combined(
446 blocks: impl IntoIterator<Item = acp::ContentBlock>,
447 language_registry: Arc<LanguageRegistry>,
448 cx: &mut App,
449 ) -> Self {
450 let mut this = Self::Empty;
451 for block in blocks {
452 this.append(block, &language_registry, cx);
453 }
454 this
455 }
456
457 pub fn append(
458 &mut self,
459 block: acp::ContentBlock,
460 language_registry: &Arc<LanguageRegistry>,
461 cx: &mut App,
462 ) {
463 if matches!(self, ContentBlock::Empty)
464 && let acp::ContentBlock::ResourceLink(resource_link) = block
465 {
466 *self = ContentBlock::ResourceLink { resource_link };
467 return;
468 }
469
470 let new_content = self.block_string_contents(block);
471
472 match self {
473 ContentBlock::Empty => {
474 *self = Self::create_markdown_block(new_content, language_registry, cx);
475 }
476 ContentBlock::Markdown { markdown } => {
477 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
478 }
479 ContentBlock::ResourceLink { resource_link } => {
480 let existing_content = Self::resource_link_md(&resource_link.uri);
481 let combined = format!("{}\n{}", existing_content, new_content);
482
483 *self = Self::create_markdown_block(combined, language_registry, cx);
484 }
485 }
486 }
487
488 fn create_markdown_block(
489 content: String,
490 language_registry: &Arc<LanguageRegistry>,
491 cx: &mut App,
492 ) -> ContentBlock {
493 ContentBlock::Markdown {
494 markdown: cx
495 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
496 }
497 }
498
499 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
500 match block {
501 acp::ContentBlock::Text(text_content) => text_content.text,
502 acp::ContentBlock::ResourceLink(resource_link) => {
503 Self::resource_link_md(&resource_link.uri)
504 }
505 acp::ContentBlock::Resource(acp::EmbeddedResource {
506 resource:
507 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
508 uri,
509 ..
510 }),
511 ..
512 }) => Self::resource_link_md(&uri),
513 acp::ContentBlock::Image(image) => Self::image_md(&image),
514 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
515 }
516 }
517
518 fn resource_link_md(uri: &str) -> String {
519 if let Some(uri) = MentionUri::parse(uri).log_err() {
520 uri.as_link().to_string()
521 } else {
522 uri.to_string()
523 }
524 }
525
526 fn image_md(_image: &acp::ImageContent) -> String {
527 "`Image`".into()
528 }
529
530 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
531 match self {
532 ContentBlock::Empty => "",
533 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
534 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
535 }
536 }
537
538 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
539 match self {
540 ContentBlock::Empty => None,
541 ContentBlock::Markdown { markdown } => Some(markdown),
542 ContentBlock::ResourceLink { .. } => None,
543 }
544 }
545
546 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
547 match self {
548 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
549 _ => None,
550 }
551 }
552}
553
554#[derive(Debug)]
555pub enum ToolCallContent {
556 ContentBlock(ContentBlock),
557 Diff(Entity<Diff>),
558 Terminal(Entity<Terminal>),
559}
560
561impl ToolCallContent {
562 pub fn from_acp(
563 content: acp::ToolCallContent,
564 language_registry: Arc<LanguageRegistry>,
565 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
566 cx: &mut App,
567 ) -> Result<Self> {
568 match content {
569 acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
570 content,
571 &language_registry,
572 cx,
573 ))),
574 acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
575 Diff::finalized(
576 diff.path,
577 diff.old_text,
578 diff.new_text,
579 language_registry,
580 cx,
581 )
582 }))),
583 acp::ToolCallContent::Terminal { terminal_id } => terminals
584 .get(&terminal_id)
585 .cloned()
586 .map(Self::Terminal)
587 .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
588 }
589 }
590
591 pub fn update_from_acp(
592 &mut self,
593 new: acp::ToolCallContent,
594 language_registry: Arc<LanguageRegistry>,
595 terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
596 cx: &mut App,
597 ) -> Result<()> {
598 let needs_update = match (&self, &new) {
599 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
600 old_diff.read(cx).needs_update(
601 new_diff.old_text.as_deref().unwrap_or(""),
602 &new_diff.new_text,
603 cx,
604 )
605 }
606 _ => true,
607 };
608
609 if needs_update {
610 *self = Self::from_acp(new, language_registry, terminals, cx)?;
611 }
612 Ok(())
613 }
614
615 pub fn to_markdown(&self, cx: &App) -> String {
616 match self {
617 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
618 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
619 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
620 }
621 }
622}
623
624#[derive(Debug, PartialEq)]
625pub enum ToolCallUpdate {
626 UpdateFields(acp::ToolCallUpdate),
627 UpdateDiff(ToolCallUpdateDiff),
628 UpdateTerminal(ToolCallUpdateTerminal),
629}
630
631impl ToolCallUpdate {
632 fn id(&self) -> &acp::ToolCallId {
633 match self {
634 Self::UpdateFields(update) => &update.id,
635 Self::UpdateDiff(diff) => &diff.id,
636 Self::UpdateTerminal(terminal) => &terminal.id,
637 }
638 }
639}
640
641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
642 fn from(update: acp::ToolCallUpdate) -> Self {
643 Self::UpdateFields(update)
644 }
645}
646
647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
648 fn from(diff: ToolCallUpdateDiff) -> Self {
649 Self::UpdateDiff(diff)
650 }
651}
652
653#[derive(Debug, PartialEq)]
654pub struct ToolCallUpdateDiff {
655 pub id: acp::ToolCallId,
656 pub diff: Entity<Diff>,
657}
658
659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
660 fn from(terminal: ToolCallUpdateTerminal) -> Self {
661 Self::UpdateTerminal(terminal)
662 }
663}
664
665#[derive(Debug, PartialEq)]
666pub struct ToolCallUpdateTerminal {
667 pub id: acp::ToolCallId,
668 pub terminal: Entity<Terminal>,
669}
670
671#[derive(Debug, Default)]
672pub struct Plan {
673 pub entries: Vec<PlanEntry>,
674}
675
676#[derive(Debug)]
677pub struct PlanStats<'a> {
678 pub in_progress_entry: Option<&'a PlanEntry>,
679 pub pending: u32,
680 pub completed: u32,
681}
682
683impl Plan {
684 pub fn is_empty(&self) -> bool {
685 self.entries.is_empty()
686 }
687
688 pub fn stats(&self) -> PlanStats<'_> {
689 let mut stats = PlanStats {
690 in_progress_entry: None,
691 pending: 0,
692 completed: 0,
693 };
694
695 for entry in &self.entries {
696 match &entry.status {
697 acp::PlanEntryStatus::Pending => {
698 stats.pending += 1;
699 }
700 acp::PlanEntryStatus::InProgress => {
701 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
702 }
703 acp::PlanEntryStatus::Completed => {
704 stats.completed += 1;
705 }
706 }
707 }
708
709 stats
710 }
711}
712
713#[derive(Debug)]
714pub struct PlanEntry {
715 pub content: Entity<Markdown>,
716 pub priority: acp::PlanEntryPriority,
717 pub status: acp::PlanEntryStatus,
718}
719
720impl PlanEntry {
721 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
722 Self {
723 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
724 priority: entry.priority,
725 status: entry.status,
726 }
727 }
728}
729
730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
731pub struct TokenUsage {
732 pub max_tokens: u64,
733 pub used_tokens: u64,
734}
735
736impl TokenUsage {
737 pub fn ratio(&self) -> TokenUsageRatio {
738 #[cfg(debug_assertions)]
739 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
740 .unwrap_or("0.8".to_string())
741 .parse()
742 .unwrap();
743 #[cfg(not(debug_assertions))]
744 let warning_threshold: f32 = 0.8;
745
746 // When the maximum is unknown because there is no selected model,
747 // avoid showing the token limit warning.
748 if self.max_tokens == 0 {
749 TokenUsageRatio::Normal
750 } else if self.used_tokens >= self.max_tokens {
751 TokenUsageRatio::Exceeded
752 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
753 TokenUsageRatio::Warning
754 } else {
755 TokenUsageRatio::Normal
756 }
757 }
758}
759
760#[derive(Debug, Clone, PartialEq, Eq)]
761pub enum TokenUsageRatio {
762 Normal,
763 Warning,
764 Exceeded,
765}
766
767#[derive(Debug, Clone)]
768pub struct RetryStatus {
769 pub last_error: SharedString,
770 pub attempt: usize,
771 pub max_attempts: usize,
772 pub started_at: Instant,
773 pub duration: Duration,
774}
775
776pub struct AcpThread {
777 title: SharedString,
778 entries: Vec<AgentThreadEntry>,
779 plan: Plan,
780 project: Entity<Project>,
781 action_log: Entity<ActionLog>,
782 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
783 send_task: Option<Task<()>>,
784 connection: Rc<dyn AgentConnection>,
785 session_id: acp::SessionId,
786 token_usage: Option<TokenUsage>,
787 prompt_capabilities: acp::PromptCapabilities,
788 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
789 determine_shell: Shared<Task<String>>,
790 terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
791}
792
793#[derive(Debug)]
794pub enum AcpThreadEvent {
795 NewEntry,
796 TitleUpdated,
797 TokenUsageUpdated,
798 EntryUpdated(usize),
799 EntriesRemoved(Range<usize>),
800 ToolAuthorizationRequired,
801 Retry(RetryStatus),
802 Stopped,
803 Error,
804 LoadError(LoadError),
805 PromptCapabilitiesUpdated,
806}
807
808impl EventEmitter<AcpThreadEvent> for AcpThread {}
809
810#[derive(PartialEq, Eq, Debug)]
811pub enum ThreadStatus {
812 Idle,
813 WaitingForToolConfirmation,
814 Generating,
815}
816
817#[derive(Debug, Clone)]
818pub enum LoadError {
819 Unsupported {
820 command: SharedString,
821 current_version: SharedString,
822 minimum_version: SharedString,
823 },
824 FailedToInstall(SharedString),
825 Exited {
826 status: ExitStatus,
827 },
828 Other(SharedString),
829}
830
831impl Display for LoadError {
832 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
833 match self {
834 LoadError::Unsupported {
835 command: path,
836 current_version,
837 minimum_version,
838 } => {
839 write!(
840 f,
841 "version {current_version} from {path} is not supported (need at least {minimum_version})"
842 )
843 }
844 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
845 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
846 LoadError::Other(msg) => write!(f, "{msg}"),
847 }
848 }
849}
850
851impl Error for LoadError {}
852
853impl AcpThread {
854 pub fn new(
855 title: impl Into<SharedString>,
856 connection: Rc<dyn AgentConnection>,
857 project: Entity<Project>,
858 action_log: Entity<ActionLog>,
859 session_id: acp::SessionId,
860 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
861 cx: &mut Context<Self>,
862 ) -> Self {
863 let prompt_capabilities = *prompt_capabilities_rx.borrow();
864 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
865 loop {
866 let caps = prompt_capabilities_rx.recv().await?;
867 this.update(cx, |this, cx| {
868 this.prompt_capabilities = caps;
869 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
870 })?;
871 }
872 });
873
874 let determine_shell = cx
875 .background_spawn(async move {
876 if cfg!(windows) {
877 return get_system_shell();
878 }
879
880 if which::which("bash").is_ok() {
881 "bash".into()
882 } else {
883 get_system_shell()
884 }
885 })
886 .shared();
887
888 Self {
889 action_log,
890 shared_buffers: Default::default(),
891 entries: Default::default(),
892 plan: Default::default(),
893 title: title.into(),
894 project,
895 send_task: None,
896 connection,
897 session_id,
898 token_usage: None,
899 prompt_capabilities,
900 _observe_prompt_capabilities: task,
901 terminals: HashMap::default(),
902 determine_shell,
903 }
904 }
905
906 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
907 self.prompt_capabilities
908 }
909
910 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
911 &self.connection
912 }
913
914 pub fn action_log(&self) -> &Entity<ActionLog> {
915 &self.action_log
916 }
917
918 pub fn project(&self) -> &Entity<Project> {
919 &self.project
920 }
921
922 pub fn title(&self) -> SharedString {
923 self.title.clone()
924 }
925
926 pub fn entries(&self) -> &[AgentThreadEntry] {
927 &self.entries
928 }
929
930 pub fn session_id(&self) -> &acp::SessionId {
931 &self.session_id
932 }
933
934 pub fn status(&self) -> ThreadStatus {
935 if self.send_task.is_some() {
936 if self.waiting_for_tool_confirmation() {
937 ThreadStatus::WaitingForToolConfirmation
938 } else {
939 ThreadStatus::Generating
940 }
941 } else {
942 ThreadStatus::Idle
943 }
944 }
945
946 pub fn token_usage(&self) -> Option<&TokenUsage> {
947 self.token_usage.as_ref()
948 }
949
950 pub fn has_pending_edit_tool_calls(&self) -> bool {
951 for entry in self.entries.iter().rev() {
952 match entry {
953 AgentThreadEntry::UserMessage(_) => return false,
954 AgentThreadEntry::ToolCall(
955 call @ ToolCall {
956 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
957 ..
958 },
959 ) if call.diffs().next().is_some() => {
960 return true;
961 }
962 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
963 }
964 }
965
966 false
967 }
968
969 pub fn used_tools_since_last_user_message(&self) -> bool {
970 for entry in self.entries.iter().rev() {
971 match entry {
972 AgentThreadEntry::UserMessage(..) => return false,
973 AgentThreadEntry::AssistantMessage(..) => continue,
974 AgentThreadEntry::ToolCall(..) => return true,
975 }
976 }
977
978 false
979 }
980
981 pub fn handle_session_update(
982 &mut self,
983 update: acp::SessionUpdate,
984 cx: &mut Context<Self>,
985 ) -> Result<(), acp::Error> {
986 match update {
987 acp::SessionUpdate::UserMessageChunk { content } => {
988 self.push_user_content_block(None, content, cx);
989 }
990 acp::SessionUpdate::AgentMessageChunk { content } => {
991 self.push_assistant_content_block(content, false, cx);
992 }
993 acp::SessionUpdate::AgentThoughtChunk { content } => {
994 self.push_assistant_content_block(content, true, cx);
995 }
996 acp::SessionUpdate::ToolCall(tool_call) => {
997 self.upsert_tool_call(tool_call, cx)?;
998 }
999 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1000 self.update_tool_call(tool_call_update, cx)?;
1001 }
1002 acp::SessionUpdate::Plan(plan) => {
1003 self.update_plan(plan, cx);
1004 }
1005 }
1006 Ok(())
1007 }
1008
1009 pub fn push_user_content_block(
1010 &mut self,
1011 message_id: Option<UserMessageId>,
1012 chunk: acp::ContentBlock,
1013 cx: &mut Context<Self>,
1014 ) {
1015 let language_registry = self.project.read(cx).languages().clone();
1016 let entries_len = self.entries.len();
1017
1018 if let Some(last_entry) = self.entries.last_mut()
1019 && let AgentThreadEntry::UserMessage(UserMessage {
1020 id,
1021 content,
1022 chunks,
1023 ..
1024 }) = last_entry
1025 {
1026 *id = message_id.or(id.take());
1027 content.append(chunk.clone(), &language_registry, cx);
1028 chunks.push(chunk);
1029 let idx = entries_len - 1;
1030 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1031 } else {
1032 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1033 self.push_entry(
1034 AgentThreadEntry::UserMessage(UserMessage {
1035 id: message_id,
1036 content,
1037 chunks: vec![chunk],
1038 checkpoint: None,
1039 }),
1040 cx,
1041 );
1042 }
1043 }
1044
1045 pub fn push_assistant_content_block(
1046 &mut self,
1047 chunk: acp::ContentBlock,
1048 is_thought: bool,
1049 cx: &mut Context<Self>,
1050 ) {
1051 let language_registry = self.project.read(cx).languages().clone();
1052 let entries_len = self.entries.len();
1053 if let Some(last_entry) = self.entries.last_mut()
1054 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1055 {
1056 let idx = entries_len - 1;
1057 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1058 match (chunks.last_mut(), is_thought) {
1059 (Some(AssistantMessageChunk::Message { block }), false)
1060 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1061 block.append(chunk, &language_registry, cx)
1062 }
1063 _ => {
1064 let block = ContentBlock::new(chunk, &language_registry, cx);
1065 if is_thought {
1066 chunks.push(AssistantMessageChunk::Thought { block })
1067 } else {
1068 chunks.push(AssistantMessageChunk::Message { block })
1069 }
1070 }
1071 }
1072 } else {
1073 let block = ContentBlock::new(chunk, &language_registry, cx);
1074 let chunk = if is_thought {
1075 AssistantMessageChunk::Thought { block }
1076 } else {
1077 AssistantMessageChunk::Message { block }
1078 };
1079
1080 self.push_entry(
1081 AgentThreadEntry::AssistantMessage(AssistantMessage {
1082 chunks: vec![chunk],
1083 }),
1084 cx,
1085 );
1086 }
1087 }
1088
1089 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1090 self.entries.push(entry);
1091 cx.emit(AcpThreadEvent::NewEntry);
1092 }
1093
1094 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1095 self.connection.set_title(&self.session_id, cx).is_some()
1096 }
1097
1098 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1099 if title != self.title {
1100 self.title = title.clone();
1101 cx.emit(AcpThreadEvent::TitleUpdated);
1102 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1103 return set_title.run(title, cx);
1104 }
1105 }
1106 Task::ready(Ok(()))
1107 }
1108
1109 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1110 self.token_usage = usage;
1111 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1112 }
1113
1114 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1115 cx.emit(AcpThreadEvent::Retry(status));
1116 }
1117
1118 pub fn update_tool_call(
1119 &mut self,
1120 update: impl Into<ToolCallUpdate>,
1121 cx: &mut Context<Self>,
1122 ) -> Result<()> {
1123 let update = update.into();
1124 let languages = self.project.read(cx).languages().clone();
1125
1126 let ix = self
1127 .index_for_tool_call(update.id())
1128 .context("Tool call not found")?;
1129 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1130 unreachable!()
1131 };
1132
1133 match update {
1134 ToolCallUpdate::UpdateFields(update) => {
1135 let location_updated = update.fields.locations.is_some();
1136 call.update_fields(update.fields, languages, &self.terminals, cx)?;
1137 if location_updated {
1138 self.resolve_locations(update.id, cx);
1139 }
1140 }
1141 ToolCallUpdate::UpdateDiff(update) => {
1142 call.content.clear();
1143 call.content.push(ToolCallContent::Diff(update.diff));
1144 }
1145 ToolCallUpdate::UpdateTerminal(update) => {
1146 call.content.clear();
1147 call.content
1148 .push(ToolCallContent::Terminal(update.terminal));
1149 }
1150 }
1151
1152 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1153
1154 Ok(())
1155 }
1156
1157 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1158 pub fn upsert_tool_call(
1159 &mut self,
1160 tool_call: acp::ToolCall,
1161 cx: &mut Context<Self>,
1162 ) -> Result<(), acp::Error> {
1163 let status = tool_call.status.into();
1164 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1165 }
1166
1167 /// Fails if id does not match an existing entry.
1168 pub fn upsert_tool_call_inner(
1169 &mut self,
1170 update: acp::ToolCallUpdate,
1171 status: ToolCallStatus,
1172 cx: &mut Context<Self>,
1173 ) -> Result<(), acp::Error> {
1174 let language_registry = self.project.read(cx).languages().clone();
1175 let id = update.id.clone();
1176
1177 if let Some(ix) = self.index_for_tool_call(&id) {
1178 let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1179 unreachable!()
1180 };
1181
1182 call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1183 call.status = status;
1184
1185 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1186 } else {
1187 let call = ToolCall::from_acp(
1188 update.try_into()?,
1189 status,
1190 language_registry,
1191 &self.terminals,
1192 cx,
1193 )?;
1194 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1195 };
1196
1197 self.resolve_locations(id, cx);
1198 Ok(())
1199 }
1200
1201 fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1202 self.entries
1203 .iter()
1204 .enumerate()
1205 .rev()
1206 .find_map(|(index, entry)| {
1207 if let AgentThreadEntry::ToolCall(tool_call) = entry
1208 && &tool_call.id == id
1209 {
1210 Some(index)
1211 } else {
1212 None
1213 }
1214 })
1215 }
1216
1217 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1218 // The tool call we are looking for is typically the last one, or very close to the end.
1219 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1220 self.entries
1221 .iter_mut()
1222 .enumerate()
1223 .rev()
1224 .find_map(|(index, tool_call)| {
1225 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1226 && &tool_call.id == id
1227 {
1228 Some((index, tool_call))
1229 } else {
1230 None
1231 }
1232 })
1233 }
1234
1235 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1236 self.entries
1237 .iter()
1238 .enumerate()
1239 .rev()
1240 .find_map(|(index, tool_call)| {
1241 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1242 && &tool_call.id == id
1243 {
1244 Some((index, tool_call))
1245 } else {
1246 None
1247 }
1248 })
1249 }
1250
1251 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1252 let project = self.project.clone();
1253 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1254 return;
1255 };
1256 let task = tool_call.resolve_locations(project, cx);
1257 cx.spawn(async move |this, cx| {
1258 let resolved_locations = task.await;
1259 this.update(cx, |this, cx| {
1260 let project = this.project.clone();
1261 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1262 return;
1263 };
1264 if let Some(Some(location)) = resolved_locations.last() {
1265 project.update(cx, |project, cx| {
1266 if let Some(agent_location) = project.agent_location() {
1267 let should_ignore = agent_location.buffer == location.buffer
1268 && location
1269 .buffer
1270 .update(cx, |buffer, _| {
1271 let snapshot = buffer.snapshot();
1272 let old_position =
1273 agent_location.position.to_point(&snapshot);
1274 let new_position = location.position.to_point(&snapshot);
1275 // ignore this so that when we get updates from the edit tool
1276 // the position doesn't reset to the startof line
1277 old_position.row == new_position.row
1278 && old_position.column > new_position.column
1279 })
1280 .ok()
1281 .unwrap_or_default();
1282 if !should_ignore {
1283 project.set_agent_location(Some(location.clone()), cx);
1284 }
1285 }
1286 });
1287 }
1288 if tool_call.resolved_locations != resolved_locations {
1289 tool_call.resolved_locations = resolved_locations;
1290 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1291 }
1292 })
1293 })
1294 .detach();
1295 }
1296
1297 pub fn request_tool_call_authorization(
1298 &mut self,
1299 tool_call: acp::ToolCallUpdate,
1300 options: Vec<acp::PermissionOption>,
1301 cx: &mut Context<Self>,
1302 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1303 let (tx, rx) = oneshot::channel();
1304
1305 if AgentSettings::get_global(cx).always_allow_tool_actions {
1306 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1307 // some tools would (incorrectly) continue to auto-accept.
1308 if let Some(allow_once_option) = options.iter().find_map(|option| {
1309 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1310 Some(option.id.clone())
1311 } else {
1312 None
1313 }
1314 }) {
1315 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1316 return Ok(async {
1317 acp::RequestPermissionOutcome::Selected {
1318 option_id: allow_once_option,
1319 }
1320 }
1321 .boxed());
1322 }
1323 }
1324
1325 let status = ToolCallStatus::WaitingForConfirmation {
1326 options,
1327 respond_tx: tx,
1328 };
1329
1330 self.upsert_tool_call_inner(tool_call, status, cx)?;
1331 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1332
1333 let fut = async {
1334 match rx.await {
1335 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1336 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1337 }
1338 }
1339 .boxed();
1340
1341 Ok(fut)
1342 }
1343
1344 pub fn authorize_tool_call(
1345 &mut self,
1346 id: acp::ToolCallId,
1347 option_id: acp::PermissionOptionId,
1348 option_kind: acp::PermissionOptionKind,
1349 cx: &mut Context<Self>,
1350 ) {
1351 let Some((ix, call)) = self.tool_call_mut(&id) else {
1352 return;
1353 };
1354
1355 let new_status = match option_kind {
1356 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1357 ToolCallStatus::Rejected
1358 }
1359 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1360 ToolCallStatus::InProgress
1361 }
1362 };
1363
1364 let curr_status = mem::replace(&mut call.status, new_status);
1365
1366 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1367 respond_tx.send(option_id).log_err();
1368 } else if cfg!(debug_assertions) {
1369 panic!("tried to authorize an already authorized tool call");
1370 }
1371
1372 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1373 }
1374
1375 /// Returns true if the last turn is awaiting tool authorization
1376 pub fn waiting_for_tool_confirmation(&self) -> bool {
1377 for entry in self.entries.iter().rev() {
1378 match &entry {
1379 AgentThreadEntry::ToolCall(call) => match call.status {
1380 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1381 ToolCallStatus::Pending
1382 | ToolCallStatus::InProgress
1383 | ToolCallStatus::Completed
1384 | ToolCallStatus::Failed
1385 | ToolCallStatus::Rejected
1386 | ToolCallStatus::Canceled => continue,
1387 },
1388 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1389 // Reached the beginning of the turn
1390 return false;
1391 }
1392 }
1393 }
1394 false
1395 }
1396
1397 pub fn plan(&self) -> &Plan {
1398 &self.plan
1399 }
1400
1401 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1402 let new_entries_len = request.entries.len();
1403 let mut new_entries = request.entries.into_iter();
1404
1405 // Reuse existing markdown to prevent flickering
1406 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1407 let PlanEntry {
1408 content,
1409 priority,
1410 status,
1411 } = old;
1412 content.update(cx, |old, cx| {
1413 old.replace(new.content, cx);
1414 });
1415 *priority = new.priority;
1416 *status = new.status;
1417 }
1418 for new in new_entries {
1419 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1420 }
1421 self.plan.entries.truncate(new_entries_len);
1422
1423 cx.notify();
1424 }
1425
1426 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1427 self.plan
1428 .entries
1429 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1430 cx.notify();
1431 }
1432
1433 #[cfg(any(test, feature = "test-support"))]
1434 pub fn send_raw(
1435 &mut self,
1436 message: &str,
1437 cx: &mut Context<Self>,
1438 ) -> BoxFuture<'static, Result<()>> {
1439 self.send(
1440 vec![acp::ContentBlock::Text(acp::TextContent {
1441 text: message.to_string(),
1442 annotations: None,
1443 })],
1444 cx,
1445 )
1446 }
1447
1448 pub fn send(
1449 &mut self,
1450 message: Vec<acp::ContentBlock>,
1451 cx: &mut Context<Self>,
1452 ) -> BoxFuture<'static, Result<()>> {
1453 let block = ContentBlock::new_combined(
1454 message.clone(),
1455 self.project.read(cx).languages().clone(),
1456 cx,
1457 );
1458 let request = acp::PromptRequest {
1459 prompt: message.clone(),
1460 session_id: self.session_id.clone(),
1461 };
1462 let git_store = self.project.read(cx).git_store().clone();
1463
1464 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1465 Some(UserMessageId::new())
1466 } else {
1467 None
1468 };
1469
1470 self.run_turn(cx, async move |this, cx| {
1471 this.update(cx, |this, cx| {
1472 this.push_entry(
1473 AgentThreadEntry::UserMessage(UserMessage {
1474 id: message_id.clone(),
1475 content: block,
1476 chunks: message,
1477 checkpoint: None,
1478 }),
1479 cx,
1480 );
1481 })
1482 .ok();
1483
1484 let old_checkpoint = git_store
1485 .update(cx, |git, cx| git.checkpoint(cx))?
1486 .await
1487 .context("failed to get old checkpoint")
1488 .log_err();
1489 this.update(cx, |this, cx| {
1490 if let Some((_ix, message)) = this.last_user_message() {
1491 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1492 git_checkpoint,
1493 show: false,
1494 });
1495 }
1496 this.connection.prompt(message_id, request, cx)
1497 })?
1498 .await
1499 })
1500 }
1501
1502 pub fn can_resume(&self, cx: &App) -> bool {
1503 self.connection.resume(&self.session_id, cx).is_some()
1504 }
1505
1506 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1507 self.run_turn(cx, async move |this, cx| {
1508 this.update(cx, |this, cx| {
1509 this.connection
1510 .resume(&this.session_id, cx)
1511 .map(|resume| resume.run(cx))
1512 })?
1513 .context("resuming a session is not supported")?
1514 .await
1515 })
1516 }
1517
1518 fn run_turn(
1519 &mut self,
1520 cx: &mut Context<Self>,
1521 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1522 ) -> BoxFuture<'static, Result<()>> {
1523 self.clear_completed_plan_entries(cx);
1524
1525 let (tx, rx) = oneshot::channel();
1526 let cancel_task = self.cancel(cx);
1527
1528 self.send_task = Some(cx.spawn(async move |this, cx| {
1529 cancel_task.await;
1530 tx.send(f(this, cx).await).ok();
1531 }));
1532
1533 cx.spawn(async move |this, cx| {
1534 let response = rx.await;
1535
1536 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1537 .await?;
1538
1539 this.update(cx, |this, cx| {
1540 this.project
1541 .update(cx, |project, cx| project.set_agent_location(None, cx));
1542 match response {
1543 Ok(Err(e)) => {
1544 this.send_task.take();
1545 cx.emit(AcpThreadEvent::Error);
1546 Err(e)
1547 }
1548 result => {
1549 let canceled = matches!(
1550 result,
1551 Ok(Ok(acp::PromptResponse {
1552 stop_reason: acp::StopReason::Cancelled
1553 }))
1554 );
1555
1556 // We only take the task if the current prompt wasn't canceled.
1557 //
1558 // This prompt may have been canceled because another one was sent
1559 // while it was still generating. In these cases, dropping `send_task`
1560 // would cause the next generation to be canceled.
1561 if !canceled {
1562 this.send_task.take();
1563 }
1564
1565 // Truncate entries if the last prompt was refused.
1566 if let Ok(Ok(acp::PromptResponse {
1567 stop_reason: acp::StopReason::Refusal,
1568 })) = result
1569 && let Some((ix, _)) = this.last_user_message()
1570 {
1571 let range = ix..this.entries.len();
1572 this.entries.truncate(ix);
1573 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1574 }
1575
1576 cx.emit(AcpThreadEvent::Stopped);
1577 Ok(())
1578 }
1579 }
1580 })?
1581 })
1582 .boxed()
1583 }
1584
1585 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1586 let Some(send_task) = self.send_task.take() else {
1587 return Task::ready(());
1588 };
1589
1590 for entry in self.entries.iter_mut() {
1591 if let AgentThreadEntry::ToolCall(call) = entry {
1592 let cancel = matches!(
1593 call.status,
1594 ToolCallStatus::Pending
1595 | ToolCallStatus::WaitingForConfirmation { .. }
1596 | ToolCallStatus::InProgress
1597 );
1598
1599 if cancel {
1600 call.status = ToolCallStatus::Canceled;
1601 }
1602 }
1603 }
1604
1605 self.connection.cancel(&self.session_id, cx);
1606
1607 // Wait for the send task to complete
1608 cx.foreground_executor().spawn(send_task)
1609 }
1610
1611 /// Rewinds this thread to before the entry at `index`, removing it and all
1612 /// subsequent entries while reverting any changes made from that point.
1613 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1614 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1615 return Task::ready(Err(anyhow!("not supported")));
1616 };
1617 let Some(message) = self.user_message(&id) else {
1618 return Task::ready(Err(anyhow!("message not found")));
1619 };
1620
1621 let checkpoint = message
1622 .checkpoint
1623 .as_ref()
1624 .map(|c| c.git_checkpoint.clone());
1625
1626 let git_store = self.project.read(cx).git_store().clone();
1627 cx.spawn(async move |this, cx| {
1628 if let Some(checkpoint) = checkpoint {
1629 git_store
1630 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1631 .await?;
1632 }
1633
1634 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1635 this.update(cx, |this, cx| {
1636 if let Some((ix, _)) = this.user_message_mut(&id) {
1637 let range = ix..this.entries.len();
1638 this.entries.truncate(ix);
1639 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1640 }
1641 })
1642 })
1643 }
1644
1645 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1646 let git_store = self.project.read(cx).git_store().clone();
1647
1648 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1649 if let Some(checkpoint) = message.checkpoint.as_ref() {
1650 checkpoint.git_checkpoint.clone()
1651 } else {
1652 return Task::ready(Ok(()));
1653 }
1654 } else {
1655 return Task::ready(Ok(()));
1656 };
1657
1658 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1659 cx.spawn(async move |this, cx| {
1660 let new_checkpoint = new_checkpoint
1661 .await
1662 .context("failed to get new checkpoint")
1663 .log_err();
1664 if let Some(new_checkpoint) = new_checkpoint {
1665 let equal = git_store
1666 .update(cx, |git, cx| {
1667 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1668 })?
1669 .await
1670 .unwrap_or(true);
1671 this.update(cx, |this, cx| {
1672 let (ix, message) = this.last_user_message().context("no user message")?;
1673 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1674 checkpoint.show = !equal;
1675 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1676 anyhow::Ok(())
1677 })??;
1678 }
1679
1680 Ok(())
1681 })
1682 }
1683
1684 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1685 self.entries
1686 .iter_mut()
1687 .enumerate()
1688 .rev()
1689 .find_map(|(ix, entry)| {
1690 if let AgentThreadEntry::UserMessage(message) = entry {
1691 Some((ix, message))
1692 } else {
1693 None
1694 }
1695 })
1696 }
1697
1698 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1699 self.entries.iter().find_map(|entry| {
1700 if let AgentThreadEntry::UserMessage(message) = entry {
1701 if message.id.as_ref() == Some(id) {
1702 Some(message)
1703 } else {
1704 None
1705 }
1706 } else {
1707 None
1708 }
1709 })
1710 }
1711
1712 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1713 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1714 if let AgentThreadEntry::UserMessage(message) = entry {
1715 if message.id.as_ref() == Some(id) {
1716 Some((ix, message))
1717 } else {
1718 None
1719 }
1720 } else {
1721 None
1722 }
1723 })
1724 }
1725
1726 pub fn read_text_file(
1727 &self,
1728 path: PathBuf,
1729 line: Option<u32>,
1730 limit: Option<u32>,
1731 reuse_shared_snapshot: bool,
1732 cx: &mut Context<Self>,
1733 ) -> Task<Result<String>> {
1734 let project = self.project.clone();
1735 let action_log = self.action_log.clone();
1736 cx.spawn(async move |this, cx| {
1737 let load = project.update(cx, |project, cx| {
1738 let path = project
1739 .project_path_for_absolute_path(&path, cx)
1740 .context("invalid path")?;
1741 anyhow::Ok(project.open_buffer(path, cx))
1742 });
1743 let buffer = load??.await?;
1744
1745 let snapshot = if reuse_shared_snapshot {
1746 this.read_with(cx, |this, _| {
1747 this.shared_buffers.get(&buffer.clone()).cloned()
1748 })
1749 .log_err()
1750 .flatten()
1751 } else {
1752 None
1753 };
1754
1755 let snapshot = if let Some(snapshot) = snapshot {
1756 snapshot
1757 } else {
1758 action_log.update(cx, |action_log, cx| {
1759 action_log.buffer_read(buffer.clone(), cx);
1760 })?;
1761 project.update(cx, |project, cx| {
1762 let position = buffer
1763 .read(cx)
1764 .snapshot()
1765 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1766 project.set_agent_location(
1767 Some(AgentLocation {
1768 buffer: buffer.downgrade(),
1769 position,
1770 }),
1771 cx,
1772 );
1773 })?;
1774
1775 buffer.update(cx, |buffer, _| buffer.snapshot())?
1776 };
1777
1778 this.update(cx, |this, _| {
1779 let text = snapshot.text();
1780 this.shared_buffers.insert(buffer.clone(), snapshot);
1781 if line.is_none() && limit.is_none() {
1782 return Ok(text);
1783 }
1784 let limit = limit.unwrap_or(u32::MAX) as usize;
1785 let Some(line) = line else {
1786 return Ok(text.lines().take(limit).collect::<String>());
1787 };
1788
1789 let count = text.lines().count();
1790 if count < line as usize {
1791 anyhow::bail!("There are only {} lines", count);
1792 }
1793 Ok(text
1794 .lines()
1795 .skip(line as usize + 1)
1796 .take(limit)
1797 .collect::<String>())
1798 })?
1799 })
1800 }
1801
1802 pub fn write_text_file(
1803 &self,
1804 path: PathBuf,
1805 content: String,
1806 cx: &mut Context<Self>,
1807 ) -> Task<Result<()>> {
1808 let project = self.project.clone();
1809 let action_log = self.action_log.clone();
1810 cx.spawn(async move |this, cx| {
1811 let load = project.update(cx, |project, cx| {
1812 let path = project
1813 .project_path_for_absolute_path(&path, cx)
1814 .context("invalid path")?;
1815 anyhow::Ok(project.open_buffer(path, cx))
1816 });
1817 let buffer = load??.await?;
1818 let snapshot = this.update(cx, |this, cx| {
1819 this.shared_buffers
1820 .get(&buffer)
1821 .cloned()
1822 .unwrap_or_else(|| buffer.read(cx).snapshot())
1823 })?;
1824 let edits = cx
1825 .background_executor()
1826 .spawn(async move {
1827 let old_text = snapshot.text();
1828 text_diff(old_text.as_str(), &content)
1829 .into_iter()
1830 .map(|(range, replacement)| {
1831 (
1832 snapshot.anchor_after(range.start)
1833 ..snapshot.anchor_before(range.end),
1834 replacement,
1835 )
1836 })
1837 .collect::<Vec<_>>()
1838 })
1839 .await;
1840
1841 project.update(cx, |project, cx| {
1842 project.set_agent_location(
1843 Some(AgentLocation {
1844 buffer: buffer.downgrade(),
1845 position: edits
1846 .last()
1847 .map(|(range, _)| range.end)
1848 .unwrap_or(Anchor::MIN),
1849 }),
1850 cx,
1851 );
1852 })?;
1853
1854 let format_on_save = cx.update(|cx| {
1855 action_log.update(cx, |action_log, cx| {
1856 action_log.buffer_read(buffer.clone(), cx);
1857 });
1858
1859 let format_on_save = buffer.update(cx, |buffer, cx| {
1860 buffer.edit(edits, None, cx);
1861
1862 let settings = language::language_settings::language_settings(
1863 buffer.language().map(|l| l.name()),
1864 buffer.file(),
1865 cx,
1866 );
1867
1868 settings.format_on_save != FormatOnSave::Off
1869 });
1870 action_log.update(cx, |action_log, cx| {
1871 action_log.buffer_edited(buffer.clone(), cx);
1872 });
1873 format_on_save
1874 })?;
1875
1876 if format_on_save {
1877 let format_task = project.update(cx, |project, cx| {
1878 project.format(
1879 HashSet::from_iter([buffer.clone()]),
1880 LspFormatTarget::Buffers,
1881 false,
1882 FormatTrigger::Save,
1883 cx,
1884 )
1885 })?;
1886 format_task.await.log_err();
1887
1888 action_log.update(cx, |action_log, cx| {
1889 action_log.buffer_edited(buffer.clone(), cx);
1890 })?;
1891 }
1892
1893 project
1894 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1895 .await
1896 })
1897 }
1898
1899 pub fn create_terminal(
1900 &self,
1901 mut command: String,
1902 args: Vec<String>,
1903 extra_env: Vec<acp::EnvVariable>,
1904 cwd: Option<PathBuf>,
1905 output_byte_limit: Option<u64>,
1906 cx: &mut Context<Self>,
1907 ) -> Task<Result<Entity<Terminal>>> {
1908 for arg in args {
1909 command.push(' ');
1910 command.push_str(&arg);
1911 }
1912
1913 let shell_command = if cfg!(windows) {
1914 format!("$null | & {{{}}}", command.replace("\"", "'"))
1915 } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
1916 // Make sure once we're *inside* the shell, we cd into `cwd`
1917 format!("(cd {cwd}; {}) </dev/null", command)
1918 } else {
1919 format!("({}) </dev/null", command)
1920 };
1921 let args = vec!["-c".into(), shell_command];
1922
1923 let env = match &cwd {
1924 Some(dir) => self.project.update(cx, |project, cx| {
1925 project.directory_environment(dir.as_path().into(), cx)
1926 }),
1927 None => Task::ready(None).shared(),
1928 };
1929
1930 let env = cx.spawn(async move |_, _| {
1931 let mut env = env.await.unwrap_or_default();
1932 if cfg!(unix) {
1933 env.insert("PAGER".into(), "cat".into());
1934 }
1935 for var in extra_env {
1936 env.insert(var.name, var.value);
1937 }
1938 env
1939 });
1940
1941 let project = self.project.clone();
1942 let language_registry = project.read(cx).languages().clone();
1943 let determine_shell = self.determine_shell.clone();
1944
1945 let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1946 let terminal_task = cx.spawn({
1947 let terminal_id = terminal_id.clone();
1948 async move |_this, cx| {
1949 let program = determine_shell.await;
1950 let env = env.await;
1951 let terminal = project
1952 .update(cx, |project, cx| {
1953 project.create_terminal_task(
1954 task::SpawnInTerminal {
1955 command: Some(program),
1956 args,
1957 cwd: cwd.clone(),
1958 env,
1959 ..Default::default()
1960 },
1961 cx,
1962 )
1963 })?
1964 .await?;
1965
1966 cx.new(|cx| {
1967 Terminal::new(
1968 terminal_id,
1969 command,
1970 cwd,
1971 output_byte_limit.map(|l| l as usize),
1972 terminal,
1973 language_registry,
1974 cx,
1975 )
1976 })
1977 }
1978 });
1979
1980 cx.spawn(async move |this, cx| {
1981 let terminal = terminal_task.await?;
1982 this.update(cx, |this, _cx| {
1983 this.terminals.insert(terminal_id, terminal.clone());
1984 terminal
1985 })
1986 })
1987 }
1988
1989 pub fn kill_terminal(
1990 &mut self,
1991 terminal_id: acp::TerminalId,
1992 cx: &mut Context<Self>,
1993 ) -> Result<()> {
1994 self.terminals
1995 .get(&terminal_id)
1996 .context("Terminal not found")?
1997 .update(cx, |terminal, cx| {
1998 terminal.kill(cx);
1999 });
2000
2001 Ok(())
2002 }
2003
2004 pub fn release_terminal(
2005 &mut self,
2006 terminal_id: acp::TerminalId,
2007 cx: &mut Context<Self>,
2008 ) -> Result<()> {
2009 self.terminals
2010 .remove(&terminal_id)
2011 .context("Terminal not found")?
2012 .update(cx, |terminal, cx| {
2013 terminal.kill(cx);
2014 });
2015
2016 Ok(())
2017 }
2018
2019 pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2020 self.terminals
2021 .get(&terminal_id)
2022 .context("Terminal not found")
2023 .cloned()
2024 }
2025
2026 pub fn to_markdown(&self, cx: &App) -> String {
2027 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2028 }
2029
2030 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2031 cx.emit(AcpThreadEvent::LoadError(error));
2032 }
2033}
2034
2035fn markdown_for_raw_output(
2036 raw_output: &serde_json::Value,
2037 language_registry: &Arc<LanguageRegistry>,
2038 cx: &mut App,
2039) -> Option<Entity<Markdown>> {
2040 match raw_output {
2041 serde_json::Value::Null => None,
2042 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2043 Markdown::new(
2044 value.to_string().into(),
2045 Some(language_registry.clone()),
2046 None,
2047 cx,
2048 )
2049 })),
2050 serde_json::Value::Number(value) => Some(cx.new(|cx| {
2051 Markdown::new(
2052 value.to_string().into(),
2053 Some(language_registry.clone()),
2054 None,
2055 cx,
2056 )
2057 })),
2058 serde_json::Value::String(value) => Some(cx.new(|cx| {
2059 Markdown::new(
2060 value.clone().into(),
2061 Some(language_registry.clone()),
2062 None,
2063 cx,
2064 )
2065 })),
2066 value => Some(cx.new(|cx| {
2067 Markdown::new(
2068 format!("```json\n{}\n```", value).into(),
2069 Some(language_registry.clone()),
2070 None,
2071 cx,
2072 )
2073 })),
2074 }
2075}
2076
2077#[cfg(test)]
2078mod tests {
2079 use super::*;
2080 use anyhow::anyhow;
2081 use futures::{channel::mpsc, future::LocalBoxFuture, select};
2082 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2083 use indoc::indoc;
2084 use project::{FakeFs, Fs};
2085 use rand::Rng as _;
2086 use serde_json::json;
2087 use settings::SettingsStore;
2088 use smol::stream::StreamExt as _;
2089 use std::{
2090 any::Any,
2091 cell::RefCell,
2092 path::Path,
2093 rc::Rc,
2094 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2095 time::Duration,
2096 };
2097 use util::path;
2098
2099 fn init_test(cx: &mut TestAppContext) {
2100 env_logger::try_init().ok();
2101 cx.update(|cx| {
2102 let settings_store = SettingsStore::test(cx);
2103 cx.set_global(settings_store);
2104 Project::init_settings(cx);
2105 language::init(cx);
2106 });
2107 }
2108
2109 #[gpui::test]
2110 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2111 init_test(cx);
2112
2113 let fs = FakeFs::new(cx.executor());
2114 let project = Project::test(fs, [], cx).await;
2115 let connection = Rc::new(FakeAgentConnection::new());
2116 let thread = cx
2117 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2118 .await
2119 .unwrap();
2120
2121 // Test creating a new user message
2122 thread.update(cx, |thread, cx| {
2123 thread.push_user_content_block(
2124 None,
2125 acp::ContentBlock::Text(acp::TextContent {
2126 annotations: None,
2127 text: "Hello, ".to_string(),
2128 }),
2129 cx,
2130 );
2131 });
2132
2133 thread.update(cx, |thread, cx| {
2134 assert_eq!(thread.entries.len(), 1);
2135 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2136 assert_eq!(user_msg.id, None);
2137 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2138 } else {
2139 panic!("Expected UserMessage");
2140 }
2141 });
2142
2143 // Test appending to existing user message
2144 let message_1_id = UserMessageId::new();
2145 thread.update(cx, |thread, cx| {
2146 thread.push_user_content_block(
2147 Some(message_1_id.clone()),
2148 acp::ContentBlock::Text(acp::TextContent {
2149 annotations: None,
2150 text: "world!".to_string(),
2151 }),
2152 cx,
2153 );
2154 });
2155
2156 thread.update(cx, |thread, cx| {
2157 assert_eq!(thread.entries.len(), 1);
2158 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2159 assert_eq!(user_msg.id, Some(message_1_id));
2160 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2161 } else {
2162 panic!("Expected UserMessage");
2163 }
2164 });
2165
2166 // Test creating new user message after assistant message
2167 thread.update(cx, |thread, cx| {
2168 thread.push_assistant_content_block(
2169 acp::ContentBlock::Text(acp::TextContent {
2170 annotations: None,
2171 text: "Assistant response".to_string(),
2172 }),
2173 false,
2174 cx,
2175 );
2176 });
2177
2178 let message_2_id = UserMessageId::new();
2179 thread.update(cx, |thread, cx| {
2180 thread.push_user_content_block(
2181 Some(message_2_id.clone()),
2182 acp::ContentBlock::Text(acp::TextContent {
2183 annotations: None,
2184 text: "New user message".to_string(),
2185 }),
2186 cx,
2187 );
2188 });
2189
2190 thread.update(cx, |thread, cx| {
2191 assert_eq!(thread.entries.len(), 3);
2192 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2193 assert_eq!(user_msg.id, Some(message_2_id));
2194 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2195 } else {
2196 panic!("Expected UserMessage at index 2");
2197 }
2198 });
2199 }
2200
2201 #[gpui::test]
2202 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2203 init_test(cx);
2204
2205 let fs = FakeFs::new(cx.executor());
2206 let project = Project::test(fs, [], cx).await;
2207 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2208 |_, thread, mut cx| {
2209 async move {
2210 thread.update(&mut cx, |thread, cx| {
2211 thread
2212 .handle_session_update(
2213 acp::SessionUpdate::AgentThoughtChunk {
2214 content: "Thinking ".into(),
2215 },
2216 cx,
2217 )
2218 .unwrap();
2219 thread
2220 .handle_session_update(
2221 acp::SessionUpdate::AgentThoughtChunk {
2222 content: "hard!".into(),
2223 },
2224 cx,
2225 )
2226 .unwrap();
2227 })?;
2228 Ok(acp::PromptResponse {
2229 stop_reason: acp::StopReason::EndTurn,
2230 })
2231 }
2232 .boxed_local()
2233 },
2234 ));
2235
2236 let thread = cx
2237 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2238 .await
2239 .unwrap();
2240
2241 thread
2242 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2243 .await
2244 .unwrap();
2245
2246 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2247 assert_eq!(
2248 output,
2249 indoc! {r#"
2250 ## User
2251
2252 Hello from Zed!
2253
2254 ## Assistant
2255
2256 <thinking>
2257 Thinking hard!
2258 </thinking>
2259
2260 "#}
2261 );
2262 }
2263
2264 #[gpui::test]
2265 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2266 init_test(cx);
2267
2268 let fs = FakeFs::new(cx.executor());
2269 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2270 .await;
2271 let project = Project::test(fs.clone(), [], cx).await;
2272 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2273 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2274 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2275 move |_, thread, mut cx| {
2276 let read_file_tx = read_file_tx.clone();
2277 async move {
2278 let content = thread
2279 .update(&mut cx, |thread, cx| {
2280 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2281 })
2282 .unwrap()
2283 .await
2284 .unwrap();
2285 assert_eq!(content, "one\ntwo\nthree\n");
2286 read_file_tx.take().unwrap().send(()).unwrap();
2287 thread
2288 .update(&mut cx, |thread, cx| {
2289 thread.write_text_file(
2290 path!("/tmp/foo").into(),
2291 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2292 cx,
2293 )
2294 })
2295 .unwrap()
2296 .await
2297 .unwrap();
2298 Ok(acp::PromptResponse {
2299 stop_reason: acp::StopReason::EndTurn,
2300 })
2301 }
2302 .boxed_local()
2303 },
2304 ));
2305
2306 let (worktree, pathbuf) = project
2307 .update(cx, |project, cx| {
2308 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2309 })
2310 .await
2311 .unwrap();
2312 let buffer = project
2313 .update(cx, |project, cx| {
2314 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2315 })
2316 .await
2317 .unwrap();
2318
2319 let thread = cx
2320 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2321 .await
2322 .unwrap();
2323
2324 let request = thread.update(cx, |thread, cx| {
2325 thread.send_raw("Extend the count in /tmp/foo", cx)
2326 });
2327 read_file_rx.await.ok();
2328 buffer.update(cx, |buffer, cx| {
2329 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2330 });
2331 cx.run_until_parked();
2332 assert_eq!(
2333 buffer.read_with(cx, |buffer, _| buffer.text()),
2334 "zero\none\ntwo\nthree\nfour\nfive\n"
2335 );
2336 assert_eq!(
2337 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2338 "zero\none\ntwo\nthree\nfour\nfive\n"
2339 );
2340 request.await.unwrap();
2341 }
2342
2343 #[gpui::test]
2344 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2345 init_test(cx);
2346
2347 let fs = FakeFs::new(cx.executor());
2348 let project = Project::test(fs, [], cx).await;
2349 let id = acp::ToolCallId("test".into());
2350
2351 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2352 let id = id.clone();
2353 move |_, thread, mut cx| {
2354 let id = id.clone();
2355 async move {
2356 thread
2357 .update(&mut cx, |thread, cx| {
2358 thread.handle_session_update(
2359 acp::SessionUpdate::ToolCall(acp::ToolCall {
2360 id: id.clone(),
2361 title: "Label".into(),
2362 kind: acp::ToolKind::Fetch,
2363 status: acp::ToolCallStatus::InProgress,
2364 content: vec![],
2365 locations: vec![],
2366 raw_input: None,
2367 raw_output: None,
2368 }),
2369 cx,
2370 )
2371 })
2372 .unwrap()
2373 .unwrap();
2374 Ok(acp::PromptResponse {
2375 stop_reason: acp::StopReason::EndTurn,
2376 })
2377 }
2378 .boxed_local()
2379 }
2380 }));
2381
2382 let thread = cx
2383 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2384 .await
2385 .unwrap();
2386
2387 let request = thread.update(cx, |thread, cx| {
2388 thread.send_raw("Fetch https://example.com", cx)
2389 });
2390
2391 run_until_first_tool_call(&thread, cx).await;
2392
2393 thread.read_with(cx, |thread, _| {
2394 assert!(matches!(
2395 thread.entries[1],
2396 AgentThreadEntry::ToolCall(ToolCall {
2397 status: ToolCallStatus::InProgress,
2398 ..
2399 })
2400 ));
2401 });
2402
2403 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2404
2405 thread.read_with(cx, |thread, _| {
2406 assert!(matches!(
2407 &thread.entries[1],
2408 AgentThreadEntry::ToolCall(ToolCall {
2409 status: ToolCallStatus::Canceled,
2410 ..
2411 })
2412 ));
2413 });
2414
2415 thread
2416 .update(cx, |thread, cx| {
2417 thread.handle_session_update(
2418 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2419 id,
2420 fields: acp::ToolCallUpdateFields {
2421 status: Some(acp::ToolCallStatus::Completed),
2422 ..Default::default()
2423 },
2424 }),
2425 cx,
2426 )
2427 })
2428 .unwrap();
2429
2430 request.await.unwrap();
2431
2432 thread.read_with(cx, |thread, _| {
2433 assert!(matches!(
2434 thread.entries[1],
2435 AgentThreadEntry::ToolCall(ToolCall {
2436 status: ToolCallStatus::Completed,
2437 ..
2438 })
2439 ));
2440 });
2441 }
2442
2443 #[gpui::test]
2444 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2445 init_test(cx);
2446 let fs = FakeFs::new(cx.background_executor.clone());
2447 fs.insert_tree(path!("/test"), json!({})).await;
2448 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2449
2450 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2451 move |_, thread, mut cx| {
2452 async move {
2453 thread
2454 .update(&mut cx, |thread, cx| {
2455 thread.handle_session_update(
2456 acp::SessionUpdate::ToolCall(acp::ToolCall {
2457 id: acp::ToolCallId("test".into()),
2458 title: "Label".into(),
2459 kind: acp::ToolKind::Edit,
2460 status: acp::ToolCallStatus::Completed,
2461 content: vec![acp::ToolCallContent::Diff {
2462 diff: acp::Diff {
2463 path: "/test/test.txt".into(),
2464 old_text: None,
2465 new_text: "foo".into(),
2466 },
2467 }],
2468 locations: vec![],
2469 raw_input: None,
2470 raw_output: None,
2471 }),
2472 cx,
2473 )
2474 })
2475 .unwrap()
2476 .unwrap();
2477 Ok(acp::PromptResponse {
2478 stop_reason: acp::StopReason::EndTurn,
2479 })
2480 }
2481 .boxed_local()
2482 }
2483 }));
2484
2485 let thread = cx
2486 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2487 .await
2488 .unwrap();
2489
2490 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2491 .await
2492 .unwrap();
2493
2494 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2495 }
2496
2497 #[gpui::test(iterations = 10)]
2498 async fn test_checkpoints(cx: &mut TestAppContext) {
2499 init_test(cx);
2500 let fs = FakeFs::new(cx.background_executor.clone());
2501 fs.insert_tree(
2502 path!("/test"),
2503 json!({
2504 ".git": {}
2505 }),
2506 )
2507 .await;
2508 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2509
2510 let simulate_changes = Arc::new(AtomicBool::new(true));
2511 let next_filename = Arc::new(AtomicUsize::new(0));
2512 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2513 let simulate_changes = simulate_changes.clone();
2514 let next_filename = next_filename.clone();
2515 let fs = fs.clone();
2516 move |request, thread, mut cx| {
2517 let fs = fs.clone();
2518 let simulate_changes = simulate_changes.clone();
2519 let next_filename = next_filename.clone();
2520 async move {
2521 if simulate_changes.load(SeqCst) {
2522 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2523 fs.write(Path::new(&filename), b"").await?;
2524 }
2525
2526 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2527 panic!("expected text content block");
2528 };
2529 thread.update(&mut cx, |thread, cx| {
2530 thread
2531 .handle_session_update(
2532 acp::SessionUpdate::AgentMessageChunk {
2533 content: content.text.to_uppercase().into(),
2534 },
2535 cx,
2536 )
2537 .unwrap();
2538 })?;
2539 Ok(acp::PromptResponse {
2540 stop_reason: acp::StopReason::EndTurn,
2541 })
2542 }
2543 .boxed_local()
2544 }
2545 }));
2546 let thread = cx
2547 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2548 .await
2549 .unwrap();
2550
2551 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2552 .await
2553 .unwrap();
2554 thread.read_with(cx, |thread, cx| {
2555 assert_eq!(
2556 thread.to_markdown(cx),
2557 indoc! {"
2558 ## User (checkpoint)
2559
2560 Lorem
2561
2562 ## Assistant
2563
2564 LOREM
2565
2566 "}
2567 );
2568 });
2569 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2570
2571 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2572 .await
2573 .unwrap();
2574 thread.read_with(cx, |thread, cx| {
2575 assert_eq!(
2576 thread.to_markdown(cx),
2577 indoc! {"
2578 ## User (checkpoint)
2579
2580 Lorem
2581
2582 ## Assistant
2583
2584 LOREM
2585
2586 ## User (checkpoint)
2587
2588 ipsum
2589
2590 ## Assistant
2591
2592 IPSUM
2593
2594 "}
2595 );
2596 });
2597 assert_eq!(
2598 fs.files(),
2599 vec![
2600 Path::new(path!("/test/file-0")),
2601 Path::new(path!("/test/file-1"))
2602 ]
2603 );
2604
2605 // Checkpoint isn't stored when there are no changes.
2606 simulate_changes.store(false, SeqCst);
2607 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2608 .await
2609 .unwrap();
2610 thread.read_with(cx, |thread, cx| {
2611 assert_eq!(
2612 thread.to_markdown(cx),
2613 indoc! {"
2614 ## User (checkpoint)
2615
2616 Lorem
2617
2618 ## Assistant
2619
2620 LOREM
2621
2622 ## User (checkpoint)
2623
2624 ipsum
2625
2626 ## Assistant
2627
2628 IPSUM
2629
2630 ## User
2631
2632 dolor
2633
2634 ## Assistant
2635
2636 DOLOR
2637
2638 "}
2639 );
2640 });
2641 assert_eq!(
2642 fs.files(),
2643 vec![
2644 Path::new(path!("/test/file-0")),
2645 Path::new(path!("/test/file-1"))
2646 ]
2647 );
2648
2649 // Rewinding the conversation truncates the history and restores the checkpoint.
2650 thread
2651 .update(cx, |thread, cx| {
2652 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2653 panic!("unexpected entries {:?}", thread.entries)
2654 };
2655 thread.rewind(message.id.clone().unwrap(), cx)
2656 })
2657 .await
2658 .unwrap();
2659 thread.read_with(cx, |thread, cx| {
2660 assert_eq!(
2661 thread.to_markdown(cx),
2662 indoc! {"
2663 ## User (checkpoint)
2664
2665 Lorem
2666
2667 ## Assistant
2668
2669 LOREM
2670
2671 "}
2672 );
2673 });
2674 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2675 }
2676
2677 #[gpui::test]
2678 async fn test_refusal(cx: &mut TestAppContext) {
2679 init_test(cx);
2680 let fs = FakeFs::new(cx.background_executor.clone());
2681 fs.insert_tree(path!("/"), json!({})).await;
2682 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2683
2684 let refuse_next = Arc::new(AtomicBool::new(false));
2685 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2686 let refuse_next = refuse_next.clone();
2687 move |request, thread, mut cx| {
2688 let refuse_next = refuse_next.clone();
2689 async move {
2690 if refuse_next.load(SeqCst) {
2691 return Ok(acp::PromptResponse {
2692 stop_reason: acp::StopReason::Refusal,
2693 });
2694 }
2695
2696 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2697 panic!("expected text content block");
2698 };
2699 thread.update(&mut cx, |thread, cx| {
2700 thread
2701 .handle_session_update(
2702 acp::SessionUpdate::AgentMessageChunk {
2703 content: content.text.to_uppercase().into(),
2704 },
2705 cx,
2706 )
2707 .unwrap();
2708 })?;
2709 Ok(acp::PromptResponse {
2710 stop_reason: acp::StopReason::EndTurn,
2711 })
2712 }
2713 .boxed_local()
2714 }
2715 }));
2716 let thread = cx
2717 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2718 .await
2719 .unwrap();
2720
2721 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2722 .await
2723 .unwrap();
2724 thread.read_with(cx, |thread, cx| {
2725 assert_eq!(
2726 thread.to_markdown(cx),
2727 indoc! {"
2728 ## User
2729
2730 hello
2731
2732 ## Assistant
2733
2734 HELLO
2735
2736 "}
2737 );
2738 });
2739
2740 // Simulate refusing the second message, ensuring the conversation gets
2741 // truncated to before sending it.
2742 refuse_next.store(true, SeqCst);
2743 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2744 .await
2745 .unwrap();
2746 thread.read_with(cx, |thread, cx| {
2747 assert_eq!(
2748 thread.to_markdown(cx),
2749 indoc! {"
2750 ## User
2751
2752 hello
2753
2754 ## Assistant
2755
2756 HELLO
2757
2758 "}
2759 );
2760 });
2761 }
2762
2763 async fn run_until_first_tool_call(
2764 thread: &Entity<AcpThread>,
2765 cx: &mut TestAppContext,
2766 ) -> usize {
2767 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2768
2769 let subscription = cx.update(|cx| {
2770 cx.subscribe(thread, move |thread, _, cx| {
2771 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2772 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2773 return tx.try_send(ix).unwrap();
2774 }
2775 }
2776 })
2777 });
2778
2779 select! {
2780 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2781 panic!("Timeout waiting for tool call")
2782 }
2783 ix = rx.next().fuse() => {
2784 drop(subscription);
2785 ix.unwrap()
2786 }
2787 }
2788 }
2789
2790 #[derive(Clone, Default)]
2791 struct FakeAgentConnection {
2792 auth_methods: Vec<acp::AuthMethod>,
2793 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2794 on_user_message: Option<
2795 Rc<
2796 dyn Fn(
2797 acp::PromptRequest,
2798 WeakEntity<AcpThread>,
2799 AsyncApp,
2800 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2801 + 'static,
2802 >,
2803 >,
2804 }
2805
2806 impl FakeAgentConnection {
2807 fn new() -> Self {
2808 Self {
2809 auth_methods: Vec::new(),
2810 on_user_message: None,
2811 sessions: Arc::default(),
2812 }
2813 }
2814
2815 #[expect(unused)]
2816 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2817 self.auth_methods = auth_methods;
2818 self
2819 }
2820
2821 fn on_user_message(
2822 mut self,
2823 handler: impl Fn(
2824 acp::PromptRequest,
2825 WeakEntity<AcpThread>,
2826 AsyncApp,
2827 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2828 + 'static,
2829 ) -> Self {
2830 self.on_user_message.replace(Rc::new(handler));
2831 self
2832 }
2833 }
2834
2835 impl AgentConnection for FakeAgentConnection {
2836 fn auth_methods(&self) -> &[acp::AuthMethod] {
2837 &self.auth_methods
2838 }
2839
2840 fn new_thread(
2841 self: Rc<Self>,
2842 project: Entity<Project>,
2843 _cwd: &Path,
2844 cx: &mut App,
2845 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2846 let session_id = acp::SessionId(
2847 rand::thread_rng()
2848 .sample_iter(&rand::distributions::Alphanumeric)
2849 .take(7)
2850 .map(char::from)
2851 .collect::<String>()
2852 .into(),
2853 );
2854 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2855 let thread = cx.new(|cx| {
2856 AcpThread::new(
2857 "Test",
2858 self.clone(),
2859 project,
2860 action_log,
2861 session_id.clone(),
2862 watch::Receiver::constant(acp::PromptCapabilities {
2863 image: true,
2864 audio: true,
2865 embedded_context: true,
2866 }),
2867 cx,
2868 )
2869 });
2870 self.sessions.lock().insert(session_id, thread.downgrade());
2871 Task::ready(Ok(thread))
2872 }
2873
2874 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2875 if self.auth_methods().iter().any(|m| m.id == method) {
2876 Task::ready(Ok(()))
2877 } else {
2878 Task::ready(Err(anyhow!("Invalid Auth Method")))
2879 }
2880 }
2881
2882 fn prompt(
2883 &self,
2884 _id: Option<UserMessageId>,
2885 params: acp::PromptRequest,
2886 cx: &mut App,
2887 ) -> Task<gpui::Result<acp::PromptResponse>> {
2888 let sessions = self.sessions.lock();
2889 let thread = sessions.get(¶ms.session_id).unwrap();
2890 if let Some(handler) = &self.on_user_message {
2891 let handler = handler.clone();
2892 let thread = thread.clone();
2893 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2894 } else {
2895 Task::ready(Ok(acp::PromptResponse {
2896 stop_reason: acp::StopReason::EndTurn,
2897 }))
2898 }
2899 }
2900
2901 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2902 let sessions = self.sessions.lock();
2903 let thread = sessions.get(session_id).unwrap().clone();
2904
2905 cx.spawn(async move |cx| {
2906 thread
2907 .update(cx, |thread, cx| thread.cancel(cx))
2908 .unwrap()
2909 .await
2910 })
2911 .detach();
2912 }
2913
2914 fn truncate(
2915 &self,
2916 session_id: &acp::SessionId,
2917 _cx: &App,
2918 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2919 Some(Rc::new(FakeAgentSessionEditor {
2920 _session_id: session_id.clone(),
2921 }))
2922 }
2923
2924 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2925 self
2926 }
2927 }
2928
2929 struct FakeAgentSessionEditor {
2930 _session_id: acp::SessionId,
2931 }
2932
2933 impl AgentSessionTruncate for FakeAgentSessionEditor {
2934 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2935 Task::ready(Ok(()))
2936 }
2937 }
2938}