1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use util::{ResultExt as _, TryFutureExt as _, post_inc};
39use uuid::Uuid;
40use zed_llm_client::{CompletionMode, CompletionRequestStatus};
41
42use crate::ThreadStore;
43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
44use crate::thread_store::{
45 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
46 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
47};
48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
101#[derive(Clone, Debug)]
102pub struct MessageCrease {
103 pub range: Range<usize>,
104 pub metadata: CreaseMetadata,
105 /// None for a deserialized message, Some otherwise.
106 pub context: Option<AgentContextHandle>,
107}
108
109/// A message in a [`Thread`].
110#[derive(Debug, Clone)]
111pub struct Message {
112 pub id: MessageId,
113 pub role: Role,
114 pub segments: Vec<MessageSegment>,
115 pub loaded_context: LoadedContext,
116 pub creases: Vec<MessageCrease>,
117}
118
119impl Message {
120 /// Returns whether the message contains any meaningful text that should be displayed
121 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
122 pub fn should_display_content(&self) -> bool {
123 self.segments.iter().all(|segment| segment.should_display())
124 }
125
126 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
127 if let Some(MessageSegment::Thinking {
128 text: segment,
129 signature: current_signature,
130 }) = self.segments.last_mut()
131 {
132 if let Some(signature) = signature {
133 *current_signature = Some(signature);
134 }
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Thinking {
138 text: text.to_string(),
139 signature,
140 });
141 }
142 }
143
144 pub fn push_text(&mut self, text: &str) {
145 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
146 segment.push_str(text);
147 } else {
148 self.segments.push(MessageSegment::Text(text.to_string()));
149 }
150 }
151
152 pub fn to_string(&self) -> String {
153 let mut result = String::new();
154
155 if !self.loaded_context.text.is_empty() {
156 result.push_str(&self.loaded_context.text);
157 }
158
159 for segment in &self.segments {
160 match segment {
161 MessageSegment::Text(text) => result.push_str(text),
162 MessageSegment::Thinking { text, .. } => {
163 result.push_str("<think>\n");
164 result.push_str(text);
165 result.push_str("\n</think>");
166 }
167 MessageSegment::RedactedThinking(_) => {}
168 }
169 }
170
171 result
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum MessageSegment {
177 Text(String),
178 Thinking {
179 text: String,
180 signature: Option<String>,
181 },
182 RedactedThinking(Vec<u8>),
183}
184
185impl MessageSegment {
186 pub fn should_display(&self) -> bool {
187 match self {
188 Self::Text(text) => text.is_empty(),
189 Self::Thinking { text, .. } => text.is_empty(),
190 Self::RedactedThinking(_) => false,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ProjectSnapshot {
197 pub worktree_snapshots: Vec<WorktreeSnapshot>,
198 pub unsaved_buffer_paths: Vec<String>,
199 pub timestamp: DateTime<Utc>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WorktreeSnapshot {
204 pub worktree_path: String,
205 pub git_state: Option<GitState>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GitState {
210 pub remote_url: Option<String>,
211 pub head_sha: Option<String>,
212 pub current_branch: Option<String>,
213 pub diff: Option<String>,
214}
215
216#[derive(Clone)]
217pub struct ThreadCheckpoint {
218 message_id: MessageId,
219 git_checkpoint: GitStoreCheckpoint,
220}
221
222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
223pub enum ThreadFeedback {
224 Positive,
225 Negative,
226}
227
228pub enum LastRestoreCheckpoint {
229 Pending {
230 message_id: MessageId,
231 },
232 Error {
233 message_id: MessageId,
234 error: String,
235 },
236}
237
238impl LastRestoreCheckpoint {
239 pub fn message_id(&self) -> MessageId {
240 match self {
241 LastRestoreCheckpoint::Pending { message_id } => *message_id,
242 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub enum DetailedSummaryState {
249 #[default]
250 NotGenerated,
251 Generating {
252 message_id: MessageId,
253 },
254 Generated {
255 text: SharedString,
256 message_id: MessageId,
257 },
258}
259
260impl DetailedSummaryState {
261 fn text(&self) -> Option<SharedString> {
262 if let Self::Generated { text, .. } = self {
263 Some(text.clone())
264 } else {
265 None
266 }
267 }
268}
269
270#[derive(Default)]
271pub struct TotalTokenUsage {
272 pub total: usize,
273 pub max: usize,
274}
275
276impl TotalTokenUsage {
277 pub fn ratio(&self) -> TokenUsageRatio {
278 #[cfg(debug_assertions)]
279 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
280 .unwrap_or("0.8".to_string())
281 .parse()
282 .unwrap();
283 #[cfg(not(debug_assertions))]
284 let warning_threshold: f32 = 0.8;
285
286 // When the maximum is unknown because there is no selected model,
287 // avoid showing the token limit warning.
288 if self.max == 0 {
289 TokenUsageRatio::Normal
290 } else if self.total >= self.max {
291 TokenUsageRatio::Exceeded
292 } else if self.total as f32 / self.max as f32 >= warning_threshold {
293 TokenUsageRatio::Warning
294 } else {
295 TokenUsageRatio::Normal
296 }
297 }
298
299 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
300 TotalTokenUsage {
301 total: self.total + tokens,
302 max: self.max,
303 }
304 }
305}
306
307#[derive(Debug, Default, PartialEq, Eq)]
308pub enum TokenUsageRatio {
309 #[default]
310 Normal,
311 Warning,
312 Exceeded,
313}
314
315fn default_completion_mode(cx: &App) -> CompletionMode {
316 if cx.is_staff() {
317 CompletionMode::Max
318 } else {
319 CompletionMode::Normal
320 }
321}
322
323#[derive(Debug, Clone, Copy)]
324pub enum QueueState {
325 Sending,
326 Queued { position: usize },
327 Started,
328}
329
330/// A thread of conversation with the LLM.
331pub struct Thread {
332 id: ThreadId,
333 updated_at: DateTime<Utc>,
334 summary: Option<SharedString>,
335 pending_summary: Task<Option<()>>,
336 detailed_summary_task: Task<Option<()>>,
337 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
338 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
339 completion_mode: CompletionMode,
340 messages: Vec<Message>,
341 next_message_id: MessageId,
342 last_prompt_id: PromptId,
343 project_context: SharedProjectContext,
344 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
345 completion_count: usize,
346 pending_completions: Vec<PendingCompletion>,
347 project: Entity<Project>,
348 prompt_builder: Arc<PromptBuilder>,
349 tools: Entity<ToolWorkingSet>,
350 tool_use: ToolUseState,
351 action_log: Entity<ActionLog>,
352 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
353 pending_checkpoint: Option<ThreadCheckpoint>,
354 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
355 request_token_usage: Vec<TokenUsage>,
356 cumulative_token_usage: TokenUsage,
357 exceeded_window_error: Option<ExceededWindowError>,
358 last_usage: Option<RequestUsage>,
359 tool_use_limit_reached: bool,
360 feedback: Option<ThreadFeedback>,
361 message_feedback: HashMap<MessageId, ThreadFeedback>,
362 last_auto_capture_at: Option<Instant>,
363 last_received_chunk_at: Option<Instant>,
364 request_callback: Option<
365 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
366 >,
367 remaining_turns: u32,
368 configured_model: Option<ConfiguredModel>,
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct ExceededWindowError {
373 /// Model used when last message exceeded context window
374 model_id: LanguageModelId,
375 /// Token count including last message
376 token_count: usize,
377}
378
379impl Thread {
380 pub fn new(
381 project: Entity<Project>,
382 tools: Entity<ToolWorkingSet>,
383 prompt_builder: Arc<PromptBuilder>,
384 system_prompt: SharedProjectContext,
385 cx: &mut Context<Self>,
386 ) -> Self {
387 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
388 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
389
390 Self {
391 id: ThreadId::new(),
392 updated_at: Utc::now(),
393 summary: None,
394 pending_summary: Task::ready(None),
395 detailed_summary_task: Task::ready(None),
396 detailed_summary_tx,
397 detailed_summary_rx,
398 completion_mode: default_completion_mode(cx),
399 messages: Vec::new(),
400 next_message_id: MessageId(0),
401 last_prompt_id: PromptId::new(),
402 project_context: system_prompt,
403 checkpoints_by_message: HashMap::default(),
404 completion_count: 0,
405 pending_completions: Vec::new(),
406 project: project.clone(),
407 prompt_builder,
408 tools: tools.clone(),
409 last_restore_checkpoint: None,
410 pending_checkpoint: None,
411 tool_use: ToolUseState::new(tools.clone()),
412 action_log: cx.new(|_| ActionLog::new(project.clone())),
413 initial_project_snapshot: {
414 let project_snapshot = Self::project_snapshot(project, cx);
415 cx.foreground_executor()
416 .spawn(async move { Some(project_snapshot.await) })
417 .shared()
418 },
419 request_token_usage: Vec::new(),
420 cumulative_token_usage: TokenUsage::default(),
421 exceeded_window_error: None,
422 last_usage: None,
423 tool_use_limit_reached: false,
424 feedback: None,
425 message_feedback: HashMap::default(),
426 last_auto_capture_at: None,
427 last_received_chunk_at: None,
428 request_callback: None,
429 remaining_turns: u32::MAX,
430 configured_model,
431 }
432 }
433
434 pub fn deserialize(
435 id: ThreadId,
436 serialized: SerializedThread,
437 project: Entity<Project>,
438 tools: Entity<ToolWorkingSet>,
439 prompt_builder: Arc<PromptBuilder>,
440 project_context: SharedProjectContext,
441 cx: &mut Context<Self>,
442 ) -> Self {
443 let next_message_id = MessageId(
444 serialized
445 .messages
446 .last()
447 .map(|message| message.id.0 + 1)
448 .unwrap_or(0),
449 );
450 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
451 let (detailed_summary_tx, detailed_summary_rx) =
452 postage::watch::channel_with(serialized.detailed_summary_state);
453
454 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
455 serialized
456 .model
457 .and_then(|model| {
458 let model = SelectedModel {
459 provider: model.provider.clone().into(),
460 model: model.model.clone().into(),
461 };
462 registry.select_model(&model, cx)
463 })
464 .or_else(|| registry.default_model())
465 });
466
467 Self {
468 id,
469 updated_at: serialized.updated_at,
470 summary: Some(serialized.summary),
471 pending_summary: Task::ready(None),
472 detailed_summary_task: Task::ready(None),
473 detailed_summary_tx,
474 detailed_summary_rx,
475 completion_mode: default_completion_mode(cx),
476 messages: serialized
477 .messages
478 .into_iter()
479 .map(|message| Message {
480 id: message.id,
481 role: message.role,
482 segments: message
483 .segments
484 .into_iter()
485 .map(|segment| match segment {
486 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
487 SerializedMessageSegment::Thinking { text, signature } => {
488 MessageSegment::Thinking { text, signature }
489 }
490 SerializedMessageSegment::RedactedThinking { data } => {
491 MessageSegment::RedactedThinking(data)
492 }
493 })
494 .collect(),
495 loaded_context: LoadedContext {
496 contexts: Vec::new(),
497 text: message.context,
498 images: Vec::new(),
499 },
500 creases: message
501 .creases
502 .into_iter()
503 .map(|crease| MessageCrease {
504 range: crease.start..crease.end,
505 metadata: CreaseMetadata {
506 icon_path: crease.icon_path,
507 label: crease.label,
508 },
509 context: None,
510 })
511 .collect(),
512 })
513 .collect(),
514 next_message_id,
515 last_prompt_id: PromptId::new(),
516 project_context,
517 checkpoints_by_message: HashMap::default(),
518 completion_count: 0,
519 pending_completions: Vec::new(),
520 last_restore_checkpoint: None,
521 pending_checkpoint: None,
522 project: project.clone(),
523 prompt_builder,
524 tools,
525 tool_use,
526 action_log: cx.new(|_| ActionLog::new(project)),
527 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
528 request_token_usage: serialized.request_token_usage,
529 cumulative_token_usage: serialized.cumulative_token_usage,
530 exceeded_window_error: None,
531 last_usage: None,
532 tool_use_limit_reached: false,
533 feedback: None,
534 message_feedback: HashMap::default(),
535 last_auto_capture_at: None,
536 last_received_chunk_at: None,
537 request_callback: None,
538 remaining_turns: u32::MAX,
539 configured_model,
540 }
541 }
542
543 pub fn set_request_callback(
544 &mut self,
545 callback: impl 'static
546 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
547 ) {
548 self.request_callback = Some(Box::new(callback));
549 }
550
551 pub fn id(&self) -> &ThreadId {
552 &self.id
553 }
554
555 pub fn is_empty(&self) -> bool {
556 self.messages.is_empty()
557 }
558
559 pub fn updated_at(&self) -> DateTime<Utc> {
560 self.updated_at
561 }
562
563 pub fn touch_updated_at(&mut self) {
564 self.updated_at = Utc::now();
565 }
566
567 pub fn advance_prompt_id(&mut self) {
568 self.last_prompt_id = PromptId::new();
569 }
570
571 pub fn summary(&self) -> Option<SharedString> {
572 self.summary.clone()
573 }
574
575 pub fn project_context(&self) -> SharedProjectContext {
576 self.project_context.clone()
577 }
578
579 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
580 if self.configured_model.is_none() {
581 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
582 }
583 self.configured_model.clone()
584 }
585
586 pub fn configured_model(&self) -> Option<ConfiguredModel> {
587 self.configured_model.clone()
588 }
589
590 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
591 self.configured_model = model;
592 cx.notify();
593 }
594
595 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
596
597 pub fn summary_or_default(&self) -> SharedString {
598 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
599 }
600
601 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
602 let Some(current_summary) = &self.summary else {
603 // Don't allow setting summary until generated
604 return;
605 };
606
607 let mut new_summary = new_summary.into();
608
609 if new_summary.is_empty() {
610 new_summary = Self::DEFAULT_SUMMARY;
611 }
612
613 if current_summary != &new_summary {
614 self.summary = Some(new_summary);
615 cx.emit(ThreadEvent::SummaryChanged);
616 }
617 }
618
619 pub fn completion_mode(&self) -> CompletionMode {
620 self.completion_mode
621 }
622
623 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
624 self.completion_mode = mode;
625 }
626
627 pub fn message(&self, id: MessageId) -> Option<&Message> {
628 let index = self
629 .messages
630 .binary_search_by(|message| message.id.cmp(&id))
631 .ok()?;
632
633 self.messages.get(index)
634 }
635
636 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
637 self.messages.iter()
638 }
639
640 pub fn is_generating(&self) -> bool {
641 !self.pending_completions.is_empty() || !self.all_tools_finished()
642 }
643
644 /// Indicates whether streaming of language model events is stale.
645 /// When `is_generating()` is false, this method returns `None`.
646 pub fn is_generation_stale(&self) -> Option<bool> {
647 const STALE_THRESHOLD: u128 = 250;
648
649 self.last_received_chunk_at
650 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
651 }
652
653 fn received_chunk(&mut self) {
654 self.last_received_chunk_at = Some(Instant::now());
655 }
656
657 pub fn queue_state(&self) -> Option<QueueState> {
658 self.pending_completions
659 .first()
660 .map(|pending_completion| pending_completion.queue_state)
661 }
662
663 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
664 &self.tools
665 }
666
667 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
668 self.tool_use
669 .pending_tool_uses()
670 .into_iter()
671 .find(|tool_use| &tool_use.id == id)
672 }
673
674 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
675 self.tool_use
676 .pending_tool_uses()
677 .into_iter()
678 .filter(|tool_use| tool_use.status.needs_confirmation())
679 }
680
681 pub fn has_pending_tool_uses(&self) -> bool {
682 !self.tool_use.pending_tool_uses().is_empty()
683 }
684
685 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
686 self.checkpoints_by_message.get(&id).cloned()
687 }
688
689 pub fn restore_checkpoint(
690 &mut self,
691 checkpoint: ThreadCheckpoint,
692 cx: &mut Context<Self>,
693 ) -> Task<Result<()>> {
694 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
695 message_id: checkpoint.message_id,
696 });
697 cx.emit(ThreadEvent::CheckpointChanged);
698 cx.notify();
699
700 let git_store = self.project().read(cx).git_store().clone();
701 let restore = git_store.update(cx, |git_store, cx| {
702 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
703 });
704
705 cx.spawn(async move |this, cx| {
706 let result = restore.await;
707 this.update(cx, |this, cx| {
708 if let Err(err) = result.as_ref() {
709 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
710 message_id: checkpoint.message_id,
711 error: err.to_string(),
712 });
713 } else {
714 this.truncate(checkpoint.message_id, cx);
715 this.last_restore_checkpoint = None;
716 }
717 this.pending_checkpoint = None;
718 cx.emit(ThreadEvent::CheckpointChanged);
719 cx.notify();
720 })?;
721 result
722 })
723 }
724
725 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
726 let pending_checkpoint = if self.is_generating() {
727 return;
728 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
729 checkpoint
730 } else {
731 return;
732 };
733
734 let git_store = self.project.read(cx).git_store().clone();
735 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
736 cx.spawn(async move |this, cx| match final_checkpoint.await {
737 Ok(final_checkpoint) => {
738 let equal = git_store
739 .update(cx, |store, cx| {
740 store.compare_checkpoints(
741 pending_checkpoint.git_checkpoint.clone(),
742 final_checkpoint.clone(),
743 cx,
744 )
745 })?
746 .await
747 .unwrap_or(false);
748
749 if !equal {
750 this.update(cx, |this, cx| {
751 this.insert_checkpoint(pending_checkpoint, cx)
752 })?;
753 }
754
755 Ok(())
756 }
757 Err(_) => this.update(cx, |this, cx| {
758 this.insert_checkpoint(pending_checkpoint, cx)
759 }),
760 })
761 .detach();
762 }
763
764 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
765 self.checkpoints_by_message
766 .insert(checkpoint.message_id, checkpoint);
767 cx.emit(ThreadEvent::CheckpointChanged);
768 cx.notify();
769 }
770
771 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
772 self.last_restore_checkpoint.as_ref()
773 }
774
775 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
776 let Some(message_ix) = self
777 .messages
778 .iter()
779 .rposition(|message| message.id == message_id)
780 else {
781 return;
782 };
783 for deleted_message in self.messages.drain(message_ix..) {
784 self.checkpoints_by_message.remove(&deleted_message.id);
785 }
786 cx.notify();
787 }
788
789 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
790 self.messages
791 .iter()
792 .find(|message| message.id == id)
793 .into_iter()
794 .flat_map(|message| message.loaded_context.contexts.iter())
795 }
796
797 pub fn is_turn_end(&self, ix: usize) -> bool {
798 if self.messages.is_empty() {
799 return false;
800 }
801
802 if !self.is_generating() && ix == self.messages.len() - 1 {
803 return true;
804 }
805
806 let Some(message) = self.messages.get(ix) else {
807 return false;
808 };
809
810 if message.role != Role::Assistant {
811 return false;
812 }
813
814 self.messages
815 .get(ix + 1)
816 .and_then(|message| {
817 self.message(message.id)
818 .map(|next_message| next_message.role == Role::User)
819 })
820 .unwrap_or(false)
821 }
822
823 pub fn last_usage(&self) -> Option<RequestUsage> {
824 self.last_usage
825 }
826
827 pub fn tool_use_limit_reached(&self) -> bool {
828 self.tool_use_limit_reached
829 }
830
831 /// Returns whether all of the tool uses have finished running.
832 pub fn all_tools_finished(&self) -> bool {
833 // If the only pending tool uses left are the ones with errors, then
834 // that means that we've finished running all of the pending tools.
835 self.tool_use
836 .pending_tool_uses()
837 .iter()
838 .all(|tool_use| tool_use.status.is_error())
839 }
840
841 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
842 self.tool_use.tool_uses_for_message(id, cx)
843 }
844
845 pub fn tool_results_for_message(
846 &self,
847 assistant_message_id: MessageId,
848 ) -> Vec<&LanguageModelToolResult> {
849 self.tool_use.tool_results_for_message(assistant_message_id)
850 }
851
852 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
853 self.tool_use.tool_result(id)
854 }
855
856 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
857 Some(&self.tool_use.tool_result(id)?.content)
858 }
859
860 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
861 self.tool_use.tool_result_card(id).cloned()
862 }
863
864 /// Return tools that are both enabled and supported by the model
865 pub fn available_tools(
866 &self,
867 cx: &App,
868 model: Arc<dyn LanguageModel>,
869 ) -> Vec<LanguageModelRequestTool> {
870 if model.supports_tools() {
871 self.tools()
872 .read(cx)
873 .enabled_tools(cx)
874 .into_iter()
875 .filter_map(|tool| {
876 // Skip tools that cannot be supported
877 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
878 Some(LanguageModelRequestTool {
879 name: tool.name(),
880 description: tool.description(),
881 input_schema,
882 })
883 })
884 .collect()
885 } else {
886 Vec::default()
887 }
888 }
889
890 pub fn insert_user_message(
891 &mut self,
892 text: impl Into<String>,
893 loaded_context: ContextLoadResult,
894 git_checkpoint: Option<GitStoreCheckpoint>,
895 creases: Vec<MessageCrease>,
896 cx: &mut Context<Self>,
897 ) -> MessageId {
898 if !loaded_context.referenced_buffers.is_empty() {
899 self.action_log.update(cx, |log, cx| {
900 for buffer in loaded_context.referenced_buffers {
901 log.buffer_read(buffer, cx);
902 }
903 });
904 }
905
906 let message_id = self.insert_message(
907 Role::User,
908 vec![MessageSegment::Text(text.into())],
909 loaded_context.loaded_context,
910 creases,
911 cx,
912 );
913
914 if let Some(git_checkpoint) = git_checkpoint {
915 self.pending_checkpoint = Some(ThreadCheckpoint {
916 message_id,
917 git_checkpoint,
918 });
919 }
920
921 self.auto_capture_telemetry(cx);
922
923 message_id
924 }
925
926 pub fn insert_assistant_message(
927 &mut self,
928 segments: Vec<MessageSegment>,
929 cx: &mut Context<Self>,
930 ) -> MessageId {
931 self.insert_message(
932 Role::Assistant,
933 segments,
934 LoadedContext::default(),
935 Vec::new(),
936 cx,
937 )
938 }
939
940 pub fn insert_message(
941 &mut self,
942 role: Role,
943 segments: Vec<MessageSegment>,
944 loaded_context: LoadedContext,
945 creases: Vec<MessageCrease>,
946 cx: &mut Context<Self>,
947 ) -> MessageId {
948 let id = self.next_message_id.post_inc();
949 self.messages.push(Message {
950 id,
951 role,
952 segments,
953 loaded_context,
954 creases,
955 });
956 self.touch_updated_at();
957 cx.emit(ThreadEvent::MessageAdded(id));
958 id
959 }
960
961 pub fn edit_message(
962 &mut self,
963 id: MessageId,
964 new_role: Role,
965 new_segments: Vec<MessageSegment>,
966 loaded_context: Option<LoadedContext>,
967 cx: &mut Context<Self>,
968 ) -> bool {
969 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
970 return false;
971 };
972 message.role = new_role;
973 message.segments = new_segments;
974 if let Some(context) = loaded_context {
975 message.loaded_context = context;
976 }
977 self.touch_updated_at();
978 cx.emit(ThreadEvent::MessageEdited(id));
979 true
980 }
981
982 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
983 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
984 return false;
985 };
986 self.messages.remove(index);
987 self.touch_updated_at();
988 cx.emit(ThreadEvent::MessageDeleted(id));
989 true
990 }
991
992 /// Returns the representation of this [`Thread`] in a textual form.
993 ///
994 /// This is the representation we use when attaching a thread as context to another thread.
995 pub fn text(&self) -> String {
996 let mut text = String::new();
997
998 for message in &self.messages {
999 text.push_str(match message.role {
1000 language_model::Role::User => "User:",
1001 language_model::Role::Assistant => "Assistant:",
1002 language_model::Role::System => "System:",
1003 });
1004 text.push('\n');
1005
1006 for segment in &message.segments {
1007 match segment {
1008 MessageSegment::Text(content) => text.push_str(content),
1009 MessageSegment::Thinking { text: content, .. } => {
1010 text.push_str(&format!("<think>{}</think>", content))
1011 }
1012 MessageSegment::RedactedThinking(_) => {}
1013 }
1014 }
1015 text.push('\n');
1016 }
1017
1018 text
1019 }
1020
1021 /// Serializes this thread into a format for storage or telemetry.
1022 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1023 let initial_project_snapshot = self.initial_project_snapshot.clone();
1024 cx.spawn(async move |this, cx| {
1025 let initial_project_snapshot = initial_project_snapshot.await;
1026 this.read_with(cx, |this, cx| SerializedThread {
1027 version: SerializedThread::VERSION.to_string(),
1028 summary: this.summary_or_default(),
1029 updated_at: this.updated_at(),
1030 messages: this
1031 .messages()
1032 .map(|message| SerializedMessage {
1033 id: message.id,
1034 role: message.role,
1035 segments: message
1036 .segments
1037 .iter()
1038 .map(|segment| match segment {
1039 MessageSegment::Text(text) => {
1040 SerializedMessageSegment::Text { text: text.clone() }
1041 }
1042 MessageSegment::Thinking { text, signature } => {
1043 SerializedMessageSegment::Thinking {
1044 text: text.clone(),
1045 signature: signature.clone(),
1046 }
1047 }
1048 MessageSegment::RedactedThinking(data) => {
1049 SerializedMessageSegment::RedactedThinking {
1050 data: data.clone(),
1051 }
1052 }
1053 })
1054 .collect(),
1055 tool_uses: this
1056 .tool_uses_for_message(message.id, cx)
1057 .into_iter()
1058 .map(|tool_use| SerializedToolUse {
1059 id: tool_use.id,
1060 name: tool_use.name,
1061 input: tool_use.input,
1062 })
1063 .collect(),
1064 tool_results: this
1065 .tool_results_for_message(message.id)
1066 .into_iter()
1067 .map(|tool_result| SerializedToolResult {
1068 tool_use_id: tool_result.tool_use_id.clone(),
1069 is_error: tool_result.is_error,
1070 content: tool_result.content.clone(),
1071 })
1072 .collect(),
1073 context: message.loaded_context.text.clone(),
1074 creases: message
1075 .creases
1076 .iter()
1077 .map(|crease| SerializedCrease {
1078 start: crease.range.start,
1079 end: crease.range.end,
1080 icon_path: crease.metadata.icon_path.clone(),
1081 label: crease.metadata.label.clone(),
1082 })
1083 .collect(),
1084 })
1085 .collect(),
1086 initial_project_snapshot,
1087 cumulative_token_usage: this.cumulative_token_usage,
1088 request_token_usage: this.request_token_usage.clone(),
1089 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1090 exceeded_window_error: this.exceeded_window_error.clone(),
1091 model: this
1092 .configured_model
1093 .as_ref()
1094 .map(|model| SerializedLanguageModel {
1095 provider: model.provider.id().0.to_string(),
1096 model: model.model.id().0.to_string(),
1097 }),
1098 })
1099 })
1100 }
1101
1102 pub fn remaining_turns(&self) -> u32 {
1103 self.remaining_turns
1104 }
1105
1106 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1107 self.remaining_turns = remaining_turns;
1108 }
1109
1110 pub fn send_to_model(
1111 &mut self,
1112 model: Arc<dyn LanguageModel>,
1113 window: Option<AnyWindowHandle>,
1114 cx: &mut Context<Self>,
1115 ) {
1116 if self.remaining_turns == 0 {
1117 return;
1118 }
1119
1120 self.remaining_turns -= 1;
1121
1122 let request = self.to_completion_request(model.clone(), cx);
1123
1124 self.stream_completion(request, model, window, cx);
1125 }
1126
1127 pub fn used_tools_since_last_user_message(&self) -> bool {
1128 for message in self.messages.iter().rev() {
1129 if self.tool_use.message_has_tool_results(message.id) {
1130 return true;
1131 } else if message.role == Role::User {
1132 return false;
1133 }
1134 }
1135
1136 false
1137 }
1138
1139 pub fn to_completion_request(
1140 &self,
1141 model: Arc<dyn LanguageModel>,
1142 cx: &mut Context<Self>,
1143 ) -> LanguageModelRequest {
1144 let mut request = LanguageModelRequest {
1145 thread_id: Some(self.id.to_string()),
1146 prompt_id: Some(self.last_prompt_id.to_string()),
1147 mode: None,
1148 messages: vec![],
1149 tools: Vec::new(),
1150 stop: Vec::new(),
1151 temperature: None,
1152 };
1153
1154 let available_tools = self.available_tools(cx, model.clone());
1155 let available_tool_names = available_tools
1156 .iter()
1157 .map(|tool| tool.name.clone())
1158 .collect();
1159
1160 let model_context = &ModelContext {
1161 available_tools: available_tool_names,
1162 };
1163
1164 if let Some(project_context) = self.project_context.borrow().as_ref() {
1165 match self
1166 .prompt_builder
1167 .generate_assistant_system_prompt(project_context, model_context)
1168 {
1169 Err(err) => {
1170 let message = format!("{err:?}").into();
1171 log::error!("{message}");
1172 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1173 header: "Error generating system prompt".into(),
1174 message,
1175 }));
1176 }
1177 Ok(system_prompt) => {
1178 request.messages.push(LanguageModelRequestMessage {
1179 role: Role::System,
1180 content: vec![MessageContent::Text(system_prompt)],
1181 cache: true,
1182 });
1183 }
1184 }
1185 } else {
1186 let message = "Context for system prompt unexpectedly not ready.".into();
1187 log::error!("{message}");
1188 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1189 header: "Error generating system prompt".into(),
1190 message,
1191 }));
1192 }
1193
1194 for message in &self.messages {
1195 let mut request_message = LanguageModelRequestMessage {
1196 role: message.role,
1197 content: Vec::new(),
1198 cache: false,
1199 };
1200
1201 message
1202 .loaded_context
1203 .add_to_request_message(&mut request_message);
1204
1205 for segment in &message.segments {
1206 match segment {
1207 MessageSegment::Text(text) => {
1208 if !text.is_empty() {
1209 request_message
1210 .content
1211 .push(MessageContent::Text(text.into()));
1212 }
1213 }
1214 MessageSegment::Thinking { text, signature } => {
1215 if !text.is_empty() {
1216 request_message.content.push(MessageContent::Thinking {
1217 text: text.into(),
1218 signature: signature.clone(),
1219 });
1220 }
1221 }
1222 MessageSegment::RedactedThinking(data) => {
1223 request_message
1224 .content
1225 .push(MessageContent::RedactedThinking(data.clone()));
1226 }
1227 };
1228 }
1229
1230 self.tool_use
1231 .attach_tool_uses(message.id, &mut request_message);
1232
1233 request.messages.push(request_message);
1234
1235 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1236 request.messages.push(tool_results_message);
1237 }
1238 }
1239
1240 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1241 if let Some(last) = request.messages.last_mut() {
1242 last.cache = true;
1243 }
1244
1245 self.attached_tracked_files_state(&mut request.messages, cx);
1246
1247 request.tools = available_tools;
1248 request.mode = if model.supports_max_mode() {
1249 Some(self.completion_mode)
1250 } else {
1251 Some(CompletionMode::Normal)
1252 };
1253
1254 request
1255 }
1256
1257 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1258 let mut request = LanguageModelRequest {
1259 thread_id: None,
1260 prompt_id: None,
1261 mode: None,
1262 messages: vec![],
1263 tools: Vec::new(),
1264 stop: Vec::new(),
1265 temperature: None,
1266 };
1267
1268 for message in &self.messages {
1269 let mut request_message = LanguageModelRequestMessage {
1270 role: message.role,
1271 content: Vec::new(),
1272 cache: false,
1273 };
1274
1275 for segment in &message.segments {
1276 match segment {
1277 MessageSegment::Text(text) => request_message
1278 .content
1279 .push(MessageContent::Text(text.clone())),
1280 MessageSegment::Thinking { .. } => {}
1281 MessageSegment::RedactedThinking(_) => {}
1282 }
1283 }
1284
1285 if request_message.content.is_empty() {
1286 continue;
1287 }
1288
1289 request.messages.push(request_message);
1290 }
1291
1292 request.messages.push(LanguageModelRequestMessage {
1293 role: Role::User,
1294 content: vec![MessageContent::Text(added_user_message)],
1295 cache: false,
1296 });
1297
1298 request
1299 }
1300
1301 fn attached_tracked_files_state(
1302 &self,
1303 messages: &mut Vec<LanguageModelRequestMessage>,
1304 cx: &App,
1305 ) {
1306 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1307
1308 let mut stale_message = String::new();
1309
1310 let action_log = self.action_log.read(cx);
1311
1312 for stale_file in action_log.stale_buffers(cx) {
1313 let Some(file) = stale_file.read(cx).file() else {
1314 continue;
1315 };
1316
1317 if stale_message.is_empty() {
1318 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1319 }
1320
1321 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1322 }
1323
1324 let mut content = Vec::with_capacity(2);
1325
1326 if !stale_message.is_empty() {
1327 content.push(stale_message.into());
1328 }
1329
1330 if !content.is_empty() {
1331 let context_message = LanguageModelRequestMessage {
1332 role: Role::User,
1333 content,
1334 cache: false,
1335 };
1336
1337 messages.push(context_message);
1338 }
1339 }
1340
1341 pub fn stream_completion(
1342 &mut self,
1343 request: LanguageModelRequest,
1344 model: Arc<dyn LanguageModel>,
1345 window: Option<AnyWindowHandle>,
1346 cx: &mut Context<Self>,
1347 ) {
1348 self.tool_use_limit_reached = false;
1349
1350 let pending_completion_id = post_inc(&mut self.completion_count);
1351 let mut request_callback_parameters = if self.request_callback.is_some() {
1352 Some((request.clone(), Vec::new()))
1353 } else {
1354 None
1355 };
1356 let prompt_id = self.last_prompt_id.clone();
1357 let tool_use_metadata = ToolUseMetadata {
1358 model: model.clone(),
1359 thread_id: self.id.clone(),
1360 prompt_id: prompt_id.clone(),
1361 };
1362
1363 self.last_received_chunk_at = Some(Instant::now());
1364
1365 let task = cx.spawn(async move |thread, cx| {
1366 let stream_completion_future = model.stream_completion(request, &cx);
1367 let initial_token_usage =
1368 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1369 let stream_completion = async {
1370 let mut events = stream_completion_future.await?;
1371
1372 let mut stop_reason = StopReason::EndTurn;
1373 let mut current_token_usage = TokenUsage::default();
1374
1375 thread
1376 .update(cx, |_thread, cx| {
1377 cx.emit(ThreadEvent::NewRequest);
1378 })
1379 .ok();
1380
1381 let mut request_assistant_message_id = None;
1382
1383 while let Some(event) = events.next().await {
1384 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1385 response_events
1386 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1387 }
1388
1389 thread.update(cx, |thread, cx| {
1390 let event = match event {
1391 Ok(event) => event,
1392 Err(LanguageModelCompletionError::BadInputJson {
1393 id,
1394 tool_name,
1395 raw_input: invalid_input_json,
1396 json_parse_error,
1397 }) => {
1398 thread.receive_invalid_tool_json(
1399 id,
1400 tool_name,
1401 invalid_input_json,
1402 json_parse_error,
1403 window,
1404 cx,
1405 );
1406 return Ok(());
1407 }
1408 Err(LanguageModelCompletionError::Other(error)) => {
1409 return Err(error);
1410 }
1411 };
1412
1413 match event {
1414 LanguageModelCompletionEvent::StartMessage { .. } => {
1415 request_assistant_message_id =
1416 Some(thread.insert_assistant_message(
1417 vec![MessageSegment::Text(String::new())],
1418 cx,
1419 ));
1420 }
1421 LanguageModelCompletionEvent::Stop(reason) => {
1422 stop_reason = reason;
1423 }
1424 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1425 thread.update_token_usage_at_last_message(token_usage);
1426 thread.cumulative_token_usage = thread.cumulative_token_usage
1427 + token_usage
1428 - current_token_usage;
1429 current_token_usage = token_usage;
1430 }
1431 LanguageModelCompletionEvent::Text(chunk) => {
1432 thread.received_chunk();
1433
1434 cx.emit(ThreadEvent::ReceivedTextChunk);
1435 if let Some(last_message) = thread.messages.last_mut() {
1436 if last_message.role == Role::Assistant
1437 && !thread.tool_use.has_tool_results(last_message.id)
1438 {
1439 last_message.push_text(&chunk);
1440 cx.emit(ThreadEvent::StreamedAssistantText(
1441 last_message.id,
1442 chunk,
1443 ));
1444 } else {
1445 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1446 // of a new Assistant response.
1447 //
1448 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1449 // will result in duplicating the text of the chunk in the rendered Markdown.
1450 request_assistant_message_id =
1451 Some(thread.insert_assistant_message(
1452 vec![MessageSegment::Text(chunk.to_string())],
1453 cx,
1454 ));
1455 };
1456 }
1457 }
1458 LanguageModelCompletionEvent::Thinking {
1459 text: chunk,
1460 signature,
1461 } => {
1462 thread.received_chunk();
1463
1464 if let Some(last_message) = thread.messages.last_mut() {
1465 if last_message.role == Role::Assistant
1466 && !thread.tool_use.has_tool_results(last_message.id)
1467 {
1468 last_message.push_thinking(&chunk, signature);
1469 cx.emit(ThreadEvent::StreamedAssistantThinking(
1470 last_message.id,
1471 chunk,
1472 ));
1473 } else {
1474 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1475 // of a new Assistant response.
1476 //
1477 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1478 // will result in duplicating the text of the chunk in the rendered Markdown.
1479 request_assistant_message_id =
1480 Some(thread.insert_assistant_message(
1481 vec![MessageSegment::Thinking {
1482 text: chunk.to_string(),
1483 signature,
1484 }],
1485 cx,
1486 ));
1487 };
1488 }
1489 }
1490 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1491 let last_assistant_message_id = request_assistant_message_id
1492 .unwrap_or_else(|| {
1493 let new_assistant_message_id =
1494 thread.insert_assistant_message(vec![], cx);
1495 request_assistant_message_id =
1496 Some(new_assistant_message_id);
1497 new_assistant_message_id
1498 });
1499
1500 let tool_use_id = tool_use.id.clone();
1501 let streamed_input = if tool_use.is_input_complete {
1502 None
1503 } else {
1504 Some((&tool_use.input).clone())
1505 };
1506
1507 let ui_text = thread.tool_use.request_tool_use(
1508 last_assistant_message_id,
1509 tool_use,
1510 tool_use_metadata.clone(),
1511 cx,
1512 );
1513
1514 if let Some(input) = streamed_input {
1515 cx.emit(ThreadEvent::StreamedToolUse {
1516 tool_use_id,
1517 ui_text,
1518 input,
1519 });
1520 }
1521 }
1522 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1523 if let Some(completion) = thread
1524 .pending_completions
1525 .iter_mut()
1526 .find(|completion| completion.id == pending_completion_id)
1527 {
1528 match status_update {
1529 CompletionRequestStatus::Queued {
1530 position,
1531 } => {
1532 completion.queue_state = QueueState::Queued { position };
1533 }
1534 CompletionRequestStatus::Started => {
1535 completion.queue_state = QueueState::Started;
1536 }
1537 CompletionRequestStatus::Failed {
1538 code, message
1539 } => {
1540 return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
1541 }
1542 CompletionRequestStatus::UsageUpdated {
1543 amount, limit
1544 } => {
1545 let usage = RequestUsage { limit, amount: amount as i32 };
1546
1547 thread.last_usage = Some(usage);
1548 }
1549 CompletionRequestStatus::ToolUseLimitReached => {
1550 thread.tool_use_limit_reached = true;
1551 }
1552 }
1553 }
1554 }
1555 }
1556
1557 thread.touch_updated_at();
1558 cx.emit(ThreadEvent::StreamedCompletion);
1559 cx.notify();
1560
1561 thread.auto_capture_telemetry(cx);
1562 Ok(())
1563 })??;
1564
1565 smol::future::yield_now().await;
1566 }
1567
1568 thread.update(cx, |thread, cx| {
1569 thread.last_received_chunk_at = None;
1570 thread
1571 .pending_completions
1572 .retain(|completion| completion.id != pending_completion_id);
1573
1574 // If there is a response without tool use, summarize the message. Otherwise,
1575 // allow two tool uses before summarizing.
1576 if thread.summary.is_none()
1577 && thread.messages.len() >= 2
1578 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1579 {
1580 thread.summarize(cx);
1581 }
1582 })?;
1583
1584 anyhow::Ok(stop_reason)
1585 };
1586
1587 let result = stream_completion.await;
1588
1589 thread
1590 .update(cx, |thread, cx| {
1591 thread.finalize_pending_checkpoint(cx);
1592 match result.as_ref() {
1593 Ok(stop_reason) => match stop_reason {
1594 StopReason::ToolUse => {
1595 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1596 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1597 }
1598 StopReason::EndTurn | StopReason::MaxTokens => {
1599 thread.project.update(cx, |project, cx| {
1600 project.set_agent_location(None, cx);
1601 });
1602 }
1603 },
1604 Err(error) => {
1605 thread.project.update(cx, |project, cx| {
1606 project.set_agent_location(None, cx);
1607 });
1608
1609 if error.is::<PaymentRequiredError>() {
1610 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1611 } else if error.is::<MaxMonthlySpendReachedError>() {
1612 cx.emit(ThreadEvent::ShowError(
1613 ThreadError::MaxMonthlySpendReached,
1614 ));
1615 } else if let Some(error) =
1616 error.downcast_ref::<ModelRequestLimitReachedError>()
1617 {
1618 cx.emit(ThreadEvent::ShowError(
1619 ThreadError::ModelRequestLimitReached { plan: error.plan },
1620 ));
1621 } else if let Some(known_error) =
1622 error.downcast_ref::<LanguageModelKnownError>()
1623 {
1624 match known_error {
1625 LanguageModelKnownError::ContextWindowLimitExceeded {
1626 tokens,
1627 } => {
1628 thread.exceeded_window_error = Some(ExceededWindowError {
1629 model_id: model.id(),
1630 token_count: *tokens,
1631 });
1632 cx.notify();
1633 }
1634 }
1635 } else {
1636 let error_message = error
1637 .chain()
1638 .map(|err| err.to_string())
1639 .collect::<Vec<_>>()
1640 .join("\n");
1641 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1642 header: "Error interacting with language model".into(),
1643 message: SharedString::from(error_message.clone()),
1644 }));
1645 }
1646
1647 thread.cancel_last_completion(window, cx);
1648 }
1649 }
1650 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1651
1652 if let Some((request_callback, (request, response_events))) = thread
1653 .request_callback
1654 .as_mut()
1655 .zip(request_callback_parameters.as_ref())
1656 {
1657 request_callback(request, response_events);
1658 }
1659
1660 thread.auto_capture_telemetry(cx);
1661
1662 if let Ok(initial_usage) = initial_token_usage {
1663 let usage = thread.cumulative_token_usage - initial_usage;
1664
1665 telemetry::event!(
1666 "Assistant Thread Completion",
1667 thread_id = thread.id().to_string(),
1668 prompt_id = prompt_id,
1669 model = model.telemetry_id(),
1670 model_provider = model.provider_id().to_string(),
1671 input_tokens = usage.input_tokens,
1672 output_tokens = usage.output_tokens,
1673 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1674 cache_read_input_tokens = usage.cache_read_input_tokens,
1675 );
1676 }
1677 })
1678 .ok();
1679 });
1680
1681 self.pending_completions.push(PendingCompletion {
1682 id: pending_completion_id,
1683 queue_state: QueueState::Sending,
1684 _task: task,
1685 });
1686 }
1687
1688 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1689 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1690 return;
1691 };
1692
1693 if !model.provider.is_authenticated(cx) {
1694 return;
1695 }
1696
1697 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1698 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1699 If the conversation is about a specific subject, include it in the title. \
1700 Be descriptive. DO NOT speak in the first person.";
1701
1702 let request = self.to_summarize_request(added_user_message.into());
1703
1704 self.pending_summary = cx.spawn(async move |this, cx| {
1705 async move {
1706 let mut messages = model.model.stream_completion(request, &cx).await?;
1707
1708 let mut new_summary = String::new();
1709 while let Some(event) = messages.next().await {
1710 let event = event?;
1711 let text = match event {
1712 LanguageModelCompletionEvent::Text(text) => text,
1713 LanguageModelCompletionEvent::StatusUpdate(
1714 CompletionRequestStatus::UsageUpdated { amount, limit },
1715 ) => {
1716 this.update(cx, |thread, _cx| {
1717 thread.last_usage = Some(RequestUsage {
1718 limit,
1719 amount: amount as i32,
1720 });
1721 })?;
1722 continue;
1723 }
1724 _ => continue,
1725 };
1726
1727 let mut lines = text.lines();
1728 new_summary.extend(lines.next());
1729
1730 // Stop if the LLM generated multiple lines.
1731 if lines.next().is_some() {
1732 break;
1733 }
1734 }
1735
1736 this.update(cx, |this, cx| {
1737 if !new_summary.is_empty() {
1738 this.summary = Some(new_summary.into());
1739 }
1740
1741 cx.emit(ThreadEvent::SummaryGenerated);
1742 })?;
1743
1744 anyhow::Ok(())
1745 }
1746 .log_err()
1747 .await
1748 });
1749 }
1750
1751 pub fn start_generating_detailed_summary_if_needed(
1752 &mut self,
1753 thread_store: WeakEntity<ThreadStore>,
1754 cx: &mut Context<Self>,
1755 ) {
1756 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1757 return;
1758 };
1759
1760 match &*self.detailed_summary_rx.borrow() {
1761 DetailedSummaryState::Generating { message_id, .. }
1762 | DetailedSummaryState::Generated { message_id, .. }
1763 if *message_id == last_message_id =>
1764 {
1765 // Already up-to-date
1766 return;
1767 }
1768 _ => {}
1769 }
1770
1771 let Some(ConfiguredModel { model, provider }) =
1772 LanguageModelRegistry::read_global(cx).thread_summary_model()
1773 else {
1774 return;
1775 };
1776
1777 if !provider.is_authenticated(cx) {
1778 return;
1779 }
1780
1781 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1782 1. A brief overview of what was discussed\n\
1783 2. Key facts or information discovered\n\
1784 3. Outcomes or conclusions reached\n\
1785 4. Any action items or next steps if any\n\
1786 Format it in Markdown with headings and bullet points.";
1787
1788 let request = self.to_summarize_request(added_user_message.into());
1789
1790 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1791 message_id: last_message_id,
1792 };
1793
1794 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1795 // be better to allow the old task to complete, but this would require logic for choosing
1796 // which result to prefer (the old task could complete after the new one, resulting in a
1797 // stale summary).
1798 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1799 let stream = model.stream_completion_text(request, &cx);
1800 let Some(mut messages) = stream.await.log_err() else {
1801 thread
1802 .update(cx, |thread, _cx| {
1803 *thread.detailed_summary_tx.borrow_mut() =
1804 DetailedSummaryState::NotGenerated;
1805 })
1806 .ok()?;
1807 return None;
1808 };
1809
1810 let mut new_detailed_summary = String::new();
1811
1812 while let Some(chunk) = messages.stream.next().await {
1813 if let Some(chunk) = chunk.log_err() {
1814 new_detailed_summary.push_str(&chunk);
1815 }
1816 }
1817
1818 thread
1819 .update(cx, |thread, _cx| {
1820 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1821 text: new_detailed_summary.into(),
1822 message_id: last_message_id,
1823 };
1824 })
1825 .ok()?;
1826
1827 // Save thread so its summary can be reused later
1828 if let Some(thread) = thread.upgrade() {
1829 if let Ok(Ok(save_task)) = cx.update(|cx| {
1830 thread_store
1831 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1832 }) {
1833 save_task.await.log_err();
1834 }
1835 }
1836
1837 Some(())
1838 });
1839 }
1840
1841 pub async fn wait_for_detailed_summary_or_text(
1842 this: &Entity<Self>,
1843 cx: &mut AsyncApp,
1844 ) -> Option<SharedString> {
1845 let mut detailed_summary_rx = this
1846 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1847 .ok()?;
1848 loop {
1849 match detailed_summary_rx.recv().await? {
1850 DetailedSummaryState::Generating { .. } => {}
1851 DetailedSummaryState::NotGenerated => {
1852 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1853 }
1854 DetailedSummaryState::Generated { text, .. } => return Some(text),
1855 }
1856 }
1857 }
1858
1859 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1860 self.detailed_summary_rx
1861 .borrow()
1862 .text()
1863 .unwrap_or_else(|| self.text().into())
1864 }
1865
1866 pub fn is_generating_detailed_summary(&self) -> bool {
1867 matches!(
1868 &*self.detailed_summary_rx.borrow(),
1869 DetailedSummaryState::Generating { .. }
1870 )
1871 }
1872
1873 pub fn use_pending_tools(
1874 &mut self,
1875 window: Option<AnyWindowHandle>,
1876 cx: &mut Context<Self>,
1877 model: Arc<dyn LanguageModel>,
1878 ) -> Vec<PendingToolUse> {
1879 self.auto_capture_telemetry(cx);
1880 let request = self.to_completion_request(model, cx);
1881 let messages = Arc::new(request.messages);
1882 let pending_tool_uses = self
1883 .tool_use
1884 .pending_tool_uses()
1885 .into_iter()
1886 .filter(|tool_use| tool_use.status.is_idle())
1887 .cloned()
1888 .collect::<Vec<_>>();
1889
1890 for tool_use in pending_tool_uses.iter() {
1891 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1892 if tool.needs_confirmation(&tool_use.input, cx)
1893 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1894 {
1895 self.tool_use.confirm_tool_use(
1896 tool_use.id.clone(),
1897 tool_use.ui_text.clone(),
1898 tool_use.input.clone(),
1899 messages.clone(),
1900 tool,
1901 );
1902 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1903 } else {
1904 self.run_tool(
1905 tool_use.id.clone(),
1906 tool_use.ui_text.clone(),
1907 tool_use.input.clone(),
1908 &messages,
1909 tool,
1910 window,
1911 cx,
1912 );
1913 }
1914 }
1915 }
1916
1917 pending_tool_uses
1918 }
1919
1920 pub fn receive_invalid_tool_json(
1921 &mut self,
1922 tool_use_id: LanguageModelToolUseId,
1923 tool_name: Arc<str>,
1924 invalid_json: Arc<str>,
1925 error: String,
1926 window: Option<AnyWindowHandle>,
1927 cx: &mut Context<Thread>,
1928 ) {
1929 log::error!("The model returned invalid input JSON: {invalid_json}");
1930
1931 let pending_tool_use = self.tool_use.insert_tool_output(
1932 tool_use_id.clone(),
1933 tool_name,
1934 Err(anyhow!("Error parsing input JSON: {error}")),
1935 self.configured_model.as_ref(),
1936 );
1937 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1938 pending_tool_use.ui_text.clone()
1939 } else {
1940 log::error!(
1941 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1942 );
1943 format!("Unknown tool {}", tool_use_id).into()
1944 };
1945
1946 cx.emit(ThreadEvent::InvalidToolInput {
1947 tool_use_id: tool_use_id.clone(),
1948 ui_text,
1949 invalid_input_json: invalid_json,
1950 });
1951
1952 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1953 }
1954
1955 pub fn run_tool(
1956 &mut self,
1957 tool_use_id: LanguageModelToolUseId,
1958 ui_text: impl Into<SharedString>,
1959 input: serde_json::Value,
1960 messages: &[LanguageModelRequestMessage],
1961 tool: Arc<dyn Tool>,
1962 window: Option<AnyWindowHandle>,
1963 cx: &mut Context<Thread>,
1964 ) {
1965 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1966 self.tool_use
1967 .run_pending_tool(tool_use_id, ui_text.into(), task);
1968 }
1969
1970 fn spawn_tool_use(
1971 &mut self,
1972 tool_use_id: LanguageModelToolUseId,
1973 messages: &[LanguageModelRequestMessage],
1974 input: serde_json::Value,
1975 tool: Arc<dyn Tool>,
1976 window: Option<AnyWindowHandle>,
1977 cx: &mut Context<Thread>,
1978 ) -> Task<()> {
1979 let tool_name: Arc<str> = tool.name().into();
1980
1981 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1982 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1983 } else {
1984 tool.run(
1985 input,
1986 messages,
1987 self.project.clone(),
1988 self.action_log.clone(),
1989 window,
1990 cx,
1991 )
1992 };
1993
1994 // Store the card separately if it exists
1995 if let Some(card) = tool_result.card.clone() {
1996 self.tool_use
1997 .insert_tool_result_card(tool_use_id.clone(), card);
1998 }
1999
2000 cx.spawn({
2001 async move |thread: WeakEntity<Thread>, cx| {
2002 let output = tool_result.output.await;
2003
2004 thread
2005 .update(cx, |thread, cx| {
2006 let pending_tool_use = thread.tool_use.insert_tool_output(
2007 tool_use_id.clone(),
2008 tool_name,
2009 output,
2010 thread.configured_model.as_ref(),
2011 );
2012 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2013 })
2014 .ok();
2015 }
2016 })
2017 }
2018
2019 fn tool_finished(
2020 &mut self,
2021 tool_use_id: LanguageModelToolUseId,
2022 pending_tool_use: Option<PendingToolUse>,
2023 canceled: bool,
2024 window: Option<AnyWindowHandle>,
2025 cx: &mut Context<Self>,
2026 ) {
2027 if self.all_tools_finished() {
2028 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2029 if !canceled {
2030 self.send_to_model(model.clone(), window, cx);
2031 }
2032 self.auto_capture_telemetry(cx);
2033 }
2034 }
2035
2036 cx.emit(ThreadEvent::ToolFinished {
2037 tool_use_id,
2038 pending_tool_use,
2039 });
2040 }
2041
2042 /// Cancels the last pending completion, if there are any pending.
2043 ///
2044 /// Returns whether a completion was canceled.
2045 pub fn cancel_last_completion(
2046 &mut self,
2047 window: Option<AnyWindowHandle>,
2048 cx: &mut Context<Self>,
2049 ) -> bool {
2050 let mut canceled = self.pending_completions.pop().is_some();
2051
2052 for pending_tool_use in self.tool_use.cancel_pending() {
2053 canceled = true;
2054 self.tool_finished(
2055 pending_tool_use.id.clone(),
2056 Some(pending_tool_use),
2057 true,
2058 window,
2059 cx,
2060 );
2061 }
2062
2063 self.finalize_pending_checkpoint(cx);
2064
2065 if canceled {
2066 cx.emit(ThreadEvent::CompletionCanceled);
2067 }
2068
2069 canceled
2070 }
2071
2072 /// Signals that any in-progress editing should be canceled.
2073 ///
2074 /// This method is used to notify listeners (like ActiveThread) that
2075 /// they should cancel any editing operations.
2076 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2077 cx.emit(ThreadEvent::CancelEditing);
2078 }
2079
2080 pub fn feedback(&self) -> Option<ThreadFeedback> {
2081 self.feedback
2082 }
2083
2084 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2085 self.message_feedback.get(&message_id).copied()
2086 }
2087
2088 pub fn report_message_feedback(
2089 &mut self,
2090 message_id: MessageId,
2091 feedback: ThreadFeedback,
2092 cx: &mut Context<Self>,
2093 ) -> Task<Result<()>> {
2094 if self.message_feedback.get(&message_id) == Some(&feedback) {
2095 return Task::ready(Ok(()));
2096 }
2097
2098 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2099 let serialized_thread = self.serialize(cx);
2100 let thread_id = self.id().clone();
2101 let client = self.project.read(cx).client();
2102
2103 let enabled_tool_names: Vec<String> = self
2104 .tools()
2105 .read(cx)
2106 .enabled_tools(cx)
2107 .iter()
2108 .map(|tool| tool.name().to_string())
2109 .collect();
2110
2111 self.message_feedback.insert(message_id, feedback);
2112
2113 cx.notify();
2114
2115 let message_content = self
2116 .message(message_id)
2117 .map(|msg| msg.to_string())
2118 .unwrap_or_default();
2119
2120 cx.background_spawn(async move {
2121 let final_project_snapshot = final_project_snapshot.await;
2122 let serialized_thread = serialized_thread.await?;
2123 let thread_data =
2124 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2125
2126 let rating = match feedback {
2127 ThreadFeedback::Positive => "positive",
2128 ThreadFeedback::Negative => "negative",
2129 };
2130 telemetry::event!(
2131 "Assistant Thread Rated",
2132 rating,
2133 thread_id,
2134 enabled_tool_names,
2135 message_id = message_id.0,
2136 message_content,
2137 thread_data,
2138 final_project_snapshot
2139 );
2140 client.telemetry().flush_events().await;
2141
2142 Ok(())
2143 })
2144 }
2145
2146 pub fn report_feedback(
2147 &mut self,
2148 feedback: ThreadFeedback,
2149 cx: &mut Context<Self>,
2150 ) -> Task<Result<()>> {
2151 let last_assistant_message_id = self
2152 .messages
2153 .iter()
2154 .rev()
2155 .find(|msg| msg.role == Role::Assistant)
2156 .map(|msg| msg.id);
2157
2158 if let Some(message_id) = last_assistant_message_id {
2159 self.report_message_feedback(message_id, feedback, cx)
2160 } else {
2161 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2162 let serialized_thread = self.serialize(cx);
2163 let thread_id = self.id().clone();
2164 let client = self.project.read(cx).client();
2165 self.feedback = Some(feedback);
2166 cx.notify();
2167
2168 cx.background_spawn(async move {
2169 let final_project_snapshot = final_project_snapshot.await;
2170 let serialized_thread = serialized_thread.await?;
2171 let thread_data = serde_json::to_value(serialized_thread)
2172 .unwrap_or_else(|_| serde_json::Value::Null);
2173
2174 let rating = match feedback {
2175 ThreadFeedback::Positive => "positive",
2176 ThreadFeedback::Negative => "negative",
2177 };
2178 telemetry::event!(
2179 "Assistant Thread Rated",
2180 rating,
2181 thread_id,
2182 thread_data,
2183 final_project_snapshot
2184 );
2185 client.telemetry().flush_events().await;
2186
2187 Ok(())
2188 })
2189 }
2190 }
2191
2192 /// Create a snapshot of the current project state including git information and unsaved buffers.
2193 fn project_snapshot(
2194 project: Entity<Project>,
2195 cx: &mut Context<Self>,
2196 ) -> Task<Arc<ProjectSnapshot>> {
2197 let git_store = project.read(cx).git_store().clone();
2198 let worktree_snapshots: Vec<_> = project
2199 .read(cx)
2200 .visible_worktrees(cx)
2201 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2202 .collect();
2203
2204 cx.spawn(async move |_, cx| {
2205 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2206
2207 let mut unsaved_buffers = Vec::new();
2208 cx.update(|app_cx| {
2209 let buffer_store = project.read(app_cx).buffer_store();
2210 for buffer_handle in buffer_store.read(app_cx).buffers() {
2211 let buffer = buffer_handle.read(app_cx);
2212 if buffer.is_dirty() {
2213 if let Some(file) = buffer.file() {
2214 let path = file.path().to_string_lossy().to_string();
2215 unsaved_buffers.push(path);
2216 }
2217 }
2218 }
2219 })
2220 .ok();
2221
2222 Arc::new(ProjectSnapshot {
2223 worktree_snapshots,
2224 unsaved_buffer_paths: unsaved_buffers,
2225 timestamp: Utc::now(),
2226 })
2227 })
2228 }
2229
2230 fn worktree_snapshot(
2231 worktree: Entity<project::Worktree>,
2232 git_store: Entity<GitStore>,
2233 cx: &App,
2234 ) -> Task<WorktreeSnapshot> {
2235 cx.spawn(async move |cx| {
2236 // Get worktree path and snapshot
2237 let worktree_info = cx.update(|app_cx| {
2238 let worktree = worktree.read(app_cx);
2239 let path = worktree.abs_path().to_string_lossy().to_string();
2240 let snapshot = worktree.snapshot();
2241 (path, snapshot)
2242 });
2243
2244 let Ok((worktree_path, _snapshot)) = worktree_info else {
2245 return WorktreeSnapshot {
2246 worktree_path: String::new(),
2247 git_state: None,
2248 };
2249 };
2250
2251 let git_state = git_store
2252 .update(cx, |git_store, cx| {
2253 git_store
2254 .repositories()
2255 .values()
2256 .find(|repo| {
2257 repo.read(cx)
2258 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2259 .is_some()
2260 })
2261 .cloned()
2262 })
2263 .ok()
2264 .flatten()
2265 .map(|repo| {
2266 repo.update(cx, |repo, _| {
2267 let current_branch =
2268 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2269 repo.send_job(None, |state, _| async move {
2270 let RepositoryState::Local { backend, .. } = state else {
2271 return GitState {
2272 remote_url: None,
2273 head_sha: None,
2274 current_branch,
2275 diff: None,
2276 };
2277 };
2278
2279 let remote_url = backend.remote_url("origin");
2280 let head_sha = backend.head_sha().await;
2281 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2282
2283 GitState {
2284 remote_url,
2285 head_sha,
2286 current_branch,
2287 diff,
2288 }
2289 })
2290 })
2291 });
2292
2293 let git_state = match git_state {
2294 Some(git_state) => match git_state.ok() {
2295 Some(git_state) => git_state.await.ok(),
2296 None => None,
2297 },
2298 None => None,
2299 };
2300
2301 WorktreeSnapshot {
2302 worktree_path,
2303 git_state,
2304 }
2305 })
2306 }
2307
2308 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2309 let mut markdown = Vec::new();
2310
2311 if let Some(summary) = self.summary() {
2312 writeln!(markdown, "# {summary}\n")?;
2313 };
2314
2315 for message in self.messages() {
2316 writeln!(
2317 markdown,
2318 "## {role}\n",
2319 role = match message.role {
2320 Role::User => "User",
2321 Role::Assistant => "Assistant",
2322 Role::System => "System",
2323 }
2324 )?;
2325
2326 if !message.loaded_context.text.is_empty() {
2327 writeln!(markdown, "{}", message.loaded_context.text)?;
2328 }
2329
2330 if !message.loaded_context.images.is_empty() {
2331 writeln!(
2332 markdown,
2333 "\n{} images attached as context.\n",
2334 message.loaded_context.images.len()
2335 )?;
2336 }
2337
2338 for segment in &message.segments {
2339 match segment {
2340 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2341 MessageSegment::Thinking { text, .. } => {
2342 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2343 }
2344 MessageSegment::RedactedThinking(_) => {}
2345 }
2346 }
2347
2348 for tool_use in self.tool_uses_for_message(message.id, cx) {
2349 writeln!(
2350 markdown,
2351 "**Use Tool: {} ({})**",
2352 tool_use.name, tool_use.id
2353 )?;
2354 writeln!(markdown, "```json")?;
2355 writeln!(
2356 markdown,
2357 "{}",
2358 serde_json::to_string_pretty(&tool_use.input)?
2359 )?;
2360 writeln!(markdown, "```")?;
2361 }
2362
2363 for tool_result in self.tool_results_for_message(message.id) {
2364 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2365 if tool_result.is_error {
2366 write!(markdown, " (Error)")?;
2367 }
2368
2369 writeln!(markdown, "**\n")?;
2370 writeln!(markdown, "{}", tool_result.content)?;
2371 }
2372 }
2373
2374 Ok(String::from_utf8_lossy(&markdown).to_string())
2375 }
2376
2377 pub fn keep_edits_in_range(
2378 &mut self,
2379 buffer: Entity<language::Buffer>,
2380 buffer_range: Range<language::Anchor>,
2381 cx: &mut Context<Self>,
2382 ) {
2383 self.action_log.update(cx, |action_log, cx| {
2384 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2385 });
2386 }
2387
2388 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2389 self.action_log
2390 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2391 }
2392
2393 pub fn reject_edits_in_ranges(
2394 &mut self,
2395 buffer: Entity<language::Buffer>,
2396 buffer_ranges: Vec<Range<language::Anchor>>,
2397 cx: &mut Context<Self>,
2398 ) -> Task<Result<()>> {
2399 self.action_log.update(cx, |action_log, cx| {
2400 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2401 })
2402 }
2403
2404 pub fn action_log(&self) -> &Entity<ActionLog> {
2405 &self.action_log
2406 }
2407
2408 pub fn project(&self) -> &Entity<Project> {
2409 &self.project
2410 }
2411
2412 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2413 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2414 return;
2415 }
2416
2417 let now = Instant::now();
2418 if let Some(last) = self.last_auto_capture_at {
2419 if now.duration_since(last).as_secs() < 10 {
2420 return;
2421 }
2422 }
2423
2424 self.last_auto_capture_at = Some(now);
2425
2426 let thread_id = self.id().clone();
2427 let github_login = self
2428 .project
2429 .read(cx)
2430 .user_store()
2431 .read(cx)
2432 .current_user()
2433 .map(|user| user.github_login.clone());
2434 let client = self.project.read(cx).client().clone();
2435 let serialize_task = self.serialize(cx);
2436
2437 cx.background_executor()
2438 .spawn(async move {
2439 if let Ok(serialized_thread) = serialize_task.await {
2440 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2441 telemetry::event!(
2442 "Agent Thread Auto-Captured",
2443 thread_id = thread_id.to_string(),
2444 thread_data = thread_data,
2445 auto_capture_reason = "tracked_user",
2446 github_login = github_login
2447 );
2448
2449 client.telemetry().flush_events().await;
2450 }
2451 }
2452 })
2453 .detach();
2454 }
2455
2456 pub fn cumulative_token_usage(&self) -> TokenUsage {
2457 self.cumulative_token_usage
2458 }
2459
2460 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2461 let Some(model) = self.configured_model.as_ref() else {
2462 return TotalTokenUsage::default();
2463 };
2464
2465 let max = model.model.max_token_count();
2466
2467 let index = self
2468 .messages
2469 .iter()
2470 .position(|msg| msg.id == message_id)
2471 .unwrap_or(0);
2472
2473 if index == 0 {
2474 return TotalTokenUsage { total: 0, max };
2475 }
2476
2477 let token_usage = &self
2478 .request_token_usage
2479 .get(index - 1)
2480 .cloned()
2481 .unwrap_or_default();
2482
2483 TotalTokenUsage {
2484 total: token_usage.total_tokens() as usize,
2485 max,
2486 }
2487 }
2488
2489 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2490 let model = self.configured_model.as_ref()?;
2491
2492 let max = model.model.max_token_count();
2493
2494 if let Some(exceeded_error) = &self.exceeded_window_error {
2495 if model.model.id() == exceeded_error.model_id {
2496 return Some(TotalTokenUsage {
2497 total: exceeded_error.token_count,
2498 max,
2499 });
2500 }
2501 }
2502
2503 let total = self
2504 .token_usage_at_last_message()
2505 .unwrap_or_default()
2506 .total_tokens() as usize;
2507
2508 Some(TotalTokenUsage { total, max })
2509 }
2510
2511 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2512 self.request_token_usage
2513 .get(self.messages.len().saturating_sub(1))
2514 .or_else(|| self.request_token_usage.last())
2515 .cloned()
2516 }
2517
2518 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2519 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2520 self.request_token_usage
2521 .resize(self.messages.len(), placeholder);
2522
2523 if let Some(last) = self.request_token_usage.last_mut() {
2524 *last = token_usage;
2525 }
2526 }
2527
2528 pub fn deny_tool_use(
2529 &mut self,
2530 tool_use_id: LanguageModelToolUseId,
2531 tool_name: Arc<str>,
2532 window: Option<AnyWindowHandle>,
2533 cx: &mut Context<Self>,
2534 ) {
2535 let err = Err(anyhow::anyhow!(
2536 "Permission to run tool action denied by user"
2537 ));
2538
2539 self.tool_use.insert_tool_output(
2540 tool_use_id.clone(),
2541 tool_name,
2542 err,
2543 self.configured_model.as_ref(),
2544 );
2545 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2546 }
2547}
2548
2549#[derive(Debug, Clone, Error)]
2550pub enum ThreadError {
2551 #[error("Payment required")]
2552 PaymentRequired,
2553 #[error("Max monthly spend reached")]
2554 MaxMonthlySpendReached,
2555 #[error("Model request limit reached")]
2556 ModelRequestLimitReached { plan: Plan },
2557 #[error("Message {header}: {message}")]
2558 Message {
2559 header: SharedString,
2560 message: SharedString,
2561 },
2562}
2563
2564#[derive(Debug, Clone)]
2565pub enum ThreadEvent {
2566 ShowError(ThreadError),
2567 StreamedCompletion,
2568 ReceivedTextChunk,
2569 NewRequest,
2570 StreamedAssistantText(MessageId, String),
2571 StreamedAssistantThinking(MessageId, String),
2572 StreamedToolUse {
2573 tool_use_id: LanguageModelToolUseId,
2574 ui_text: Arc<str>,
2575 input: serde_json::Value,
2576 },
2577 InvalidToolInput {
2578 tool_use_id: LanguageModelToolUseId,
2579 ui_text: Arc<str>,
2580 invalid_input_json: Arc<str>,
2581 },
2582 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2583 MessageAdded(MessageId),
2584 MessageEdited(MessageId),
2585 MessageDeleted(MessageId),
2586 SummaryGenerated,
2587 SummaryChanged,
2588 UsePendingTools {
2589 tool_uses: Vec<PendingToolUse>,
2590 },
2591 ToolFinished {
2592 #[allow(unused)]
2593 tool_use_id: LanguageModelToolUseId,
2594 /// The pending tool use that corresponds to this tool.
2595 pending_tool_use: Option<PendingToolUse>,
2596 },
2597 CheckpointChanged,
2598 ToolConfirmationNeeded,
2599 CancelEditing,
2600 CompletionCanceled,
2601}
2602
2603impl EventEmitter<ThreadEvent> for Thread {}
2604
2605struct PendingCompletion {
2606 id: usize,
2607 queue_state: QueueState,
2608 _task: Task<()>,
2609}
2610
2611#[cfg(test)]
2612mod tests {
2613 use super::*;
2614 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2615 use assistant_settings::AssistantSettings;
2616 use assistant_tool::ToolRegistry;
2617 use context_server::ContextServerSettings;
2618 use editor::EditorSettings;
2619 use gpui::TestAppContext;
2620 use language_model::fake_provider::FakeLanguageModel;
2621 use project::{FakeFs, Project};
2622 use prompt_store::PromptBuilder;
2623 use serde_json::json;
2624 use settings::{Settings, SettingsStore};
2625 use std::sync::Arc;
2626 use theme::ThemeSettings;
2627 use util::path;
2628 use workspace::Workspace;
2629
2630 #[gpui::test]
2631 async fn test_message_with_context(cx: &mut TestAppContext) {
2632 init_test_settings(cx);
2633
2634 let project = create_test_project(
2635 cx,
2636 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2637 )
2638 .await;
2639
2640 let (_workspace, _thread_store, thread, context_store, model) =
2641 setup_test_environment(cx, project.clone()).await;
2642
2643 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2644 .await
2645 .unwrap();
2646
2647 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2648 let loaded_context = cx
2649 .update(|cx| load_context(vec![context], &project, &None, cx))
2650 .await;
2651
2652 // Insert user message with context
2653 let message_id = thread.update(cx, |thread, cx| {
2654 thread.insert_user_message(
2655 "Please explain this code",
2656 loaded_context,
2657 None,
2658 Vec::new(),
2659 cx,
2660 )
2661 });
2662
2663 // Check content and context in message object
2664 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2665
2666 // Use different path format strings based on platform for the test
2667 #[cfg(windows)]
2668 let path_part = r"test\code.rs";
2669 #[cfg(not(windows))]
2670 let path_part = "test/code.rs";
2671
2672 let expected_context = format!(
2673 r#"
2674<context>
2675The following items were attached by the user. They are up-to-date and don't need to be re-read.
2676
2677<files>
2678```rs {path_part}
2679fn main() {{
2680 println!("Hello, world!");
2681}}
2682```
2683</files>
2684</context>
2685"#
2686 );
2687
2688 assert_eq!(message.role, Role::User);
2689 assert_eq!(message.segments.len(), 1);
2690 assert_eq!(
2691 message.segments[0],
2692 MessageSegment::Text("Please explain this code".to_string())
2693 );
2694 assert_eq!(message.loaded_context.text, expected_context);
2695
2696 // Check message in request
2697 let request = thread.update(cx, |thread, cx| {
2698 thread.to_completion_request(model.clone(), cx)
2699 });
2700
2701 assert_eq!(request.messages.len(), 2);
2702 let expected_full_message = format!("{}Please explain this code", expected_context);
2703 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2704 }
2705
2706 #[gpui::test]
2707 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2708 init_test_settings(cx);
2709
2710 let project = create_test_project(
2711 cx,
2712 json!({
2713 "file1.rs": "fn function1() {}\n",
2714 "file2.rs": "fn function2() {}\n",
2715 "file3.rs": "fn function3() {}\n",
2716 "file4.rs": "fn function4() {}\n",
2717 }),
2718 )
2719 .await;
2720
2721 let (_, _thread_store, thread, context_store, model) =
2722 setup_test_environment(cx, project.clone()).await;
2723
2724 // First message with context 1
2725 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2726 .await
2727 .unwrap();
2728 let new_contexts = context_store.update(cx, |store, cx| {
2729 store.new_context_for_thread(thread.read(cx), None)
2730 });
2731 assert_eq!(new_contexts.len(), 1);
2732 let loaded_context = cx
2733 .update(|cx| load_context(new_contexts, &project, &None, cx))
2734 .await;
2735 let message1_id = thread.update(cx, |thread, cx| {
2736 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2737 });
2738
2739 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2740 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2741 .await
2742 .unwrap();
2743 let new_contexts = context_store.update(cx, |store, cx| {
2744 store.new_context_for_thread(thread.read(cx), None)
2745 });
2746 assert_eq!(new_contexts.len(), 1);
2747 let loaded_context = cx
2748 .update(|cx| load_context(new_contexts, &project, &None, cx))
2749 .await;
2750 let message2_id = thread.update(cx, |thread, cx| {
2751 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2752 });
2753
2754 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2755 //
2756 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2757 .await
2758 .unwrap();
2759 let new_contexts = context_store.update(cx, |store, cx| {
2760 store.new_context_for_thread(thread.read(cx), None)
2761 });
2762 assert_eq!(new_contexts.len(), 1);
2763 let loaded_context = cx
2764 .update(|cx| load_context(new_contexts, &project, &None, cx))
2765 .await;
2766 let message3_id = thread.update(cx, |thread, cx| {
2767 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2768 });
2769
2770 // Check what contexts are included in each message
2771 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2772 (
2773 thread.message(message1_id).unwrap().clone(),
2774 thread.message(message2_id).unwrap().clone(),
2775 thread.message(message3_id).unwrap().clone(),
2776 )
2777 });
2778
2779 // First message should include context 1
2780 assert!(message1.loaded_context.text.contains("file1.rs"));
2781
2782 // Second message should include only context 2 (not 1)
2783 assert!(!message2.loaded_context.text.contains("file1.rs"));
2784 assert!(message2.loaded_context.text.contains("file2.rs"));
2785
2786 // Third message should include only context 3 (not 1 or 2)
2787 assert!(!message3.loaded_context.text.contains("file1.rs"));
2788 assert!(!message3.loaded_context.text.contains("file2.rs"));
2789 assert!(message3.loaded_context.text.contains("file3.rs"));
2790
2791 // Check entire request to make sure all contexts are properly included
2792 let request = thread.update(cx, |thread, cx| {
2793 thread.to_completion_request(model.clone(), cx)
2794 });
2795
2796 // The request should contain all 3 messages
2797 assert_eq!(request.messages.len(), 4);
2798
2799 // Check that the contexts are properly formatted in each message
2800 assert!(request.messages[1].string_contents().contains("file1.rs"));
2801 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2802 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2803
2804 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2805 assert!(request.messages[2].string_contents().contains("file2.rs"));
2806 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2807
2808 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2809 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2810 assert!(request.messages[3].string_contents().contains("file3.rs"));
2811
2812 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2813 .await
2814 .unwrap();
2815 let new_contexts = context_store.update(cx, |store, cx| {
2816 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2817 });
2818 assert_eq!(new_contexts.len(), 3);
2819 let loaded_context = cx
2820 .update(|cx| load_context(new_contexts, &project, &None, cx))
2821 .await
2822 .loaded_context;
2823
2824 assert!(!loaded_context.text.contains("file1.rs"));
2825 assert!(loaded_context.text.contains("file2.rs"));
2826 assert!(loaded_context.text.contains("file3.rs"));
2827 assert!(loaded_context.text.contains("file4.rs"));
2828
2829 let new_contexts = context_store.update(cx, |store, cx| {
2830 // Remove file4.rs
2831 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2832 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2833 });
2834 assert_eq!(new_contexts.len(), 2);
2835 let loaded_context = cx
2836 .update(|cx| load_context(new_contexts, &project, &None, cx))
2837 .await
2838 .loaded_context;
2839
2840 assert!(!loaded_context.text.contains("file1.rs"));
2841 assert!(loaded_context.text.contains("file2.rs"));
2842 assert!(loaded_context.text.contains("file3.rs"));
2843 assert!(!loaded_context.text.contains("file4.rs"));
2844
2845 let new_contexts = context_store.update(cx, |store, cx| {
2846 // Remove file3.rs
2847 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2848 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2849 });
2850 assert_eq!(new_contexts.len(), 1);
2851 let loaded_context = cx
2852 .update(|cx| load_context(new_contexts, &project, &None, cx))
2853 .await
2854 .loaded_context;
2855
2856 assert!(!loaded_context.text.contains("file1.rs"));
2857 assert!(loaded_context.text.contains("file2.rs"));
2858 assert!(!loaded_context.text.contains("file3.rs"));
2859 assert!(!loaded_context.text.contains("file4.rs"));
2860 }
2861
2862 #[gpui::test]
2863 async fn test_message_without_files(cx: &mut TestAppContext) {
2864 init_test_settings(cx);
2865
2866 let project = create_test_project(
2867 cx,
2868 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2869 )
2870 .await;
2871
2872 let (_, _thread_store, thread, _context_store, model) =
2873 setup_test_environment(cx, project.clone()).await;
2874
2875 // Insert user message without any context (empty context vector)
2876 let message_id = thread.update(cx, |thread, cx| {
2877 thread.insert_user_message(
2878 "What is the best way to learn Rust?",
2879 ContextLoadResult::default(),
2880 None,
2881 Vec::new(),
2882 cx,
2883 )
2884 });
2885
2886 // Check content and context in message object
2887 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2888
2889 // Context should be empty when no files are included
2890 assert_eq!(message.role, Role::User);
2891 assert_eq!(message.segments.len(), 1);
2892 assert_eq!(
2893 message.segments[0],
2894 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2895 );
2896 assert_eq!(message.loaded_context.text, "");
2897
2898 // Check message in request
2899 let request = thread.update(cx, |thread, cx| {
2900 thread.to_completion_request(model.clone(), cx)
2901 });
2902
2903 assert_eq!(request.messages.len(), 2);
2904 assert_eq!(
2905 request.messages[1].string_contents(),
2906 "What is the best way to learn Rust?"
2907 );
2908
2909 // Add second message, also without context
2910 let message2_id = thread.update(cx, |thread, cx| {
2911 thread.insert_user_message(
2912 "Are there any good books?",
2913 ContextLoadResult::default(),
2914 None,
2915 Vec::new(),
2916 cx,
2917 )
2918 });
2919
2920 let message2 =
2921 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2922 assert_eq!(message2.loaded_context.text, "");
2923
2924 // Check that both messages appear in the request
2925 let request = thread.update(cx, |thread, cx| {
2926 thread.to_completion_request(model.clone(), cx)
2927 });
2928
2929 assert_eq!(request.messages.len(), 3);
2930 assert_eq!(
2931 request.messages[1].string_contents(),
2932 "What is the best way to learn Rust?"
2933 );
2934 assert_eq!(
2935 request.messages[2].string_contents(),
2936 "Are there any good books?"
2937 );
2938 }
2939
2940 #[gpui::test]
2941 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2942 init_test_settings(cx);
2943
2944 let project = create_test_project(
2945 cx,
2946 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2947 )
2948 .await;
2949
2950 let (_workspace, _thread_store, thread, context_store, model) =
2951 setup_test_environment(cx, project.clone()).await;
2952
2953 // Open buffer and add it to context
2954 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2955 .await
2956 .unwrap();
2957
2958 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2959 let loaded_context = cx
2960 .update(|cx| load_context(vec![context], &project, &None, cx))
2961 .await;
2962
2963 // Insert user message with the buffer as context
2964 thread.update(cx, |thread, cx| {
2965 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2966 });
2967
2968 // Create a request and check that it doesn't have a stale buffer warning yet
2969 let initial_request = thread.update(cx, |thread, cx| {
2970 thread.to_completion_request(model.clone(), cx)
2971 });
2972
2973 // Make sure we don't have a stale file warning yet
2974 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2975 msg.string_contents()
2976 .contains("These files changed since last read:")
2977 });
2978 assert!(
2979 !has_stale_warning,
2980 "Should not have stale buffer warning before buffer is modified"
2981 );
2982
2983 // Modify the buffer
2984 buffer.update(cx, |buffer, cx| {
2985 // Find a position at the end of line 1
2986 buffer.edit(
2987 [(1..1, "\n println!(\"Added a new line\");\n")],
2988 None,
2989 cx,
2990 );
2991 });
2992
2993 // Insert another user message without context
2994 thread.update(cx, |thread, cx| {
2995 thread.insert_user_message(
2996 "What does the code do now?",
2997 ContextLoadResult::default(),
2998 None,
2999 Vec::new(),
3000 cx,
3001 )
3002 });
3003
3004 // Create a new request and check for the stale buffer warning
3005 let new_request = thread.update(cx, |thread, cx| {
3006 thread.to_completion_request(model.clone(), cx)
3007 });
3008
3009 // We should have a stale file warning as the last message
3010 let last_message = new_request
3011 .messages
3012 .last()
3013 .expect("Request should have messages");
3014
3015 // The last message should be the stale buffer notification
3016 assert_eq!(last_message.role, Role::User);
3017
3018 // Check the exact content of the message
3019 let expected_content = "These files changed since last read:\n- code.rs\n";
3020 assert_eq!(
3021 last_message.string_contents(),
3022 expected_content,
3023 "Last message should be exactly the stale buffer notification"
3024 );
3025 }
3026
3027 fn init_test_settings(cx: &mut TestAppContext) {
3028 cx.update(|cx| {
3029 let settings_store = SettingsStore::test(cx);
3030 cx.set_global(settings_store);
3031 language::init(cx);
3032 Project::init_settings(cx);
3033 AssistantSettings::register(cx);
3034 prompt_store::init(cx);
3035 thread_store::init(cx);
3036 workspace::init_settings(cx);
3037 language_model::init_settings(cx);
3038 ThemeSettings::register(cx);
3039 ContextServerSettings::register(cx);
3040 EditorSettings::register(cx);
3041 ToolRegistry::default_global(cx);
3042 });
3043 }
3044
3045 // Helper to create a test project with test files
3046 async fn create_test_project(
3047 cx: &mut TestAppContext,
3048 files: serde_json::Value,
3049 ) -> Entity<Project> {
3050 let fs = FakeFs::new(cx.executor());
3051 fs.insert_tree(path!("/test"), files).await;
3052 Project::test(fs, [path!("/test").as_ref()], cx).await
3053 }
3054
3055 async fn setup_test_environment(
3056 cx: &mut TestAppContext,
3057 project: Entity<Project>,
3058 ) -> (
3059 Entity<Workspace>,
3060 Entity<ThreadStore>,
3061 Entity<Thread>,
3062 Entity<ContextStore>,
3063 Arc<dyn LanguageModel>,
3064 ) {
3065 let (workspace, cx) =
3066 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3067
3068 let thread_store = cx
3069 .update(|_, cx| {
3070 ThreadStore::load(
3071 project.clone(),
3072 cx.new(|_| ToolWorkingSet::default()),
3073 None,
3074 Arc::new(PromptBuilder::new(None).unwrap()),
3075 cx,
3076 )
3077 })
3078 .await
3079 .unwrap();
3080
3081 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3082 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3083
3084 let model = FakeLanguageModel::default();
3085 let model: Arc<dyn LanguageModel> = Arc::new(model);
3086
3087 (workspace, thread_store, thread, context_store, model)
3088 }
3089
3090 async fn add_file_to_context(
3091 project: &Entity<Project>,
3092 context_store: &Entity<ContextStore>,
3093 path: &str,
3094 cx: &mut TestAppContext,
3095 ) -> Result<Entity<language::Buffer>> {
3096 let buffer_path = project
3097 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3098 .unwrap();
3099
3100 let buffer = project
3101 .update(cx, |project, cx| {
3102 project.open_buffer(buffer_path.clone(), cx)
3103 })
3104 .await
3105 .unwrap();
3106
3107 context_store.update(cx, |context_store, cx| {
3108 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3109 });
3110
3111 Ok(buffer)
3112 }
3113}