1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use util::{ResultExt as _, TryFutureExt as _, post_inc};
39use uuid::Uuid;
40use zed_llm_client::{CompletionMode, CompletionRequestStatus};
41
42use crate::ThreadStore;
43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
44use crate::thread_store::{
45 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
46 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
47};
48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
101#[derive(Clone, Debug)]
102pub struct MessageCrease {
103 pub range: Range<usize>,
104 pub metadata: CreaseMetadata,
105 /// None for a deserialized message, Some otherwise.
106 pub context: Option<AgentContextHandle>,
107}
108
109/// A message in a [`Thread`].
110#[derive(Debug, Clone)]
111pub struct Message {
112 pub id: MessageId,
113 pub role: Role,
114 pub segments: Vec<MessageSegment>,
115 pub loaded_context: LoadedContext,
116 pub creases: Vec<MessageCrease>,
117}
118
119impl Message {
120 /// Returns whether the message contains any meaningful text that should be displayed
121 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
122 pub fn should_display_content(&self) -> bool {
123 self.segments.iter().all(|segment| segment.should_display())
124 }
125
126 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
127 if let Some(MessageSegment::Thinking {
128 text: segment,
129 signature: current_signature,
130 }) = self.segments.last_mut()
131 {
132 if let Some(signature) = signature {
133 *current_signature = Some(signature);
134 }
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Thinking {
138 text: text.to_string(),
139 signature,
140 });
141 }
142 }
143
144 pub fn push_text(&mut self, text: &str) {
145 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
146 segment.push_str(text);
147 } else {
148 self.segments.push(MessageSegment::Text(text.to_string()));
149 }
150 }
151
152 pub fn to_string(&self) -> String {
153 let mut result = String::new();
154
155 if !self.loaded_context.text.is_empty() {
156 result.push_str(&self.loaded_context.text);
157 }
158
159 for segment in &self.segments {
160 match segment {
161 MessageSegment::Text(text) => result.push_str(text),
162 MessageSegment::Thinking { text, .. } => {
163 result.push_str("<think>\n");
164 result.push_str(text);
165 result.push_str("\n</think>");
166 }
167 MessageSegment::RedactedThinking(_) => {}
168 }
169 }
170
171 result
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum MessageSegment {
177 Text(String),
178 Thinking {
179 text: String,
180 signature: Option<String>,
181 },
182 RedactedThinking(Vec<u8>),
183}
184
185impl MessageSegment {
186 pub fn should_display(&self) -> bool {
187 match self {
188 Self::Text(text) => text.is_empty(),
189 Self::Thinking { text, .. } => text.is_empty(),
190 Self::RedactedThinking(_) => false,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ProjectSnapshot {
197 pub worktree_snapshots: Vec<WorktreeSnapshot>,
198 pub unsaved_buffer_paths: Vec<String>,
199 pub timestamp: DateTime<Utc>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WorktreeSnapshot {
204 pub worktree_path: String,
205 pub git_state: Option<GitState>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GitState {
210 pub remote_url: Option<String>,
211 pub head_sha: Option<String>,
212 pub current_branch: Option<String>,
213 pub diff: Option<String>,
214}
215
216#[derive(Clone)]
217pub struct ThreadCheckpoint {
218 message_id: MessageId,
219 git_checkpoint: GitStoreCheckpoint,
220}
221
222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
223pub enum ThreadFeedback {
224 Positive,
225 Negative,
226}
227
228pub enum LastRestoreCheckpoint {
229 Pending {
230 message_id: MessageId,
231 },
232 Error {
233 message_id: MessageId,
234 error: String,
235 },
236}
237
238impl LastRestoreCheckpoint {
239 pub fn message_id(&self) -> MessageId {
240 match self {
241 LastRestoreCheckpoint::Pending { message_id } => *message_id,
242 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub enum DetailedSummaryState {
249 #[default]
250 NotGenerated,
251 Generating {
252 message_id: MessageId,
253 },
254 Generated {
255 text: SharedString,
256 message_id: MessageId,
257 },
258}
259
260impl DetailedSummaryState {
261 fn text(&self) -> Option<SharedString> {
262 if let Self::Generated { text, .. } = self {
263 Some(text.clone())
264 } else {
265 None
266 }
267 }
268}
269
270#[derive(Default)]
271pub struct TotalTokenUsage {
272 pub total: usize,
273 pub max: usize,
274}
275
276impl TotalTokenUsage {
277 pub fn ratio(&self) -> TokenUsageRatio {
278 #[cfg(debug_assertions)]
279 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
280 .unwrap_or("0.8".to_string())
281 .parse()
282 .unwrap();
283 #[cfg(not(debug_assertions))]
284 let warning_threshold: f32 = 0.8;
285
286 // When the maximum is unknown because there is no selected model,
287 // avoid showing the token limit warning.
288 if self.max == 0 {
289 TokenUsageRatio::Normal
290 } else if self.total >= self.max {
291 TokenUsageRatio::Exceeded
292 } else if self.total as f32 / self.max as f32 >= warning_threshold {
293 TokenUsageRatio::Warning
294 } else {
295 TokenUsageRatio::Normal
296 }
297 }
298
299 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
300 TotalTokenUsage {
301 total: self.total + tokens,
302 max: self.max,
303 }
304 }
305}
306
307#[derive(Debug, Default, PartialEq, Eq)]
308pub enum TokenUsageRatio {
309 #[default]
310 Normal,
311 Warning,
312 Exceeded,
313}
314
315fn default_completion_mode(cx: &App) -> CompletionMode {
316 if cx.is_staff() {
317 CompletionMode::Max
318 } else {
319 CompletionMode::Normal
320 }
321}
322
323#[derive(Debug, Clone, Copy)]
324pub enum QueueState {
325 Sending,
326 Queued { position: usize },
327 Started,
328}
329
330/// A thread of conversation with the LLM.
331pub struct Thread {
332 id: ThreadId,
333 updated_at: DateTime<Utc>,
334 summary: Option<SharedString>,
335 pending_summary: Task<Option<()>>,
336 detailed_summary_task: Task<Option<()>>,
337 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
338 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
339 completion_mode: CompletionMode,
340 messages: Vec<Message>,
341 next_message_id: MessageId,
342 last_prompt_id: PromptId,
343 project_context: SharedProjectContext,
344 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
345 completion_count: usize,
346 pending_completions: Vec<PendingCompletion>,
347 project: Entity<Project>,
348 prompt_builder: Arc<PromptBuilder>,
349 tools: Entity<ToolWorkingSet>,
350 tool_use: ToolUseState,
351 action_log: Entity<ActionLog>,
352 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
353 pending_checkpoint: Option<ThreadCheckpoint>,
354 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
355 request_token_usage: Vec<TokenUsage>,
356 cumulative_token_usage: TokenUsage,
357 exceeded_window_error: Option<ExceededWindowError>,
358 tool_use_limit_reached: bool,
359 feedback: Option<ThreadFeedback>,
360 message_feedback: HashMap<MessageId, ThreadFeedback>,
361 last_auto_capture_at: Option<Instant>,
362 last_received_chunk_at: Option<Instant>,
363 request_callback: Option<
364 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
365 >,
366 remaining_turns: u32,
367 configured_model: Option<ConfiguredModel>,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct ExceededWindowError {
372 /// Model used when last message exceeded context window
373 model_id: LanguageModelId,
374 /// Token count including last message
375 token_count: usize,
376}
377
378impl Thread {
379 pub fn new(
380 project: Entity<Project>,
381 tools: Entity<ToolWorkingSet>,
382 prompt_builder: Arc<PromptBuilder>,
383 system_prompt: SharedProjectContext,
384 cx: &mut Context<Self>,
385 ) -> Self {
386 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
387 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
388
389 Self {
390 id: ThreadId::new(),
391 updated_at: Utc::now(),
392 summary: None,
393 pending_summary: Task::ready(None),
394 detailed_summary_task: Task::ready(None),
395 detailed_summary_tx,
396 detailed_summary_rx,
397 completion_mode: default_completion_mode(cx),
398 messages: Vec::new(),
399 next_message_id: MessageId(0),
400 last_prompt_id: PromptId::new(),
401 project_context: system_prompt,
402 checkpoints_by_message: HashMap::default(),
403 completion_count: 0,
404 pending_completions: Vec::new(),
405 project: project.clone(),
406 prompt_builder,
407 tools: tools.clone(),
408 last_restore_checkpoint: None,
409 pending_checkpoint: None,
410 tool_use: ToolUseState::new(tools.clone()),
411 action_log: cx.new(|_| ActionLog::new(project.clone())),
412 initial_project_snapshot: {
413 let project_snapshot = Self::project_snapshot(project, cx);
414 cx.foreground_executor()
415 .spawn(async move { Some(project_snapshot.await) })
416 .shared()
417 },
418 request_token_usage: Vec::new(),
419 cumulative_token_usage: TokenUsage::default(),
420 exceeded_window_error: None,
421 tool_use_limit_reached: false,
422 feedback: None,
423 message_feedback: HashMap::default(),
424 last_auto_capture_at: None,
425 last_received_chunk_at: None,
426 request_callback: None,
427 remaining_turns: u32::MAX,
428 configured_model,
429 }
430 }
431
432 pub fn deserialize(
433 id: ThreadId,
434 serialized: SerializedThread,
435 project: Entity<Project>,
436 tools: Entity<ToolWorkingSet>,
437 prompt_builder: Arc<PromptBuilder>,
438 project_context: SharedProjectContext,
439 cx: &mut Context<Self>,
440 ) -> Self {
441 let next_message_id = MessageId(
442 serialized
443 .messages
444 .last()
445 .map(|message| message.id.0 + 1)
446 .unwrap_or(0),
447 );
448 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
449 let (detailed_summary_tx, detailed_summary_rx) =
450 postage::watch::channel_with(serialized.detailed_summary_state);
451
452 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
453 serialized
454 .model
455 .and_then(|model| {
456 let model = SelectedModel {
457 provider: model.provider.clone().into(),
458 model: model.model.clone().into(),
459 };
460 registry.select_model(&model, cx)
461 })
462 .or_else(|| registry.default_model())
463 });
464
465 Self {
466 id,
467 updated_at: serialized.updated_at,
468 summary: Some(serialized.summary),
469 pending_summary: Task::ready(None),
470 detailed_summary_task: Task::ready(None),
471 detailed_summary_tx,
472 detailed_summary_rx,
473 completion_mode: default_completion_mode(cx),
474 messages: serialized
475 .messages
476 .into_iter()
477 .map(|message| Message {
478 id: message.id,
479 role: message.role,
480 segments: message
481 .segments
482 .into_iter()
483 .map(|segment| match segment {
484 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
485 SerializedMessageSegment::Thinking { text, signature } => {
486 MessageSegment::Thinking { text, signature }
487 }
488 SerializedMessageSegment::RedactedThinking { data } => {
489 MessageSegment::RedactedThinking(data)
490 }
491 })
492 .collect(),
493 loaded_context: LoadedContext {
494 contexts: Vec::new(),
495 text: message.context,
496 images: Vec::new(),
497 },
498 creases: message
499 .creases
500 .into_iter()
501 .map(|crease| MessageCrease {
502 range: crease.start..crease.end,
503 metadata: CreaseMetadata {
504 icon_path: crease.icon_path,
505 label: crease.label,
506 },
507 context: None,
508 })
509 .collect(),
510 })
511 .collect(),
512 next_message_id,
513 last_prompt_id: PromptId::new(),
514 project_context,
515 checkpoints_by_message: HashMap::default(),
516 completion_count: 0,
517 pending_completions: Vec::new(),
518 last_restore_checkpoint: None,
519 pending_checkpoint: None,
520 project: project.clone(),
521 prompt_builder,
522 tools,
523 tool_use,
524 action_log: cx.new(|_| ActionLog::new(project)),
525 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
526 request_token_usage: serialized.request_token_usage,
527 cumulative_token_usage: serialized.cumulative_token_usage,
528 exceeded_window_error: None,
529 tool_use_limit_reached: false,
530 feedback: None,
531 message_feedback: HashMap::default(),
532 last_auto_capture_at: None,
533 last_received_chunk_at: None,
534 request_callback: None,
535 remaining_turns: u32::MAX,
536 configured_model,
537 }
538 }
539
540 pub fn set_request_callback(
541 &mut self,
542 callback: impl 'static
543 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
544 ) {
545 self.request_callback = Some(Box::new(callback));
546 }
547
548 pub fn id(&self) -> &ThreadId {
549 &self.id
550 }
551
552 pub fn is_empty(&self) -> bool {
553 self.messages.is_empty()
554 }
555
556 pub fn updated_at(&self) -> DateTime<Utc> {
557 self.updated_at
558 }
559
560 pub fn touch_updated_at(&mut self) {
561 self.updated_at = Utc::now();
562 }
563
564 pub fn advance_prompt_id(&mut self) {
565 self.last_prompt_id = PromptId::new();
566 }
567
568 pub fn summary(&self) -> Option<SharedString> {
569 self.summary.clone()
570 }
571
572 pub fn project_context(&self) -> SharedProjectContext {
573 self.project_context.clone()
574 }
575
576 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
577 if self.configured_model.is_none() {
578 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
579 }
580 self.configured_model.clone()
581 }
582
583 pub fn configured_model(&self) -> Option<ConfiguredModel> {
584 self.configured_model.clone()
585 }
586
587 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
588 self.configured_model = model;
589 cx.notify();
590 }
591
592 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
593
594 pub fn summary_or_default(&self) -> SharedString {
595 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
596 }
597
598 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
599 let Some(current_summary) = &self.summary else {
600 // Don't allow setting summary until generated
601 return;
602 };
603
604 let mut new_summary = new_summary.into();
605
606 if new_summary.is_empty() {
607 new_summary = Self::DEFAULT_SUMMARY;
608 }
609
610 if current_summary != &new_summary {
611 self.summary = Some(new_summary);
612 cx.emit(ThreadEvent::SummaryChanged);
613 }
614 }
615
616 pub fn completion_mode(&self) -> CompletionMode {
617 self.completion_mode
618 }
619
620 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
621 self.completion_mode = mode;
622 }
623
624 pub fn message(&self, id: MessageId) -> Option<&Message> {
625 let index = self
626 .messages
627 .binary_search_by(|message| message.id.cmp(&id))
628 .ok()?;
629
630 self.messages.get(index)
631 }
632
633 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
634 self.messages.iter()
635 }
636
637 pub fn is_generating(&self) -> bool {
638 !self.pending_completions.is_empty() || !self.all_tools_finished()
639 }
640
641 /// Indicates whether streaming of language model events is stale.
642 /// When `is_generating()` is false, this method returns `None`.
643 pub fn is_generation_stale(&self) -> Option<bool> {
644 const STALE_THRESHOLD: u128 = 250;
645
646 self.last_received_chunk_at
647 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
648 }
649
650 fn received_chunk(&mut self) {
651 self.last_received_chunk_at = Some(Instant::now());
652 }
653
654 pub fn queue_state(&self) -> Option<QueueState> {
655 self.pending_completions
656 .first()
657 .map(|pending_completion| pending_completion.queue_state)
658 }
659
660 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
661 &self.tools
662 }
663
664 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
665 self.tool_use
666 .pending_tool_uses()
667 .into_iter()
668 .find(|tool_use| &tool_use.id == id)
669 }
670
671 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
672 self.tool_use
673 .pending_tool_uses()
674 .into_iter()
675 .filter(|tool_use| tool_use.status.needs_confirmation())
676 }
677
678 pub fn has_pending_tool_uses(&self) -> bool {
679 !self.tool_use.pending_tool_uses().is_empty()
680 }
681
682 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
683 self.checkpoints_by_message.get(&id).cloned()
684 }
685
686 pub fn restore_checkpoint(
687 &mut self,
688 checkpoint: ThreadCheckpoint,
689 cx: &mut Context<Self>,
690 ) -> Task<Result<()>> {
691 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
692 message_id: checkpoint.message_id,
693 });
694 cx.emit(ThreadEvent::CheckpointChanged);
695 cx.notify();
696
697 let git_store = self.project().read(cx).git_store().clone();
698 let restore = git_store.update(cx, |git_store, cx| {
699 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
700 });
701
702 cx.spawn(async move |this, cx| {
703 let result = restore.await;
704 this.update(cx, |this, cx| {
705 if let Err(err) = result.as_ref() {
706 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
707 message_id: checkpoint.message_id,
708 error: err.to_string(),
709 });
710 } else {
711 this.truncate(checkpoint.message_id, cx);
712 this.last_restore_checkpoint = None;
713 }
714 this.pending_checkpoint = None;
715 cx.emit(ThreadEvent::CheckpointChanged);
716 cx.notify();
717 })?;
718 result
719 })
720 }
721
722 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
723 let pending_checkpoint = if self.is_generating() {
724 return;
725 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
726 checkpoint
727 } else {
728 return;
729 };
730
731 let git_store = self.project.read(cx).git_store().clone();
732 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
733 cx.spawn(async move |this, cx| match final_checkpoint.await {
734 Ok(final_checkpoint) => {
735 let equal = git_store
736 .update(cx, |store, cx| {
737 store.compare_checkpoints(
738 pending_checkpoint.git_checkpoint.clone(),
739 final_checkpoint.clone(),
740 cx,
741 )
742 })?
743 .await
744 .unwrap_or(false);
745
746 if !equal {
747 this.update(cx, |this, cx| {
748 this.insert_checkpoint(pending_checkpoint, cx)
749 })?;
750 }
751
752 Ok(())
753 }
754 Err(_) => this.update(cx, |this, cx| {
755 this.insert_checkpoint(pending_checkpoint, cx)
756 }),
757 })
758 .detach();
759 }
760
761 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
762 self.checkpoints_by_message
763 .insert(checkpoint.message_id, checkpoint);
764 cx.emit(ThreadEvent::CheckpointChanged);
765 cx.notify();
766 }
767
768 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
769 self.last_restore_checkpoint.as_ref()
770 }
771
772 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
773 let Some(message_ix) = self
774 .messages
775 .iter()
776 .rposition(|message| message.id == message_id)
777 else {
778 return;
779 };
780 for deleted_message in self.messages.drain(message_ix..) {
781 self.checkpoints_by_message.remove(&deleted_message.id);
782 }
783 cx.notify();
784 }
785
786 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
787 self.messages
788 .iter()
789 .find(|message| message.id == id)
790 .into_iter()
791 .flat_map(|message| message.loaded_context.contexts.iter())
792 }
793
794 pub fn is_turn_end(&self, ix: usize) -> bool {
795 if self.messages.is_empty() {
796 return false;
797 }
798
799 if !self.is_generating() && ix == self.messages.len() - 1 {
800 return true;
801 }
802
803 let Some(message) = self.messages.get(ix) else {
804 return false;
805 };
806
807 if message.role != Role::Assistant {
808 return false;
809 }
810
811 self.messages
812 .get(ix + 1)
813 .and_then(|message| {
814 self.message(message.id)
815 .map(|next_message| next_message.role == Role::User)
816 })
817 .unwrap_or(false)
818 }
819
820 pub fn tool_use_limit_reached(&self) -> bool {
821 self.tool_use_limit_reached
822 }
823
824 /// Returns whether all of the tool uses have finished running.
825 pub fn all_tools_finished(&self) -> bool {
826 // If the only pending tool uses left are the ones with errors, then
827 // that means that we've finished running all of the pending tools.
828 self.tool_use
829 .pending_tool_uses()
830 .iter()
831 .all(|tool_use| tool_use.status.is_error())
832 }
833
834 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
835 self.tool_use.tool_uses_for_message(id, cx)
836 }
837
838 pub fn tool_results_for_message(
839 &self,
840 assistant_message_id: MessageId,
841 ) -> Vec<&LanguageModelToolResult> {
842 self.tool_use.tool_results_for_message(assistant_message_id)
843 }
844
845 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
846 self.tool_use.tool_result(id)
847 }
848
849 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
850 Some(&self.tool_use.tool_result(id)?.content)
851 }
852
853 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
854 self.tool_use.tool_result_card(id).cloned()
855 }
856
857 /// Return tools that are both enabled and supported by the model
858 pub fn available_tools(
859 &self,
860 cx: &App,
861 model: Arc<dyn LanguageModel>,
862 ) -> Vec<LanguageModelRequestTool> {
863 if model.supports_tools() {
864 self.tools()
865 .read(cx)
866 .enabled_tools(cx)
867 .into_iter()
868 .filter_map(|tool| {
869 // Skip tools that cannot be supported
870 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
871 Some(LanguageModelRequestTool {
872 name: tool.name(),
873 description: tool.description(),
874 input_schema,
875 })
876 })
877 .collect()
878 } else {
879 Vec::default()
880 }
881 }
882
883 pub fn insert_user_message(
884 &mut self,
885 text: impl Into<String>,
886 loaded_context: ContextLoadResult,
887 git_checkpoint: Option<GitStoreCheckpoint>,
888 creases: Vec<MessageCrease>,
889 cx: &mut Context<Self>,
890 ) -> MessageId {
891 if !loaded_context.referenced_buffers.is_empty() {
892 self.action_log.update(cx, |log, cx| {
893 for buffer in loaded_context.referenced_buffers {
894 log.track_buffer(buffer, cx);
895 }
896 });
897 }
898
899 let message_id = self.insert_message(
900 Role::User,
901 vec![MessageSegment::Text(text.into())],
902 loaded_context.loaded_context,
903 creases,
904 cx,
905 );
906
907 if let Some(git_checkpoint) = git_checkpoint {
908 self.pending_checkpoint = Some(ThreadCheckpoint {
909 message_id,
910 git_checkpoint,
911 });
912 }
913
914 self.auto_capture_telemetry(cx);
915
916 message_id
917 }
918
919 pub fn insert_assistant_message(
920 &mut self,
921 segments: Vec<MessageSegment>,
922 cx: &mut Context<Self>,
923 ) -> MessageId {
924 self.insert_message(
925 Role::Assistant,
926 segments,
927 LoadedContext::default(),
928 Vec::new(),
929 cx,
930 )
931 }
932
933 pub fn insert_message(
934 &mut self,
935 role: Role,
936 segments: Vec<MessageSegment>,
937 loaded_context: LoadedContext,
938 creases: Vec<MessageCrease>,
939 cx: &mut Context<Self>,
940 ) -> MessageId {
941 let id = self.next_message_id.post_inc();
942 self.messages.push(Message {
943 id,
944 role,
945 segments,
946 loaded_context,
947 creases,
948 });
949 self.touch_updated_at();
950 cx.emit(ThreadEvent::MessageAdded(id));
951 id
952 }
953
954 pub fn edit_message(
955 &mut self,
956 id: MessageId,
957 new_role: Role,
958 new_segments: Vec<MessageSegment>,
959 loaded_context: Option<LoadedContext>,
960 cx: &mut Context<Self>,
961 ) -> bool {
962 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
963 return false;
964 };
965 message.role = new_role;
966 message.segments = new_segments;
967 if let Some(context) = loaded_context {
968 message.loaded_context = context;
969 }
970 self.touch_updated_at();
971 cx.emit(ThreadEvent::MessageEdited(id));
972 true
973 }
974
975 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
976 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
977 return false;
978 };
979 self.messages.remove(index);
980 self.touch_updated_at();
981 cx.emit(ThreadEvent::MessageDeleted(id));
982 true
983 }
984
985 /// Returns the representation of this [`Thread`] in a textual form.
986 ///
987 /// This is the representation we use when attaching a thread as context to another thread.
988 pub fn text(&self) -> String {
989 let mut text = String::new();
990
991 for message in &self.messages {
992 text.push_str(match message.role {
993 language_model::Role::User => "User:",
994 language_model::Role::Assistant => "Assistant:",
995 language_model::Role::System => "System:",
996 });
997 text.push('\n');
998
999 for segment in &message.segments {
1000 match segment {
1001 MessageSegment::Text(content) => text.push_str(content),
1002 MessageSegment::Thinking { text: content, .. } => {
1003 text.push_str(&format!("<think>{}</think>", content))
1004 }
1005 MessageSegment::RedactedThinking(_) => {}
1006 }
1007 }
1008 text.push('\n');
1009 }
1010
1011 text
1012 }
1013
1014 /// Serializes this thread into a format for storage or telemetry.
1015 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1016 let initial_project_snapshot = self.initial_project_snapshot.clone();
1017 cx.spawn(async move |this, cx| {
1018 let initial_project_snapshot = initial_project_snapshot.await;
1019 this.read_with(cx, |this, cx| SerializedThread {
1020 version: SerializedThread::VERSION.to_string(),
1021 summary: this.summary_or_default(),
1022 updated_at: this.updated_at(),
1023 messages: this
1024 .messages()
1025 .map(|message| SerializedMessage {
1026 id: message.id,
1027 role: message.role,
1028 segments: message
1029 .segments
1030 .iter()
1031 .map(|segment| match segment {
1032 MessageSegment::Text(text) => {
1033 SerializedMessageSegment::Text { text: text.clone() }
1034 }
1035 MessageSegment::Thinking { text, signature } => {
1036 SerializedMessageSegment::Thinking {
1037 text: text.clone(),
1038 signature: signature.clone(),
1039 }
1040 }
1041 MessageSegment::RedactedThinking(data) => {
1042 SerializedMessageSegment::RedactedThinking {
1043 data: data.clone(),
1044 }
1045 }
1046 })
1047 .collect(),
1048 tool_uses: this
1049 .tool_uses_for_message(message.id, cx)
1050 .into_iter()
1051 .map(|tool_use| SerializedToolUse {
1052 id: tool_use.id,
1053 name: tool_use.name,
1054 input: tool_use.input,
1055 })
1056 .collect(),
1057 tool_results: this
1058 .tool_results_for_message(message.id)
1059 .into_iter()
1060 .map(|tool_result| SerializedToolResult {
1061 tool_use_id: tool_result.tool_use_id.clone(),
1062 is_error: tool_result.is_error,
1063 content: tool_result.content.clone(),
1064 })
1065 .collect(),
1066 context: message.loaded_context.text.clone(),
1067 creases: message
1068 .creases
1069 .iter()
1070 .map(|crease| SerializedCrease {
1071 start: crease.range.start,
1072 end: crease.range.end,
1073 icon_path: crease.metadata.icon_path.clone(),
1074 label: crease.metadata.label.clone(),
1075 })
1076 .collect(),
1077 })
1078 .collect(),
1079 initial_project_snapshot,
1080 cumulative_token_usage: this.cumulative_token_usage,
1081 request_token_usage: this.request_token_usage.clone(),
1082 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1083 exceeded_window_error: this.exceeded_window_error.clone(),
1084 model: this
1085 .configured_model
1086 .as_ref()
1087 .map(|model| SerializedLanguageModel {
1088 provider: model.provider.id().0.to_string(),
1089 model: model.model.id().0.to_string(),
1090 }),
1091 })
1092 })
1093 }
1094
1095 pub fn remaining_turns(&self) -> u32 {
1096 self.remaining_turns
1097 }
1098
1099 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1100 self.remaining_turns = remaining_turns;
1101 }
1102
1103 pub fn send_to_model(
1104 &mut self,
1105 model: Arc<dyn LanguageModel>,
1106 window: Option<AnyWindowHandle>,
1107 cx: &mut Context<Self>,
1108 ) {
1109 if self.remaining_turns == 0 {
1110 return;
1111 }
1112
1113 self.remaining_turns -= 1;
1114
1115 let request = self.to_completion_request(model.clone(), cx);
1116
1117 self.stream_completion(request, model, window, cx);
1118 }
1119
1120 pub fn used_tools_since_last_user_message(&self) -> bool {
1121 for message in self.messages.iter().rev() {
1122 if self.tool_use.message_has_tool_results(message.id) {
1123 return true;
1124 } else if message.role == Role::User {
1125 return false;
1126 }
1127 }
1128
1129 false
1130 }
1131
1132 pub fn to_completion_request(
1133 &self,
1134 model: Arc<dyn LanguageModel>,
1135 cx: &mut Context<Self>,
1136 ) -> LanguageModelRequest {
1137 let mut request = LanguageModelRequest {
1138 thread_id: Some(self.id.to_string()),
1139 prompt_id: Some(self.last_prompt_id.to_string()),
1140 mode: None,
1141 messages: vec![],
1142 tools: Vec::new(),
1143 stop: Vec::new(),
1144 temperature: None,
1145 };
1146
1147 let available_tools = self.available_tools(cx, model.clone());
1148 let available_tool_names = available_tools
1149 .iter()
1150 .map(|tool| tool.name.clone())
1151 .collect();
1152
1153 let model_context = &ModelContext {
1154 available_tools: available_tool_names,
1155 };
1156
1157 if let Some(project_context) = self.project_context.borrow().as_ref() {
1158 match self
1159 .prompt_builder
1160 .generate_assistant_system_prompt(project_context, model_context)
1161 {
1162 Err(err) => {
1163 let message = format!("{err:?}").into();
1164 log::error!("{message}");
1165 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1166 header: "Error generating system prompt".into(),
1167 message,
1168 }));
1169 }
1170 Ok(system_prompt) => {
1171 request.messages.push(LanguageModelRequestMessage {
1172 role: Role::System,
1173 content: vec![MessageContent::Text(system_prompt)],
1174 cache: true,
1175 });
1176 }
1177 }
1178 } else {
1179 let message = "Context for system prompt unexpectedly not ready.".into();
1180 log::error!("{message}");
1181 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1182 header: "Error generating system prompt".into(),
1183 message,
1184 }));
1185 }
1186
1187 for message in &self.messages {
1188 let mut request_message = LanguageModelRequestMessage {
1189 role: message.role,
1190 content: Vec::new(),
1191 cache: false,
1192 };
1193
1194 message
1195 .loaded_context
1196 .add_to_request_message(&mut request_message);
1197
1198 for segment in &message.segments {
1199 match segment {
1200 MessageSegment::Text(text) => {
1201 if !text.is_empty() {
1202 request_message
1203 .content
1204 .push(MessageContent::Text(text.into()));
1205 }
1206 }
1207 MessageSegment::Thinking { text, signature } => {
1208 if !text.is_empty() {
1209 request_message.content.push(MessageContent::Thinking {
1210 text: text.into(),
1211 signature: signature.clone(),
1212 });
1213 }
1214 }
1215 MessageSegment::RedactedThinking(data) => {
1216 request_message
1217 .content
1218 .push(MessageContent::RedactedThinking(data.clone()));
1219 }
1220 };
1221 }
1222
1223 self.tool_use
1224 .attach_tool_uses(message.id, &mut request_message);
1225
1226 request.messages.push(request_message);
1227
1228 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1229 request.messages.push(tool_results_message);
1230 }
1231 }
1232
1233 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1234 if let Some(last) = request.messages.last_mut() {
1235 last.cache = true;
1236 }
1237
1238 self.attached_tracked_files_state(&mut request.messages, cx);
1239
1240 request.tools = available_tools;
1241 request.mode = if model.supports_max_mode() {
1242 Some(self.completion_mode)
1243 } else {
1244 Some(CompletionMode::Normal)
1245 };
1246
1247 request
1248 }
1249
1250 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1251 let mut request = LanguageModelRequest {
1252 thread_id: None,
1253 prompt_id: None,
1254 mode: None,
1255 messages: vec![],
1256 tools: Vec::new(),
1257 stop: Vec::new(),
1258 temperature: None,
1259 };
1260
1261 for message in &self.messages {
1262 let mut request_message = LanguageModelRequestMessage {
1263 role: message.role,
1264 content: Vec::new(),
1265 cache: false,
1266 };
1267
1268 for segment in &message.segments {
1269 match segment {
1270 MessageSegment::Text(text) => request_message
1271 .content
1272 .push(MessageContent::Text(text.clone())),
1273 MessageSegment::Thinking { .. } => {}
1274 MessageSegment::RedactedThinking(_) => {}
1275 }
1276 }
1277
1278 if request_message.content.is_empty() {
1279 continue;
1280 }
1281
1282 request.messages.push(request_message);
1283 }
1284
1285 request.messages.push(LanguageModelRequestMessage {
1286 role: Role::User,
1287 content: vec![MessageContent::Text(added_user_message)],
1288 cache: false,
1289 });
1290
1291 request
1292 }
1293
1294 fn attached_tracked_files_state(
1295 &self,
1296 messages: &mut Vec<LanguageModelRequestMessage>,
1297 cx: &App,
1298 ) {
1299 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1300
1301 let mut stale_message = String::new();
1302
1303 let action_log = self.action_log.read(cx);
1304
1305 for stale_file in action_log.stale_buffers(cx) {
1306 let Some(file) = stale_file.read(cx).file() else {
1307 continue;
1308 };
1309
1310 if stale_message.is_empty() {
1311 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1312 }
1313
1314 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1315 }
1316
1317 let mut content = Vec::with_capacity(2);
1318
1319 if !stale_message.is_empty() {
1320 content.push(stale_message.into());
1321 }
1322
1323 if !content.is_empty() {
1324 let context_message = LanguageModelRequestMessage {
1325 role: Role::User,
1326 content,
1327 cache: false,
1328 };
1329
1330 messages.push(context_message);
1331 }
1332 }
1333
1334 pub fn stream_completion(
1335 &mut self,
1336 request: LanguageModelRequest,
1337 model: Arc<dyn LanguageModel>,
1338 window: Option<AnyWindowHandle>,
1339 cx: &mut Context<Self>,
1340 ) {
1341 self.tool_use_limit_reached = false;
1342
1343 let pending_completion_id = post_inc(&mut self.completion_count);
1344 let mut request_callback_parameters = if self.request_callback.is_some() {
1345 Some((request.clone(), Vec::new()))
1346 } else {
1347 None
1348 };
1349 let prompt_id = self.last_prompt_id.clone();
1350 let tool_use_metadata = ToolUseMetadata {
1351 model: model.clone(),
1352 thread_id: self.id.clone(),
1353 prompt_id: prompt_id.clone(),
1354 };
1355
1356 self.last_received_chunk_at = Some(Instant::now());
1357
1358 let task = cx.spawn(async move |thread, cx| {
1359 let stream_completion_future = model.stream_completion(request, &cx);
1360 let initial_token_usage =
1361 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1362 let stream_completion = async {
1363 let mut events = stream_completion_future.await?;
1364
1365 let mut stop_reason = StopReason::EndTurn;
1366 let mut current_token_usage = TokenUsage::default();
1367
1368 thread
1369 .update(cx, |_thread, cx| {
1370 cx.emit(ThreadEvent::NewRequest);
1371 })
1372 .ok();
1373
1374 let mut request_assistant_message_id = None;
1375
1376 while let Some(event) = events.next().await {
1377 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1378 response_events
1379 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1380 }
1381
1382 thread.update(cx, |thread, cx| {
1383 let event = match event {
1384 Ok(event) => event,
1385 Err(LanguageModelCompletionError::BadInputJson {
1386 id,
1387 tool_name,
1388 raw_input: invalid_input_json,
1389 json_parse_error,
1390 }) => {
1391 thread.receive_invalid_tool_json(
1392 id,
1393 tool_name,
1394 invalid_input_json,
1395 json_parse_error,
1396 window,
1397 cx,
1398 );
1399 return Ok(());
1400 }
1401 Err(LanguageModelCompletionError::Other(error)) => {
1402 return Err(error);
1403 }
1404 };
1405
1406 match event {
1407 LanguageModelCompletionEvent::StartMessage { .. } => {
1408 request_assistant_message_id =
1409 Some(thread.insert_assistant_message(
1410 vec![MessageSegment::Text(String::new())],
1411 cx,
1412 ));
1413 }
1414 LanguageModelCompletionEvent::Stop(reason) => {
1415 stop_reason = reason;
1416 }
1417 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1418 thread.update_token_usage_at_last_message(token_usage);
1419 thread.cumulative_token_usage = thread.cumulative_token_usage
1420 + token_usage
1421 - current_token_usage;
1422 current_token_usage = token_usage;
1423 }
1424 LanguageModelCompletionEvent::Text(chunk) => {
1425 thread.received_chunk();
1426
1427 cx.emit(ThreadEvent::ReceivedTextChunk);
1428 if let Some(last_message) = thread.messages.last_mut() {
1429 if last_message.role == Role::Assistant
1430 && !thread.tool_use.has_tool_results(last_message.id)
1431 {
1432 last_message.push_text(&chunk);
1433 cx.emit(ThreadEvent::StreamedAssistantText(
1434 last_message.id,
1435 chunk,
1436 ));
1437 } else {
1438 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1439 // of a new Assistant response.
1440 //
1441 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1442 // will result in duplicating the text of the chunk in the rendered Markdown.
1443 request_assistant_message_id =
1444 Some(thread.insert_assistant_message(
1445 vec![MessageSegment::Text(chunk.to_string())],
1446 cx,
1447 ));
1448 };
1449 }
1450 }
1451 LanguageModelCompletionEvent::Thinking {
1452 text: chunk,
1453 signature,
1454 } => {
1455 thread.received_chunk();
1456
1457 if let Some(last_message) = thread.messages.last_mut() {
1458 if last_message.role == Role::Assistant
1459 && !thread.tool_use.has_tool_results(last_message.id)
1460 {
1461 last_message.push_thinking(&chunk, signature);
1462 cx.emit(ThreadEvent::StreamedAssistantThinking(
1463 last_message.id,
1464 chunk,
1465 ));
1466 } else {
1467 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1468 // of a new Assistant response.
1469 //
1470 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1471 // will result in duplicating the text of the chunk in the rendered Markdown.
1472 request_assistant_message_id =
1473 Some(thread.insert_assistant_message(
1474 vec![MessageSegment::Thinking {
1475 text: chunk.to_string(),
1476 signature,
1477 }],
1478 cx,
1479 ));
1480 };
1481 }
1482 }
1483 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1484 let last_assistant_message_id = request_assistant_message_id
1485 .unwrap_or_else(|| {
1486 let new_assistant_message_id =
1487 thread.insert_assistant_message(vec![], cx);
1488 request_assistant_message_id =
1489 Some(new_assistant_message_id);
1490 new_assistant_message_id
1491 });
1492
1493 let tool_use_id = tool_use.id.clone();
1494 let streamed_input = if tool_use.is_input_complete {
1495 None
1496 } else {
1497 Some((&tool_use.input).clone())
1498 };
1499
1500 let ui_text = thread.tool_use.request_tool_use(
1501 last_assistant_message_id,
1502 tool_use,
1503 tool_use_metadata.clone(),
1504 cx,
1505 );
1506
1507 if let Some(input) = streamed_input {
1508 cx.emit(ThreadEvent::StreamedToolUse {
1509 tool_use_id,
1510 ui_text,
1511 input,
1512 });
1513 }
1514 }
1515 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1516 if let Some(completion) = thread
1517 .pending_completions
1518 .iter_mut()
1519 .find(|completion| completion.id == pending_completion_id)
1520 {
1521 match status_update {
1522 CompletionRequestStatus::Queued {
1523 position,
1524 } => {
1525 completion.queue_state = QueueState::Queued { position };
1526 }
1527 CompletionRequestStatus::Started => {
1528 completion.queue_state = QueueState::Started;
1529 }
1530 CompletionRequestStatus::Failed {
1531 code, message
1532 } => {
1533 return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
1534 }
1535 CompletionRequestStatus::UsageUpdated {
1536 amount, limit
1537 } => {
1538 cx.emit(ThreadEvent::UsageUpdated(RequestUsage { limit, amount: amount as i32 }));
1539 }
1540 CompletionRequestStatus::ToolUseLimitReached => {
1541 thread.tool_use_limit_reached = true;
1542 }
1543 }
1544 }
1545 }
1546 }
1547
1548 thread.touch_updated_at();
1549 cx.emit(ThreadEvent::StreamedCompletion);
1550 cx.notify();
1551
1552 thread.auto_capture_telemetry(cx);
1553 Ok(())
1554 })??;
1555
1556 smol::future::yield_now().await;
1557 }
1558
1559 thread.update(cx, |thread, cx| {
1560 thread.last_received_chunk_at = None;
1561 thread
1562 .pending_completions
1563 .retain(|completion| completion.id != pending_completion_id);
1564
1565 // If there is a response without tool use, summarize the message. Otherwise,
1566 // allow two tool uses before summarizing.
1567 if thread.summary.is_none()
1568 && thread.messages.len() >= 2
1569 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1570 {
1571 thread.summarize(cx);
1572 }
1573 })?;
1574
1575 anyhow::Ok(stop_reason)
1576 };
1577
1578 let result = stream_completion.await;
1579
1580 thread
1581 .update(cx, |thread, cx| {
1582 thread.finalize_pending_checkpoint(cx);
1583 match result.as_ref() {
1584 Ok(stop_reason) => match stop_reason {
1585 StopReason::ToolUse => {
1586 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1587 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1588 }
1589 StopReason::EndTurn | StopReason::MaxTokens => {
1590 thread.project.update(cx, |project, cx| {
1591 project.set_agent_location(None, cx);
1592 });
1593 }
1594 },
1595 Err(error) => {
1596 thread.project.update(cx, |project, cx| {
1597 project.set_agent_location(None, cx);
1598 });
1599
1600 if error.is::<PaymentRequiredError>() {
1601 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1602 } else if error.is::<MaxMonthlySpendReachedError>() {
1603 cx.emit(ThreadEvent::ShowError(
1604 ThreadError::MaxMonthlySpendReached,
1605 ));
1606 } else if let Some(error) =
1607 error.downcast_ref::<ModelRequestLimitReachedError>()
1608 {
1609 cx.emit(ThreadEvent::ShowError(
1610 ThreadError::ModelRequestLimitReached { plan: error.plan },
1611 ));
1612 } else if let Some(known_error) =
1613 error.downcast_ref::<LanguageModelKnownError>()
1614 {
1615 match known_error {
1616 LanguageModelKnownError::ContextWindowLimitExceeded {
1617 tokens,
1618 } => {
1619 thread.exceeded_window_error = Some(ExceededWindowError {
1620 model_id: model.id(),
1621 token_count: *tokens,
1622 });
1623 cx.notify();
1624 }
1625 }
1626 } else {
1627 let error_message = error
1628 .chain()
1629 .map(|err| err.to_string())
1630 .collect::<Vec<_>>()
1631 .join("\n");
1632 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1633 header: "Error interacting with language model".into(),
1634 message: SharedString::from(error_message.clone()),
1635 }));
1636 }
1637
1638 thread.cancel_last_completion(window, cx);
1639 }
1640 }
1641 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1642
1643 if let Some((request_callback, (request, response_events))) = thread
1644 .request_callback
1645 .as_mut()
1646 .zip(request_callback_parameters.as_ref())
1647 {
1648 request_callback(request, response_events);
1649 }
1650
1651 thread.auto_capture_telemetry(cx);
1652
1653 if let Ok(initial_usage) = initial_token_usage {
1654 let usage = thread.cumulative_token_usage - initial_usage;
1655
1656 telemetry::event!(
1657 "Assistant Thread Completion",
1658 thread_id = thread.id().to_string(),
1659 prompt_id = prompt_id,
1660 model = model.telemetry_id(),
1661 model_provider = model.provider_id().to_string(),
1662 input_tokens = usage.input_tokens,
1663 output_tokens = usage.output_tokens,
1664 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1665 cache_read_input_tokens = usage.cache_read_input_tokens,
1666 );
1667 }
1668 })
1669 .ok();
1670 });
1671
1672 self.pending_completions.push(PendingCompletion {
1673 id: pending_completion_id,
1674 queue_state: QueueState::Sending,
1675 _task: task,
1676 });
1677 }
1678
1679 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1680 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1681 return;
1682 };
1683
1684 if !model.provider.is_authenticated(cx) {
1685 return;
1686 }
1687
1688 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1689 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1690 If the conversation is about a specific subject, include it in the title. \
1691 Be descriptive. DO NOT speak in the first person.";
1692
1693 let request = self.to_summarize_request(added_user_message.into());
1694
1695 self.pending_summary = cx.spawn(async move |this, cx| {
1696 async move {
1697 let mut messages = model.model.stream_completion(request, &cx).await?;
1698
1699 let mut new_summary = String::new();
1700 while let Some(event) = messages.next().await {
1701 let event = event?;
1702 let text = match event {
1703 LanguageModelCompletionEvent::Text(text) => text,
1704 LanguageModelCompletionEvent::StatusUpdate(
1705 CompletionRequestStatus::UsageUpdated { amount, limit },
1706 ) => {
1707 this.update(cx, |_, cx| {
1708 cx.emit(ThreadEvent::UsageUpdated(RequestUsage {
1709 limit,
1710 amount: amount as i32,
1711 }));
1712 })?;
1713 continue;
1714 }
1715 _ => continue,
1716 };
1717
1718 let mut lines = text.lines();
1719 new_summary.extend(lines.next());
1720
1721 // Stop if the LLM generated multiple lines.
1722 if lines.next().is_some() {
1723 break;
1724 }
1725 }
1726
1727 this.update(cx, |this, cx| {
1728 if !new_summary.is_empty() {
1729 this.summary = Some(new_summary.into());
1730 }
1731
1732 cx.emit(ThreadEvent::SummaryGenerated);
1733 })?;
1734
1735 anyhow::Ok(())
1736 }
1737 .log_err()
1738 .await
1739 });
1740 }
1741
1742 pub fn start_generating_detailed_summary_if_needed(
1743 &mut self,
1744 thread_store: WeakEntity<ThreadStore>,
1745 cx: &mut Context<Self>,
1746 ) {
1747 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1748 return;
1749 };
1750
1751 match &*self.detailed_summary_rx.borrow() {
1752 DetailedSummaryState::Generating { message_id, .. }
1753 | DetailedSummaryState::Generated { message_id, .. }
1754 if *message_id == last_message_id =>
1755 {
1756 // Already up-to-date
1757 return;
1758 }
1759 _ => {}
1760 }
1761
1762 let Some(ConfiguredModel { model, provider }) =
1763 LanguageModelRegistry::read_global(cx).thread_summary_model()
1764 else {
1765 return;
1766 };
1767
1768 if !provider.is_authenticated(cx) {
1769 return;
1770 }
1771
1772 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1773 1. A brief overview of what was discussed\n\
1774 2. Key facts or information discovered\n\
1775 3. Outcomes or conclusions reached\n\
1776 4. Any action items or next steps if any\n\
1777 Format it in Markdown with headings and bullet points.";
1778
1779 let request = self.to_summarize_request(added_user_message.into());
1780
1781 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1782 message_id: last_message_id,
1783 };
1784
1785 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1786 // be better to allow the old task to complete, but this would require logic for choosing
1787 // which result to prefer (the old task could complete after the new one, resulting in a
1788 // stale summary).
1789 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1790 let stream = model.stream_completion_text(request, &cx);
1791 let Some(mut messages) = stream.await.log_err() else {
1792 thread
1793 .update(cx, |thread, _cx| {
1794 *thread.detailed_summary_tx.borrow_mut() =
1795 DetailedSummaryState::NotGenerated;
1796 })
1797 .ok()?;
1798 return None;
1799 };
1800
1801 let mut new_detailed_summary = String::new();
1802
1803 while let Some(chunk) = messages.stream.next().await {
1804 if let Some(chunk) = chunk.log_err() {
1805 new_detailed_summary.push_str(&chunk);
1806 }
1807 }
1808
1809 thread
1810 .update(cx, |thread, _cx| {
1811 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1812 text: new_detailed_summary.into(),
1813 message_id: last_message_id,
1814 };
1815 })
1816 .ok()?;
1817
1818 // Save thread so its summary can be reused later
1819 if let Some(thread) = thread.upgrade() {
1820 if let Ok(Ok(save_task)) = cx.update(|cx| {
1821 thread_store
1822 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1823 }) {
1824 save_task.await.log_err();
1825 }
1826 }
1827
1828 Some(())
1829 });
1830 }
1831
1832 pub async fn wait_for_detailed_summary_or_text(
1833 this: &Entity<Self>,
1834 cx: &mut AsyncApp,
1835 ) -> Option<SharedString> {
1836 let mut detailed_summary_rx = this
1837 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1838 .ok()?;
1839 loop {
1840 match detailed_summary_rx.recv().await? {
1841 DetailedSummaryState::Generating { .. } => {}
1842 DetailedSummaryState::NotGenerated => {
1843 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1844 }
1845 DetailedSummaryState::Generated { text, .. } => return Some(text),
1846 }
1847 }
1848 }
1849
1850 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1851 self.detailed_summary_rx
1852 .borrow()
1853 .text()
1854 .unwrap_or_else(|| self.text().into())
1855 }
1856
1857 pub fn is_generating_detailed_summary(&self) -> bool {
1858 matches!(
1859 &*self.detailed_summary_rx.borrow(),
1860 DetailedSummaryState::Generating { .. }
1861 )
1862 }
1863
1864 pub fn use_pending_tools(
1865 &mut self,
1866 window: Option<AnyWindowHandle>,
1867 cx: &mut Context<Self>,
1868 model: Arc<dyn LanguageModel>,
1869 ) -> Vec<PendingToolUse> {
1870 self.auto_capture_telemetry(cx);
1871 let request = self.to_completion_request(model, cx);
1872 let messages = Arc::new(request.messages);
1873 let pending_tool_uses = self
1874 .tool_use
1875 .pending_tool_uses()
1876 .into_iter()
1877 .filter(|tool_use| tool_use.status.is_idle())
1878 .cloned()
1879 .collect::<Vec<_>>();
1880
1881 for tool_use in pending_tool_uses.iter() {
1882 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1883 if tool.needs_confirmation(&tool_use.input, cx)
1884 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1885 {
1886 self.tool_use.confirm_tool_use(
1887 tool_use.id.clone(),
1888 tool_use.ui_text.clone(),
1889 tool_use.input.clone(),
1890 messages.clone(),
1891 tool,
1892 );
1893 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1894 } else {
1895 self.run_tool(
1896 tool_use.id.clone(),
1897 tool_use.ui_text.clone(),
1898 tool_use.input.clone(),
1899 &messages,
1900 tool,
1901 window,
1902 cx,
1903 );
1904 }
1905 }
1906 }
1907
1908 pending_tool_uses
1909 }
1910
1911 pub fn receive_invalid_tool_json(
1912 &mut self,
1913 tool_use_id: LanguageModelToolUseId,
1914 tool_name: Arc<str>,
1915 invalid_json: Arc<str>,
1916 error: String,
1917 window: Option<AnyWindowHandle>,
1918 cx: &mut Context<Thread>,
1919 ) {
1920 log::error!("The model returned invalid input JSON: {invalid_json}");
1921
1922 let pending_tool_use = self.tool_use.insert_tool_output(
1923 tool_use_id.clone(),
1924 tool_name,
1925 Err(anyhow!("Error parsing input JSON: {error}")),
1926 self.configured_model.as_ref(),
1927 );
1928 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1929 pending_tool_use.ui_text.clone()
1930 } else {
1931 log::error!(
1932 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1933 );
1934 format!("Unknown tool {}", tool_use_id).into()
1935 };
1936
1937 cx.emit(ThreadEvent::InvalidToolInput {
1938 tool_use_id: tool_use_id.clone(),
1939 ui_text,
1940 invalid_input_json: invalid_json,
1941 });
1942
1943 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1944 }
1945
1946 pub fn run_tool(
1947 &mut self,
1948 tool_use_id: LanguageModelToolUseId,
1949 ui_text: impl Into<SharedString>,
1950 input: serde_json::Value,
1951 messages: &[LanguageModelRequestMessage],
1952 tool: Arc<dyn Tool>,
1953 window: Option<AnyWindowHandle>,
1954 cx: &mut Context<Thread>,
1955 ) {
1956 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1957 self.tool_use
1958 .run_pending_tool(tool_use_id, ui_text.into(), task);
1959 }
1960
1961 fn spawn_tool_use(
1962 &mut self,
1963 tool_use_id: LanguageModelToolUseId,
1964 messages: &[LanguageModelRequestMessage],
1965 input: serde_json::Value,
1966 tool: Arc<dyn Tool>,
1967 window: Option<AnyWindowHandle>,
1968 cx: &mut Context<Thread>,
1969 ) -> Task<()> {
1970 let tool_name: Arc<str> = tool.name().into();
1971
1972 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1973 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1974 } else {
1975 tool.run(
1976 input,
1977 messages,
1978 self.project.clone(),
1979 self.action_log.clone(),
1980 window,
1981 cx,
1982 )
1983 };
1984
1985 // Store the card separately if it exists
1986 if let Some(card) = tool_result.card.clone() {
1987 self.tool_use
1988 .insert_tool_result_card(tool_use_id.clone(), card);
1989 }
1990
1991 cx.spawn({
1992 async move |thread: WeakEntity<Thread>, cx| {
1993 let output = tool_result.output.await;
1994
1995 thread
1996 .update(cx, |thread, cx| {
1997 let pending_tool_use = thread.tool_use.insert_tool_output(
1998 tool_use_id.clone(),
1999 tool_name,
2000 output,
2001 thread.configured_model.as_ref(),
2002 );
2003 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2004 })
2005 .ok();
2006 }
2007 })
2008 }
2009
2010 fn tool_finished(
2011 &mut self,
2012 tool_use_id: LanguageModelToolUseId,
2013 pending_tool_use: Option<PendingToolUse>,
2014 canceled: bool,
2015 window: Option<AnyWindowHandle>,
2016 cx: &mut Context<Self>,
2017 ) {
2018 if self.all_tools_finished() {
2019 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2020 if !canceled {
2021 self.send_to_model(model.clone(), window, cx);
2022 }
2023 self.auto_capture_telemetry(cx);
2024 }
2025 }
2026
2027 cx.emit(ThreadEvent::ToolFinished {
2028 tool_use_id,
2029 pending_tool_use,
2030 });
2031 }
2032
2033 /// Cancels the last pending completion, if there are any pending.
2034 ///
2035 /// Returns whether a completion was canceled.
2036 pub fn cancel_last_completion(
2037 &mut self,
2038 window: Option<AnyWindowHandle>,
2039 cx: &mut Context<Self>,
2040 ) -> bool {
2041 let mut canceled = self.pending_completions.pop().is_some();
2042
2043 for pending_tool_use in self.tool_use.cancel_pending() {
2044 canceled = true;
2045 self.tool_finished(
2046 pending_tool_use.id.clone(),
2047 Some(pending_tool_use),
2048 true,
2049 window,
2050 cx,
2051 );
2052 }
2053
2054 self.finalize_pending_checkpoint(cx);
2055
2056 if canceled {
2057 cx.emit(ThreadEvent::CompletionCanceled);
2058 }
2059
2060 canceled
2061 }
2062
2063 /// Signals that any in-progress editing should be canceled.
2064 ///
2065 /// This method is used to notify listeners (like ActiveThread) that
2066 /// they should cancel any editing operations.
2067 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2068 cx.emit(ThreadEvent::CancelEditing);
2069 }
2070
2071 pub fn feedback(&self) -> Option<ThreadFeedback> {
2072 self.feedback
2073 }
2074
2075 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2076 self.message_feedback.get(&message_id).copied()
2077 }
2078
2079 pub fn report_message_feedback(
2080 &mut self,
2081 message_id: MessageId,
2082 feedback: ThreadFeedback,
2083 cx: &mut Context<Self>,
2084 ) -> Task<Result<()>> {
2085 if self.message_feedback.get(&message_id) == Some(&feedback) {
2086 return Task::ready(Ok(()));
2087 }
2088
2089 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2090 let serialized_thread = self.serialize(cx);
2091 let thread_id = self.id().clone();
2092 let client = self.project.read(cx).client();
2093
2094 let enabled_tool_names: Vec<String> = self
2095 .tools()
2096 .read(cx)
2097 .enabled_tools(cx)
2098 .iter()
2099 .map(|tool| tool.name().to_string())
2100 .collect();
2101
2102 self.message_feedback.insert(message_id, feedback);
2103
2104 cx.notify();
2105
2106 let message_content = self
2107 .message(message_id)
2108 .map(|msg| msg.to_string())
2109 .unwrap_or_default();
2110
2111 cx.background_spawn(async move {
2112 let final_project_snapshot = final_project_snapshot.await;
2113 let serialized_thread = serialized_thread.await?;
2114 let thread_data =
2115 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2116
2117 let rating = match feedback {
2118 ThreadFeedback::Positive => "positive",
2119 ThreadFeedback::Negative => "negative",
2120 };
2121 telemetry::event!(
2122 "Assistant Thread Rated",
2123 rating,
2124 thread_id,
2125 enabled_tool_names,
2126 message_id = message_id.0,
2127 message_content,
2128 thread_data,
2129 final_project_snapshot
2130 );
2131 client.telemetry().flush_events().await;
2132
2133 Ok(())
2134 })
2135 }
2136
2137 pub fn report_feedback(
2138 &mut self,
2139 feedback: ThreadFeedback,
2140 cx: &mut Context<Self>,
2141 ) -> Task<Result<()>> {
2142 let last_assistant_message_id = self
2143 .messages
2144 .iter()
2145 .rev()
2146 .find(|msg| msg.role == Role::Assistant)
2147 .map(|msg| msg.id);
2148
2149 if let Some(message_id) = last_assistant_message_id {
2150 self.report_message_feedback(message_id, feedback, cx)
2151 } else {
2152 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2153 let serialized_thread = self.serialize(cx);
2154 let thread_id = self.id().clone();
2155 let client = self.project.read(cx).client();
2156 self.feedback = Some(feedback);
2157 cx.notify();
2158
2159 cx.background_spawn(async move {
2160 let final_project_snapshot = final_project_snapshot.await;
2161 let serialized_thread = serialized_thread.await?;
2162 let thread_data = serde_json::to_value(serialized_thread)
2163 .unwrap_or_else(|_| serde_json::Value::Null);
2164
2165 let rating = match feedback {
2166 ThreadFeedback::Positive => "positive",
2167 ThreadFeedback::Negative => "negative",
2168 };
2169 telemetry::event!(
2170 "Assistant Thread Rated",
2171 rating,
2172 thread_id,
2173 thread_data,
2174 final_project_snapshot
2175 );
2176 client.telemetry().flush_events().await;
2177
2178 Ok(())
2179 })
2180 }
2181 }
2182
2183 /// Create a snapshot of the current project state including git information and unsaved buffers.
2184 fn project_snapshot(
2185 project: Entity<Project>,
2186 cx: &mut Context<Self>,
2187 ) -> Task<Arc<ProjectSnapshot>> {
2188 let git_store = project.read(cx).git_store().clone();
2189 let worktree_snapshots: Vec<_> = project
2190 .read(cx)
2191 .visible_worktrees(cx)
2192 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2193 .collect();
2194
2195 cx.spawn(async move |_, cx| {
2196 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2197
2198 let mut unsaved_buffers = Vec::new();
2199 cx.update(|app_cx| {
2200 let buffer_store = project.read(app_cx).buffer_store();
2201 for buffer_handle in buffer_store.read(app_cx).buffers() {
2202 let buffer = buffer_handle.read(app_cx);
2203 if buffer.is_dirty() {
2204 if let Some(file) = buffer.file() {
2205 let path = file.path().to_string_lossy().to_string();
2206 unsaved_buffers.push(path);
2207 }
2208 }
2209 }
2210 })
2211 .ok();
2212
2213 Arc::new(ProjectSnapshot {
2214 worktree_snapshots,
2215 unsaved_buffer_paths: unsaved_buffers,
2216 timestamp: Utc::now(),
2217 })
2218 })
2219 }
2220
2221 fn worktree_snapshot(
2222 worktree: Entity<project::Worktree>,
2223 git_store: Entity<GitStore>,
2224 cx: &App,
2225 ) -> Task<WorktreeSnapshot> {
2226 cx.spawn(async move |cx| {
2227 // Get worktree path and snapshot
2228 let worktree_info = cx.update(|app_cx| {
2229 let worktree = worktree.read(app_cx);
2230 let path = worktree.abs_path().to_string_lossy().to_string();
2231 let snapshot = worktree.snapshot();
2232 (path, snapshot)
2233 });
2234
2235 let Ok((worktree_path, _snapshot)) = worktree_info else {
2236 return WorktreeSnapshot {
2237 worktree_path: String::new(),
2238 git_state: None,
2239 };
2240 };
2241
2242 let git_state = git_store
2243 .update(cx, |git_store, cx| {
2244 git_store
2245 .repositories()
2246 .values()
2247 .find(|repo| {
2248 repo.read(cx)
2249 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2250 .is_some()
2251 })
2252 .cloned()
2253 })
2254 .ok()
2255 .flatten()
2256 .map(|repo| {
2257 repo.update(cx, |repo, _| {
2258 let current_branch =
2259 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2260 repo.send_job(None, |state, _| async move {
2261 let RepositoryState::Local { backend, .. } = state else {
2262 return GitState {
2263 remote_url: None,
2264 head_sha: None,
2265 current_branch,
2266 diff: None,
2267 };
2268 };
2269
2270 let remote_url = backend.remote_url("origin");
2271 let head_sha = backend.head_sha().await;
2272 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2273
2274 GitState {
2275 remote_url,
2276 head_sha,
2277 current_branch,
2278 diff,
2279 }
2280 })
2281 })
2282 });
2283
2284 let git_state = match git_state {
2285 Some(git_state) => match git_state.ok() {
2286 Some(git_state) => git_state.await.ok(),
2287 None => None,
2288 },
2289 None => None,
2290 };
2291
2292 WorktreeSnapshot {
2293 worktree_path,
2294 git_state,
2295 }
2296 })
2297 }
2298
2299 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2300 let mut markdown = Vec::new();
2301
2302 if let Some(summary) = self.summary() {
2303 writeln!(markdown, "# {summary}\n")?;
2304 };
2305
2306 for message in self.messages() {
2307 writeln!(
2308 markdown,
2309 "## {role}\n",
2310 role = match message.role {
2311 Role::User => "User",
2312 Role::Assistant => "Assistant",
2313 Role::System => "System",
2314 }
2315 )?;
2316
2317 if !message.loaded_context.text.is_empty() {
2318 writeln!(markdown, "{}", message.loaded_context.text)?;
2319 }
2320
2321 if !message.loaded_context.images.is_empty() {
2322 writeln!(
2323 markdown,
2324 "\n{} images attached as context.\n",
2325 message.loaded_context.images.len()
2326 )?;
2327 }
2328
2329 for segment in &message.segments {
2330 match segment {
2331 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2332 MessageSegment::Thinking { text, .. } => {
2333 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2334 }
2335 MessageSegment::RedactedThinking(_) => {}
2336 }
2337 }
2338
2339 for tool_use in self.tool_uses_for_message(message.id, cx) {
2340 writeln!(
2341 markdown,
2342 "**Use Tool: {} ({})**",
2343 tool_use.name, tool_use.id
2344 )?;
2345 writeln!(markdown, "```json")?;
2346 writeln!(
2347 markdown,
2348 "{}",
2349 serde_json::to_string_pretty(&tool_use.input)?
2350 )?;
2351 writeln!(markdown, "```")?;
2352 }
2353
2354 for tool_result in self.tool_results_for_message(message.id) {
2355 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2356 if tool_result.is_error {
2357 write!(markdown, " (Error)")?;
2358 }
2359
2360 writeln!(markdown, "**\n")?;
2361 writeln!(markdown, "{}", tool_result.content)?;
2362 }
2363 }
2364
2365 Ok(String::from_utf8_lossy(&markdown).to_string())
2366 }
2367
2368 pub fn keep_edits_in_range(
2369 &mut self,
2370 buffer: Entity<language::Buffer>,
2371 buffer_range: Range<language::Anchor>,
2372 cx: &mut Context<Self>,
2373 ) {
2374 self.action_log.update(cx, |action_log, cx| {
2375 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2376 });
2377 }
2378
2379 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2380 self.action_log
2381 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2382 }
2383
2384 pub fn reject_edits_in_ranges(
2385 &mut self,
2386 buffer: Entity<language::Buffer>,
2387 buffer_ranges: Vec<Range<language::Anchor>>,
2388 cx: &mut Context<Self>,
2389 ) -> Task<Result<()>> {
2390 self.action_log.update(cx, |action_log, cx| {
2391 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2392 })
2393 }
2394
2395 pub fn action_log(&self) -> &Entity<ActionLog> {
2396 &self.action_log
2397 }
2398
2399 pub fn project(&self) -> &Entity<Project> {
2400 &self.project
2401 }
2402
2403 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2404 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2405 return;
2406 }
2407
2408 let now = Instant::now();
2409 if let Some(last) = self.last_auto_capture_at {
2410 if now.duration_since(last).as_secs() < 10 {
2411 return;
2412 }
2413 }
2414
2415 self.last_auto_capture_at = Some(now);
2416
2417 let thread_id = self.id().clone();
2418 let github_login = self
2419 .project
2420 .read(cx)
2421 .user_store()
2422 .read(cx)
2423 .current_user()
2424 .map(|user| user.github_login.clone());
2425 let client = self.project.read(cx).client().clone();
2426 let serialize_task = self.serialize(cx);
2427
2428 cx.background_executor()
2429 .spawn(async move {
2430 if let Ok(serialized_thread) = serialize_task.await {
2431 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2432 telemetry::event!(
2433 "Agent Thread Auto-Captured",
2434 thread_id = thread_id.to_string(),
2435 thread_data = thread_data,
2436 auto_capture_reason = "tracked_user",
2437 github_login = github_login
2438 );
2439
2440 client.telemetry().flush_events().await;
2441 }
2442 }
2443 })
2444 .detach();
2445 }
2446
2447 pub fn cumulative_token_usage(&self) -> TokenUsage {
2448 self.cumulative_token_usage
2449 }
2450
2451 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2452 let Some(model) = self.configured_model.as_ref() else {
2453 return TotalTokenUsage::default();
2454 };
2455
2456 let max = model.model.max_token_count();
2457
2458 let index = self
2459 .messages
2460 .iter()
2461 .position(|msg| msg.id == message_id)
2462 .unwrap_or(0);
2463
2464 if index == 0 {
2465 return TotalTokenUsage { total: 0, max };
2466 }
2467
2468 let token_usage = &self
2469 .request_token_usage
2470 .get(index - 1)
2471 .cloned()
2472 .unwrap_or_default();
2473
2474 TotalTokenUsage {
2475 total: token_usage.total_tokens() as usize,
2476 max,
2477 }
2478 }
2479
2480 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2481 let model = self.configured_model.as_ref()?;
2482
2483 let max = model.model.max_token_count();
2484
2485 if let Some(exceeded_error) = &self.exceeded_window_error {
2486 if model.model.id() == exceeded_error.model_id {
2487 return Some(TotalTokenUsage {
2488 total: exceeded_error.token_count,
2489 max,
2490 });
2491 }
2492 }
2493
2494 let total = self
2495 .token_usage_at_last_message()
2496 .unwrap_or_default()
2497 .total_tokens() as usize;
2498
2499 Some(TotalTokenUsage { total, max })
2500 }
2501
2502 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2503 self.request_token_usage
2504 .get(self.messages.len().saturating_sub(1))
2505 .or_else(|| self.request_token_usage.last())
2506 .cloned()
2507 }
2508
2509 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2510 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2511 self.request_token_usage
2512 .resize(self.messages.len(), placeholder);
2513
2514 if let Some(last) = self.request_token_usage.last_mut() {
2515 *last = token_usage;
2516 }
2517 }
2518
2519 pub fn deny_tool_use(
2520 &mut self,
2521 tool_use_id: LanguageModelToolUseId,
2522 tool_name: Arc<str>,
2523 window: Option<AnyWindowHandle>,
2524 cx: &mut Context<Self>,
2525 ) {
2526 let err = Err(anyhow::anyhow!(
2527 "Permission to run tool action denied by user"
2528 ));
2529
2530 self.tool_use.insert_tool_output(
2531 tool_use_id.clone(),
2532 tool_name,
2533 err,
2534 self.configured_model.as_ref(),
2535 );
2536 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2537 }
2538}
2539
2540#[derive(Debug, Clone, Error)]
2541pub enum ThreadError {
2542 #[error("Payment required")]
2543 PaymentRequired,
2544 #[error("Max monthly spend reached")]
2545 MaxMonthlySpendReached,
2546 #[error("Model request limit reached")]
2547 ModelRequestLimitReached { plan: Plan },
2548 #[error("Message {header}: {message}")]
2549 Message {
2550 header: SharedString,
2551 message: SharedString,
2552 },
2553}
2554
2555#[derive(Debug, Clone)]
2556pub enum ThreadEvent {
2557 ShowError(ThreadError),
2558 UsageUpdated(RequestUsage),
2559 StreamedCompletion,
2560 ReceivedTextChunk,
2561 NewRequest,
2562 StreamedAssistantText(MessageId, String),
2563 StreamedAssistantThinking(MessageId, String),
2564 StreamedToolUse {
2565 tool_use_id: LanguageModelToolUseId,
2566 ui_text: Arc<str>,
2567 input: serde_json::Value,
2568 },
2569 InvalidToolInput {
2570 tool_use_id: LanguageModelToolUseId,
2571 ui_text: Arc<str>,
2572 invalid_input_json: Arc<str>,
2573 },
2574 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2575 MessageAdded(MessageId),
2576 MessageEdited(MessageId),
2577 MessageDeleted(MessageId),
2578 SummaryGenerated,
2579 SummaryChanged,
2580 UsePendingTools {
2581 tool_uses: Vec<PendingToolUse>,
2582 },
2583 ToolFinished {
2584 #[allow(unused)]
2585 tool_use_id: LanguageModelToolUseId,
2586 /// The pending tool use that corresponds to this tool.
2587 pending_tool_use: Option<PendingToolUse>,
2588 },
2589 CheckpointChanged,
2590 ToolConfirmationNeeded,
2591 CancelEditing,
2592 CompletionCanceled,
2593}
2594
2595impl EventEmitter<ThreadEvent> for Thread {}
2596
2597struct PendingCompletion {
2598 id: usize,
2599 queue_state: QueueState,
2600 _task: Task<()>,
2601}
2602
2603#[cfg(test)]
2604mod tests {
2605 use super::*;
2606 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2607 use assistant_settings::AssistantSettings;
2608 use assistant_tool::ToolRegistry;
2609 use context_server::ContextServerSettings;
2610 use editor::EditorSettings;
2611 use gpui::TestAppContext;
2612 use language_model::fake_provider::FakeLanguageModel;
2613 use project::{FakeFs, Project};
2614 use prompt_store::PromptBuilder;
2615 use serde_json::json;
2616 use settings::{Settings, SettingsStore};
2617 use std::sync::Arc;
2618 use theme::ThemeSettings;
2619 use util::path;
2620 use workspace::Workspace;
2621
2622 #[gpui::test]
2623 async fn test_message_with_context(cx: &mut TestAppContext) {
2624 init_test_settings(cx);
2625
2626 let project = create_test_project(
2627 cx,
2628 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2629 )
2630 .await;
2631
2632 let (_workspace, _thread_store, thread, context_store, model) =
2633 setup_test_environment(cx, project.clone()).await;
2634
2635 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2636 .await
2637 .unwrap();
2638
2639 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2640 let loaded_context = cx
2641 .update(|cx| load_context(vec![context], &project, &None, cx))
2642 .await;
2643
2644 // Insert user message with context
2645 let message_id = thread.update(cx, |thread, cx| {
2646 thread.insert_user_message(
2647 "Please explain this code",
2648 loaded_context,
2649 None,
2650 Vec::new(),
2651 cx,
2652 )
2653 });
2654
2655 // Check content and context in message object
2656 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2657
2658 // Use different path format strings based on platform for the test
2659 #[cfg(windows)]
2660 let path_part = r"test\code.rs";
2661 #[cfg(not(windows))]
2662 let path_part = "test/code.rs";
2663
2664 let expected_context = format!(
2665 r#"
2666<context>
2667The following items were attached by the user. They are up-to-date and don't need to be re-read.
2668
2669<files>
2670```rs {path_part}
2671fn main() {{
2672 println!("Hello, world!");
2673}}
2674```
2675</files>
2676</context>
2677"#
2678 );
2679
2680 assert_eq!(message.role, Role::User);
2681 assert_eq!(message.segments.len(), 1);
2682 assert_eq!(
2683 message.segments[0],
2684 MessageSegment::Text("Please explain this code".to_string())
2685 );
2686 assert_eq!(message.loaded_context.text, expected_context);
2687
2688 // Check message in request
2689 let request = thread.update(cx, |thread, cx| {
2690 thread.to_completion_request(model.clone(), cx)
2691 });
2692
2693 assert_eq!(request.messages.len(), 2);
2694 let expected_full_message = format!("{}Please explain this code", expected_context);
2695 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2696 }
2697
2698 #[gpui::test]
2699 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2700 init_test_settings(cx);
2701
2702 let project = create_test_project(
2703 cx,
2704 json!({
2705 "file1.rs": "fn function1() {}\n",
2706 "file2.rs": "fn function2() {}\n",
2707 "file3.rs": "fn function3() {}\n",
2708 "file4.rs": "fn function4() {}\n",
2709 }),
2710 )
2711 .await;
2712
2713 let (_, _thread_store, thread, context_store, model) =
2714 setup_test_environment(cx, project.clone()).await;
2715
2716 // First message with context 1
2717 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2718 .await
2719 .unwrap();
2720 let new_contexts = context_store.update(cx, |store, cx| {
2721 store.new_context_for_thread(thread.read(cx), None)
2722 });
2723 assert_eq!(new_contexts.len(), 1);
2724 let loaded_context = cx
2725 .update(|cx| load_context(new_contexts, &project, &None, cx))
2726 .await;
2727 let message1_id = thread.update(cx, |thread, cx| {
2728 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2729 });
2730
2731 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2732 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2733 .await
2734 .unwrap();
2735 let new_contexts = context_store.update(cx, |store, cx| {
2736 store.new_context_for_thread(thread.read(cx), None)
2737 });
2738 assert_eq!(new_contexts.len(), 1);
2739 let loaded_context = cx
2740 .update(|cx| load_context(new_contexts, &project, &None, cx))
2741 .await;
2742 let message2_id = thread.update(cx, |thread, cx| {
2743 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2744 });
2745
2746 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2747 //
2748 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2749 .await
2750 .unwrap();
2751 let new_contexts = context_store.update(cx, |store, cx| {
2752 store.new_context_for_thread(thread.read(cx), None)
2753 });
2754 assert_eq!(new_contexts.len(), 1);
2755 let loaded_context = cx
2756 .update(|cx| load_context(new_contexts, &project, &None, cx))
2757 .await;
2758 let message3_id = thread.update(cx, |thread, cx| {
2759 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2760 });
2761
2762 // Check what contexts are included in each message
2763 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2764 (
2765 thread.message(message1_id).unwrap().clone(),
2766 thread.message(message2_id).unwrap().clone(),
2767 thread.message(message3_id).unwrap().clone(),
2768 )
2769 });
2770
2771 // First message should include context 1
2772 assert!(message1.loaded_context.text.contains("file1.rs"));
2773
2774 // Second message should include only context 2 (not 1)
2775 assert!(!message2.loaded_context.text.contains("file1.rs"));
2776 assert!(message2.loaded_context.text.contains("file2.rs"));
2777
2778 // Third message should include only context 3 (not 1 or 2)
2779 assert!(!message3.loaded_context.text.contains("file1.rs"));
2780 assert!(!message3.loaded_context.text.contains("file2.rs"));
2781 assert!(message3.loaded_context.text.contains("file3.rs"));
2782
2783 // Check entire request to make sure all contexts are properly included
2784 let request = thread.update(cx, |thread, cx| {
2785 thread.to_completion_request(model.clone(), cx)
2786 });
2787
2788 // The request should contain all 3 messages
2789 assert_eq!(request.messages.len(), 4);
2790
2791 // Check that the contexts are properly formatted in each message
2792 assert!(request.messages[1].string_contents().contains("file1.rs"));
2793 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2794 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2795
2796 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2797 assert!(request.messages[2].string_contents().contains("file2.rs"));
2798 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2799
2800 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2801 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2802 assert!(request.messages[3].string_contents().contains("file3.rs"));
2803
2804 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2805 .await
2806 .unwrap();
2807 let new_contexts = context_store.update(cx, |store, cx| {
2808 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2809 });
2810 assert_eq!(new_contexts.len(), 3);
2811 let loaded_context = cx
2812 .update(|cx| load_context(new_contexts, &project, &None, cx))
2813 .await
2814 .loaded_context;
2815
2816 assert!(!loaded_context.text.contains("file1.rs"));
2817 assert!(loaded_context.text.contains("file2.rs"));
2818 assert!(loaded_context.text.contains("file3.rs"));
2819 assert!(loaded_context.text.contains("file4.rs"));
2820
2821 let new_contexts = context_store.update(cx, |store, cx| {
2822 // Remove file4.rs
2823 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2824 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2825 });
2826 assert_eq!(new_contexts.len(), 2);
2827 let loaded_context = cx
2828 .update(|cx| load_context(new_contexts, &project, &None, cx))
2829 .await
2830 .loaded_context;
2831
2832 assert!(!loaded_context.text.contains("file1.rs"));
2833 assert!(loaded_context.text.contains("file2.rs"));
2834 assert!(loaded_context.text.contains("file3.rs"));
2835 assert!(!loaded_context.text.contains("file4.rs"));
2836
2837 let new_contexts = context_store.update(cx, |store, cx| {
2838 // Remove file3.rs
2839 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2840 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2841 });
2842 assert_eq!(new_contexts.len(), 1);
2843 let loaded_context = cx
2844 .update(|cx| load_context(new_contexts, &project, &None, cx))
2845 .await
2846 .loaded_context;
2847
2848 assert!(!loaded_context.text.contains("file1.rs"));
2849 assert!(loaded_context.text.contains("file2.rs"));
2850 assert!(!loaded_context.text.contains("file3.rs"));
2851 assert!(!loaded_context.text.contains("file4.rs"));
2852 }
2853
2854 #[gpui::test]
2855 async fn test_message_without_files(cx: &mut TestAppContext) {
2856 init_test_settings(cx);
2857
2858 let project = create_test_project(
2859 cx,
2860 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2861 )
2862 .await;
2863
2864 let (_, _thread_store, thread, _context_store, model) =
2865 setup_test_environment(cx, project.clone()).await;
2866
2867 // Insert user message without any context (empty context vector)
2868 let message_id = thread.update(cx, |thread, cx| {
2869 thread.insert_user_message(
2870 "What is the best way to learn Rust?",
2871 ContextLoadResult::default(),
2872 None,
2873 Vec::new(),
2874 cx,
2875 )
2876 });
2877
2878 // Check content and context in message object
2879 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2880
2881 // Context should be empty when no files are included
2882 assert_eq!(message.role, Role::User);
2883 assert_eq!(message.segments.len(), 1);
2884 assert_eq!(
2885 message.segments[0],
2886 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2887 );
2888 assert_eq!(message.loaded_context.text, "");
2889
2890 // Check message in request
2891 let request = thread.update(cx, |thread, cx| {
2892 thread.to_completion_request(model.clone(), cx)
2893 });
2894
2895 assert_eq!(request.messages.len(), 2);
2896 assert_eq!(
2897 request.messages[1].string_contents(),
2898 "What is the best way to learn Rust?"
2899 );
2900
2901 // Add second message, also without context
2902 let message2_id = thread.update(cx, |thread, cx| {
2903 thread.insert_user_message(
2904 "Are there any good books?",
2905 ContextLoadResult::default(),
2906 None,
2907 Vec::new(),
2908 cx,
2909 )
2910 });
2911
2912 let message2 =
2913 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2914 assert_eq!(message2.loaded_context.text, "");
2915
2916 // Check that both messages appear in the request
2917 let request = thread.update(cx, |thread, cx| {
2918 thread.to_completion_request(model.clone(), cx)
2919 });
2920
2921 assert_eq!(request.messages.len(), 3);
2922 assert_eq!(
2923 request.messages[1].string_contents(),
2924 "What is the best way to learn Rust?"
2925 );
2926 assert_eq!(
2927 request.messages[2].string_contents(),
2928 "Are there any good books?"
2929 );
2930 }
2931
2932 #[gpui::test]
2933 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2934 init_test_settings(cx);
2935
2936 let project = create_test_project(
2937 cx,
2938 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2939 )
2940 .await;
2941
2942 let (_workspace, _thread_store, thread, context_store, model) =
2943 setup_test_environment(cx, project.clone()).await;
2944
2945 // Open buffer and add it to context
2946 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2947 .await
2948 .unwrap();
2949
2950 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2951 let loaded_context = cx
2952 .update(|cx| load_context(vec![context], &project, &None, cx))
2953 .await;
2954
2955 // Insert user message with the buffer as context
2956 thread.update(cx, |thread, cx| {
2957 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2958 });
2959
2960 // Create a request and check that it doesn't have a stale buffer warning yet
2961 let initial_request = thread.update(cx, |thread, cx| {
2962 thread.to_completion_request(model.clone(), cx)
2963 });
2964
2965 // Make sure we don't have a stale file warning yet
2966 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2967 msg.string_contents()
2968 .contains("These files changed since last read:")
2969 });
2970 assert!(
2971 !has_stale_warning,
2972 "Should not have stale buffer warning before buffer is modified"
2973 );
2974
2975 // Modify the buffer
2976 buffer.update(cx, |buffer, cx| {
2977 // Find a position at the end of line 1
2978 buffer.edit(
2979 [(1..1, "\n println!(\"Added a new line\");\n")],
2980 None,
2981 cx,
2982 );
2983 });
2984
2985 // Insert another user message without context
2986 thread.update(cx, |thread, cx| {
2987 thread.insert_user_message(
2988 "What does the code do now?",
2989 ContextLoadResult::default(),
2990 None,
2991 Vec::new(),
2992 cx,
2993 )
2994 });
2995
2996 // Create a new request and check for the stale buffer warning
2997 let new_request = thread.update(cx, |thread, cx| {
2998 thread.to_completion_request(model.clone(), cx)
2999 });
3000
3001 // We should have a stale file warning as the last message
3002 let last_message = new_request
3003 .messages
3004 .last()
3005 .expect("Request should have messages");
3006
3007 // The last message should be the stale buffer notification
3008 assert_eq!(last_message.role, Role::User);
3009
3010 // Check the exact content of the message
3011 let expected_content = "These files changed since last read:\n- code.rs\n";
3012 assert_eq!(
3013 last_message.string_contents(),
3014 expected_content,
3015 "Last message should be exactly the stale buffer notification"
3016 );
3017 }
3018
3019 fn init_test_settings(cx: &mut TestAppContext) {
3020 cx.update(|cx| {
3021 let settings_store = SettingsStore::test(cx);
3022 cx.set_global(settings_store);
3023 language::init(cx);
3024 Project::init_settings(cx);
3025 AssistantSettings::register(cx);
3026 prompt_store::init(cx);
3027 thread_store::init(cx);
3028 workspace::init_settings(cx);
3029 language_model::init_settings(cx);
3030 ThemeSettings::register(cx);
3031 ContextServerSettings::register(cx);
3032 EditorSettings::register(cx);
3033 ToolRegistry::default_global(cx);
3034 });
3035 }
3036
3037 // Helper to create a test project with test files
3038 async fn create_test_project(
3039 cx: &mut TestAppContext,
3040 files: serde_json::Value,
3041 ) -> Entity<Project> {
3042 let fs = FakeFs::new(cx.executor());
3043 fs.insert_tree(path!("/test"), files).await;
3044 Project::test(fs, [path!("/test").as_ref()], cx).await
3045 }
3046
3047 async fn setup_test_environment(
3048 cx: &mut TestAppContext,
3049 project: Entity<Project>,
3050 ) -> (
3051 Entity<Workspace>,
3052 Entity<ThreadStore>,
3053 Entity<Thread>,
3054 Entity<ContextStore>,
3055 Arc<dyn LanguageModel>,
3056 ) {
3057 let (workspace, cx) =
3058 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3059
3060 let thread_store = cx
3061 .update(|_, cx| {
3062 ThreadStore::load(
3063 project.clone(),
3064 cx.new(|_| ToolWorkingSet::default()),
3065 None,
3066 Arc::new(PromptBuilder::new(None).unwrap()),
3067 cx,
3068 )
3069 })
3070 .await
3071 .unwrap();
3072
3073 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3074 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3075
3076 let model = FakeLanguageModel::default();
3077 let model: Arc<dyn LanguageModel> = Arc::new(model);
3078
3079 (workspace, thread_store, thread, context_store, model)
3080 }
3081
3082 async fn add_file_to_context(
3083 project: &Entity<Project>,
3084 context_store: &Entity<ContextStore>,
3085 path: &str,
3086 cx: &mut TestAppContext,
3087 ) -> Result<Entity<language::Buffer>> {
3088 let buffer_path = project
3089 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3090 .unwrap();
3091
3092 let buffer = project
3093 .update(cx, |project, cx| {
3094 project.open_buffer(buffer_path.clone(), cx)
3095 })
3096 .await
3097 .unwrap();
3098
3099 context_store.update(cx, |context_store, cx| {
3100 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3101 });
3102
3103 Ok(buffer)
3104 }
3105}