1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use util::{ResultExt as _, TryFutureExt as _, post_inc};
39use uuid::Uuid;
40use zed_llm_client::CompletionRequestStatus;
41
42use crate::ThreadStore;
43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
44use crate::thread_store::{
45 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
46 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
47};
48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
101#[derive(Clone, Debug)]
102pub struct MessageCrease {
103 pub range: Range<usize>,
104 pub metadata: CreaseMetadata,
105 /// None for a deserialized message, Some otherwise.
106 pub context: Option<AgentContextHandle>,
107}
108
109/// A message in a [`Thread`].
110#[derive(Debug, Clone)]
111pub struct Message {
112 pub id: MessageId,
113 pub role: Role,
114 pub segments: Vec<MessageSegment>,
115 pub loaded_context: LoadedContext,
116 pub creases: Vec<MessageCrease>,
117}
118
119impl Message {
120 /// Returns whether the message contains any meaningful text that should be displayed
121 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
122 pub fn should_display_content(&self) -> bool {
123 self.segments.iter().all(|segment| segment.should_display())
124 }
125
126 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
127 if let Some(MessageSegment::Thinking {
128 text: segment,
129 signature: current_signature,
130 }) = self.segments.last_mut()
131 {
132 if let Some(signature) = signature {
133 *current_signature = Some(signature);
134 }
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Thinking {
138 text: text.to_string(),
139 signature,
140 });
141 }
142 }
143
144 pub fn push_text(&mut self, text: &str) {
145 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
146 segment.push_str(text);
147 } else {
148 self.segments.push(MessageSegment::Text(text.to_string()));
149 }
150 }
151
152 pub fn to_string(&self) -> String {
153 let mut result = String::new();
154
155 if !self.loaded_context.text.is_empty() {
156 result.push_str(&self.loaded_context.text);
157 }
158
159 for segment in &self.segments {
160 match segment {
161 MessageSegment::Text(text) => result.push_str(text),
162 MessageSegment::Thinking { text, .. } => {
163 result.push_str("<think>\n");
164 result.push_str(text);
165 result.push_str("\n</think>");
166 }
167 MessageSegment::RedactedThinking(_) => {}
168 }
169 }
170
171 result
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum MessageSegment {
177 Text(String),
178 Thinking {
179 text: String,
180 signature: Option<String>,
181 },
182 RedactedThinking(Vec<u8>),
183}
184
185impl MessageSegment {
186 pub fn should_display(&self) -> bool {
187 match self {
188 Self::Text(text) => text.is_empty(),
189 Self::Thinking { text, .. } => text.is_empty(),
190 Self::RedactedThinking(_) => false,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ProjectSnapshot {
197 pub worktree_snapshots: Vec<WorktreeSnapshot>,
198 pub unsaved_buffer_paths: Vec<String>,
199 pub timestamp: DateTime<Utc>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WorktreeSnapshot {
204 pub worktree_path: String,
205 pub git_state: Option<GitState>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GitState {
210 pub remote_url: Option<String>,
211 pub head_sha: Option<String>,
212 pub current_branch: Option<String>,
213 pub diff: Option<String>,
214}
215
216#[derive(Clone)]
217pub struct ThreadCheckpoint {
218 message_id: MessageId,
219 git_checkpoint: GitStoreCheckpoint,
220}
221
222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
223pub enum ThreadFeedback {
224 Positive,
225 Negative,
226}
227
228pub enum LastRestoreCheckpoint {
229 Pending {
230 message_id: MessageId,
231 },
232 Error {
233 message_id: MessageId,
234 error: String,
235 },
236}
237
238impl LastRestoreCheckpoint {
239 pub fn message_id(&self) -> MessageId {
240 match self {
241 LastRestoreCheckpoint::Pending { message_id } => *message_id,
242 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub enum DetailedSummaryState {
249 #[default]
250 NotGenerated,
251 Generating {
252 message_id: MessageId,
253 },
254 Generated {
255 text: SharedString,
256 message_id: MessageId,
257 },
258}
259
260impl DetailedSummaryState {
261 fn text(&self) -> Option<SharedString> {
262 if let Self::Generated { text, .. } = self {
263 Some(text.clone())
264 } else {
265 None
266 }
267 }
268}
269
270#[derive(Default, Debug)]
271pub struct TotalTokenUsage {
272 pub total: usize,
273 pub max: usize,
274}
275
276impl TotalTokenUsage {
277 pub fn ratio(&self) -> TokenUsageRatio {
278 #[cfg(debug_assertions)]
279 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
280 .unwrap_or("0.8".to_string())
281 .parse()
282 .unwrap();
283 #[cfg(not(debug_assertions))]
284 let warning_threshold: f32 = 0.8;
285
286 // When the maximum is unknown because there is no selected model,
287 // avoid showing the token limit warning.
288 if self.max == 0 {
289 TokenUsageRatio::Normal
290 } else if self.total >= self.max {
291 TokenUsageRatio::Exceeded
292 } else if self.total as f32 / self.max as f32 >= warning_threshold {
293 TokenUsageRatio::Warning
294 } else {
295 TokenUsageRatio::Normal
296 }
297 }
298
299 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
300 TotalTokenUsage {
301 total: self.total + tokens,
302 max: self.max,
303 }
304 }
305}
306
307#[derive(Debug, Default, PartialEq, Eq)]
308pub enum TokenUsageRatio {
309 #[default]
310 Normal,
311 Warning,
312 Exceeded,
313}
314
315#[derive(Debug, Clone, Copy)]
316pub enum QueueState {
317 Sending,
318 Queued { position: usize },
319 Started,
320}
321
322/// A thread of conversation with the LLM.
323pub struct Thread {
324 id: ThreadId,
325 updated_at: DateTime<Utc>,
326 summary: Option<SharedString>,
327 pending_summary: Task<Option<()>>,
328 detailed_summary_task: Task<Option<()>>,
329 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
330 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
331 completion_mode: assistant_settings::CompletionMode,
332 messages: Vec<Message>,
333 next_message_id: MessageId,
334 last_prompt_id: PromptId,
335 project_context: SharedProjectContext,
336 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
337 completion_count: usize,
338 pending_completions: Vec<PendingCompletion>,
339 project: Entity<Project>,
340 prompt_builder: Arc<PromptBuilder>,
341 tools: Entity<ToolWorkingSet>,
342 tool_use: ToolUseState,
343 action_log: Entity<ActionLog>,
344 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
345 pending_checkpoint: Option<ThreadCheckpoint>,
346 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
347 request_token_usage: Vec<TokenUsage>,
348 cumulative_token_usage: TokenUsage,
349 exceeded_window_error: Option<ExceededWindowError>,
350 last_usage: Option<RequestUsage>,
351 tool_use_limit_reached: bool,
352 feedback: Option<ThreadFeedback>,
353 message_feedback: HashMap<MessageId, ThreadFeedback>,
354 last_auto_capture_at: Option<Instant>,
355 last_received_chunk_at: Option<Instant>,
356 request_callback: Option<
357 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
358 >,
359 remaining_turns: u32,
360 configured_model: Option<ConfiguredModel>,
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct ExceededWindowError {
365 /// Model used when last message exceeded context window
366 model_id: LanguageModelId,
367 /// Token count including last message
368 token_count: usize,
369}
370
371impl Thread {
372 pub fn new(
373 project: Entity<Project>,
374 tools: Entity<ToolWorkingSet>,
375 prompt_builder: Arc<PromptBuilder>,
376 system_prompt: SharedProjectContext,
377 cx: &mut Context<Self>,
378 ) -> Self {
379 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
380 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
381
382 Self {
383 id: ThreadId::new(),
384 updated_at: Utc::now(),
385 summary: None,
386 pending_summary: Task::ready(None),
387 detailed_summary_task: Task::ready(None),
388 detailed_summary_tx,
389 detailed_summary_rx,
390 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
391 messages: Vec::new(),
392 next_message_id: MessageId(0),
393 last_prompt_id: PromptId::new(),
394 project_context: system_prompt,
395 checkpoints_by_message: HashMap::default(),
396 completion_count: 0,
397 pending_completions: Vec::new(),
398 project: project.clone(),
399 prompt_builder,
400 tools: tools.clone(),
401 last_restore_checkpoint: None,
402 pending_checkpoint: None,
403 tool_use: ToolUseState::new(tools.clone()),
404 action_log: cx.new(|_| ActionLog::new(project.clone())),
405 initial_project_snapshot: {
406 let project_snapshot = Self::project_snapshot(project, cx);
407 cx.foreground_executor()
408 .spawn(async move { Some(project_snapshot.await) })
409 .shared()
410 },
411 request_token_usage: Vec::new(),
412 cumulative_token_usage: TokenUsage::default(),
413 exceeded_window_error: None,
414 last_usage: None,
415 tool_use_limit_reached: false,
416 feedback: None,
417 message_feedback: HashMap::default(),
418 last_auto_capture_at: None,
419 last_received_chunk_at: None,
420 request_callback: None,
421 remaining_turns: u32::MAX,
422 configured_model,
423 }
424 }
425
426 pub fn deserialize(
427 id: ThreadId,
428 serialized: SerializedThread,
429 project: Entity<Project>,
430 tools: Entity<ToolWorkingSet>,
431 prompt_builder: Arc<PromptBuilder>,
432 project_context: SharedProjectContext,
433 cx: &mut Context<Self>,
434 ) -> Self {
435 let next_message_id = MessageId(
436 serialized
437 .messages
438 .last()
439 .map(|message| message.id.0 + 1)
440 .unwrap_or(0),
441 );
442 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
443 let (detailed_summary_tx, detailed_summary_rx) =
444 postage::watch::channel_with(serialized.detailed_summary_state);
445
446 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
447 serialized
448 .model
449 .and_then(|model| {
450 let model = SelectedModel {
451 provider: model.provider.clone().into(),
452 model: model.model.clone().into(),
453 };
454 registry.select_model(&model, cx)
455 })
456 .or_else(|| registry.default_model())
457 });
458
459 let completion_mode = serialized
460 .completion_mode
461 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
462
463 Self {
464 id,
465 updated_at: serialized.updated_at,
466 summary: Some(serialized.summary),
467 pending_summary: Task::ready(None),
468 detailed_summary_task: Task::ready(None),
469 detailed_summary_tx,
470 detailed_summary_rx,
471 completion_mode,
472 messages: serialized
473 .messages
474 .into_iter()
475 .map(|message| Message {
476 id: message.id,
477 role: message.role,
478 segments: message
479 .segments
480 .into_iter()
481 .map(|segment| match segment {
482 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
483 SerializedMessageSegment::Thinking { text, signature } => {
484 MessageSegment::Thinking { text, signature }
485 }
486 SerializedMessageSegment::RedactedThinking { data } => {
487 MessageSegment::RedactedThinking(data)
488 }
489 })
490 .collect(),
491 loaded_context: LoadedContext {
492 contexts: Vec::new(),
493 text: message.context,
494 images: Vec::new(),
495 },
496 creases: message
497 .creases
498 .into_iter()
499 .map(|crease| MessageCrease {
500 range: crease.start..crease.end,
501 metadata: CreaseMetadata {
502 icon_path: crease.icon_path,
503 label: crease.label,
504 },
505 context: None,
506 })
507 .collect(),
508 })
509 .collect(),
510 next_message_id,
511 last_prompt_id: PromptId::new(),
512 project_context,
513 checkpoints_by_message: HashMap::default(),
514 completion_count: 0,
515 pending_completions: Vec::new(),
516 last_restore_checkpoint: None,
517 pending_checkpoint: None,
518 project: project.clone(),
519 prompt_builder,
520 tools,
521 tool_use,
522 action_log: cx.new(|_| ActionLog::new(project)),
523 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
524 request_token_usage: serialized.request_token_usage,
525 cumulative_token_usage: serialized.cumulative_token_usage,
526 exceeded_window_error: None,
527 last_usage: None,
528 tool_use_limit_reached: false,
529 feedback: None,
530 message_feedback: HashMap::default(),
531 last_auto_capture_at: None,
532 last_received_chunk_at: None,
533 request_callback: None,
534 remaining_turns: u32::MAX,
535 configured_model,
536 }
537 }
538
539 pub fn set_request_callback(
540 &mut self,
541 callback: impl 'static
542 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
543 ) {
544 self.request_callback = Some(Box::new(callback));
545 }
546
547 pub fn id(&self) -> &ThreadId {
548 &self.id
549 }
550
551 pub fn is_empty(&self) -> bool {
552 self.messages.is_empty()
553 }
554
555 pub fn updated_at(&self) -> DateTime<Utc> {
556 self.updated_at
557 }
558
559 pub fn touch_updated_at(&mut self) {
560 self.updated_at = Utc::now();
561 }
562
563 pub fn advance_prompt_id(&mut self) {
564 self.last_prompt_id = PromptId::new();
565 }
566
567 pub fn summary(&self) -> Option<SharedString> {
568 self.summary.clone()
569 }
570
571 pub fn project_context(&self) -> SharedProjectContext {
572 self.project_context.clone()
573 }
574
575 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
576 if self.configured_model.is_none() {
577 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
578 }
579 self.configured_model.clone()
580 }
581
582 pub fn configured_model(&self) -> Option<ConfiguredModel> {
583 self.configured_model.clone()
584 }
585
586 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
587 self.configured_model = model;
588 cx.notify();
589 }
590
591 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
592
593 pub fn summary_or_default(&self) -> SharedString {
594 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
595 }
596
597 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
598 let Some(current_summary) = &self.summary else {
599 // Don't allow setting summary until generated
600 return;
601 };
602
603 let mut new_summary = new_summary.into();
604
605 if new_summary.is_empty() {
606 new_summary = Self::DEFAULT_SUMMARY;
607 }
608
609 if current_summary != &new_summary {
610 self.summary = Some(new_summary);
611 cx.emit(ThreadEvent::SummaryChanged);
612 }
613 }
614
615 pub fn completion_mode(&self) -> CompletionMode {
616 self.completion_mode
617 }
618
619 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
620 self.completion_mode = mode;
621 }
622
623 pub fn message(&self, id: MessageId) -> Option<&Message> {
624 let index = self
625 .messages
626 .binary_search_by(|message| message.id.cmp(&id))
627 .ok()?;
628
629 self.messages.get(index)
630 }
631
632 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
633 self.messages.iter()
634 }
635
636 pub fn is_generating(&self) -> bool {
637 !self.pending_completions.is_empty() || !self.all_tools_finished()
638 }
639
640 /// Indicates whether streaming of language model events is stale.
641 /// When `is_generating()` is false, this method returns `None`.
642 pub fn is_generation_stale(&self) -> Option<bool> {
643 const STALE_THRESHOLD: u128 = 250;
644
645 self.last_received_chunk_at
646 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
647 }
648
649 fn received_chunk(&mut self) {
650 self.last_received_chunk_at = Some(Instant::now());
651 }
652
653 pub fn queue_state(&self) -> Option<QueueState> {
654 self.pending_completions
655 .first()
656 .map(|pending_completion| pending_completion.queue_state)
657 }
658
659 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
660 &self.tools
661 }
662
663 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
664 self.tool_use
665 .pending_tool_uses()
666 .into_iter()
667 .find(|tool_use| &tool_use.id == id)
668 }
669
670 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
671 self.tool_use
672 .pending_tool_uses()
673 .into_iter()
674 .filter(|tool_use| tool_use.status.needs_confirmation())
675 }
676
677 pub fn has_pending_tool_uses(&self) -> bool {
678 !self.tool_use.pending_tool_uses().is_empty()
679 }
680
681 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
682 self.checkpoints_by_message.get(&id).cloned()
683 }
684
685 pub fn restore_checkpoint(
686 &mut self,
687 checkpoint: ThreadCheckpoint,
688 cx: &mut Context<Self>,
689 ) -> Task<Result<()>> {
690 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
691 message_id: checkpoint.message_id,
692 });
693 cx.emit(ThreadEvent::CheckpointChanged);
694 cx.notify();
695
696 let git_store = self.project().read(cx).git_store().clone();
697 let restore = git_store.update(cx, |git_store, cx| {
698 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
699 });
700
701 cx.spawn(async move |this, cx| {
702 let result = restore.await;
703 this.update(cx, |this, cx| {
704 if let Err(err) = result.as_ref() {
705 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
706 message_id: checkpoint.message_id,
707 error: err.to_string(),
708 });
709 } else {
710 this.truncate(checkpoint.message_id, cx);
711 this.last_restore_checkpoint = None;
712 }
713 this.pending_checkpoint = None;
714 cx.emit(ThreadEvent::CheckpointChanged);
715 cx.notify();
716 })?;
717 result
718 })
719 }
720
721 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
722 let pending_checkpoint = if self.is_generating() {
723 return;
724 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
725 checkpoint
726 } else {
727 return;
728 };
729
730 let git_store = self.project.read(cx).git_store().clone();
731 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
732 cx.spawn(async move |this, cx| match final_checkpoint.await {
733 Ok(final_checkpoint) => {
734 let equal = git_store
735 .update(cx, |store, cx| {
736 store.compare_checkpoints(
737 pending_checkpoint.git_checkpoint.clone(),
738 final_checkpoint.clone(),
739 cx,
740 )
741 })?
742 .await
743 .unwrap_or(false);
744
745 if !equal {
746 this.update(cx, |this, cx| {
747 this.insert_checkpoint(pending_checkpoint, cx)
748 })?;
749 }
750
751 Ok(())
752 }
753 Err(_) => this.update(cx, |this, cx| {
754 this.insert_checkpoint(pending_checkpoint, cx)
755 }),
756 })
757 .detach();
758 }
759
760 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
761 self.checkpoints_by_message
762 .insert(checkpoint.message_id, checkpoint);
763 cx.emit(ThreadEvent::CheckpointChanged);
764 cx.notify();
765 }
766
767 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
768 self.last_restore_checkpoint.as_ref()
769 }
770
771 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
772 let Some(message_ix) = self
773 .messages
774 .iter()
775 .rposition(|message| message.id == message_id)
776 else {
777 return;
778 };
779 for deleted_message in self.messages.drain(message_ix..) {
780 self.checkpoints_by_message.remove(&deleted_message.id);
781 }
782 cx.notify();
783 }
784
785 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
786 self.messages
787 .iter()
788 .find(|message| message.id == id)
789 .into_iter()
790 .flat_map(|message| message.loaded_context.contexts.iter())
791 }
792
793 pub fn is_turn_end(&self, ix: usize) -> bool {
794 if self.messages.is_empty() {
795 return false;
796 }
797
798 if !self.is_generating() && ix == self.messages.len() - 1 {
799 return true;
800 }
801
802 let Some(message) = self.messages.get(ix) else {
803 return false;
804 };
805
806 if message.role != Role::Assistant {
807 return false;
808 }
809
810 self.messages
811 .get(ix + 1)
812 .and_then(|message| {
813 self.message(message.id)
814 .map(|next_message| next_message.role == Role::User)
815 })
816 .unwrap_or(false)
817 }
818
819 pub fn last_usage(&self) -> Option<RequestUsage> {
820 self.last_usage
821 }
822
823 pub fn tool_use_limit_reached(&self) -> bool {
824 self.tool_use_limit_reached
825 }
826
827 /// Returns whether all of the tool uses have finished running.
828 pub fn all_tools_finished(&self) -> bool {
829 // If the only pending tool uses left are the ones with errors, then
830 // that means that we've finished running all of the pending tools.
831 self.tool_use
832 .pending_tool_uses()
833 .iter()
834 .all(|tool_use| tool_use.status.is_error())
835 }
836
837 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
838 self.tool_use.tool_uses_for_message(id, cx)
839 }
840
841 pub fn tool_results_for_message(
842 &self,
843 assistant_message_id: MessageId,
844 ) -> Vec<&LanguageModelToolResult> {
845 self.tool_use.tool_results_for_message(assistant_message_id)
846 }
847
848 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
849 self.tool_use.tool_result(id)
850 }
851
852 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
853 Some(&self.tool_use.tool_result(id)?.content)
854 }
855
856 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
857 self.tool_use.tool_result_card(id).cloned()
858 }
859
860 /// Return tools that are both enabled and supported by the model
861 pub fn available_tools(
862 &self,
863 cx: &App,
864 model: Arc<dyn LanguageModel>,
865 ) -> Vec<LanguageModelRequestTool> {
866 if model.supports_tools() {
867 self.tools()
868 .read(cx)
869 .enabled_tools(cx)
870 .into_iter()
871 .filter_map(|tool| {
872 // Skip tools that cannot be supported
873 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
874 Some(LanguageModelRequestTool {
875 name: tool.name(),
876 description: tool.description(),
877 input_schema,
878 })
879 })
880 .collect()
881 } else {
882 Vec::default()
883 }
884 }
885
886 pub fn insert_user_message(
887 &mut self,
888 text: impl Into<String>,
889 loaded_context: ContextLoadResult,
890 git_checkpoint: Option<GitStoreCheckpoint>,
891 creases: Vec<MessageCrease>,
892 cx: &mut Context<Self>,
893 ) -> MessageId {
894 if !loaded_context.referenced_buffers.is_empty() {
895 self.action_log.update(cx, |log, cx| {
896 for buffer in loaded_context.referenced_buffers {
897 log.buffer_read(buffer, cx);
898 }
899 });
900 }
901
902 let message_id = self.insert_message(
903 Role::User,
904 vec![MessageSegment::Text(text.into())],
905 loaded_context.loaded_context,
906 creases,
907 cx,
908 );
909
910 if let Some(git_checkpoint) = git_checkpoint {
911 self.pending_checkpoint = Some(ThreadCheckpoint {
912 message_id,
913 git_checkpoint,
914 });
915 }
916
917 self.auto_capture_telemetry(cx);
918
919 message_id
920 }
921
922 pub fn insert_assistant_message(
923 &mut self,
924 segments: Vec<MessageSegment>,
925 cx: &mut Context<Self>,
926 ) -> MessageId {
927 self.insert_message(
928 Role::Assistant,
929 segments,
930 LoadedContext::default(),
931 Vec::new(),
932 cx,
933 )
934 }
935
936 pub fn insert_message(
937 &mut self,
938 role: Role,
939 segments: Vec<MessageSegment>,
940 loaded_context: LoadedContext,
941 creases: Vec<MessageCrease>,
942 cx: &mut Context<Self>,
943 ) -> MessageId {
944 let id = self.next_message_id.post_inc();
945 self.messages.push(Message {
946 id,
947 role,
948 segments,
949 loaded_context,
950 creases,
951 });
952 self.touch_updated_at();
953 cx.emit(ThreadEvent::MessageAdded(id));
954 id
955 }
956
957 pub fn edit_message(
958 &mut self,
959 id: MessageId,
960 new_role: Role,
961 new_segments: Vec<MessageSegment>,
962 loaded_context: Option<LoadedContext>,
963 cx: &mut Context<Self>,
964 ) -> bool {
965 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
966 return false;
967 };
968 message.role = new_role;
969 message.segments = new_segments;
970 if let Some(context) = loaded_context {
971 message.loaded_context = context;
972 }
973 self.touch_updated_at();
974 cx.emit(ThreadEvent::MessageEdited(id));
975 true
976 }
977
978 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
979 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
980 return false;
981 };
982 self.messages.remove(index);
983 self.touch_updated_at();
984 cx.emit(ThreadEvent::MessageDeleted(id));
985 true
986 }
987
988 /// Returns the representation of this [`Thread`] in a textual form.
989 ///
990 /// This is the representation we use when attaching a thread as context to another thread.
991 pub fn text(&self) -> String {
992 let mut text = String::new();
993
994 for message in &self.messages {
995 text.push_str(match message.role {
996 language_model::Role::User => "User:",
997 language_model::Role::Assistant => "Agent:",
998 language_model::Role::System => "System:",
999 });
1000 text.push('\n');
1001
1002 for segment in &message.segments {
1003 match segment {
1004 MessageSegment::Text(content) => text.push_str(content),
1005 MessageSegment::Thinking { text: content, .. } => {
1006 text.push_str(&format!("<think>{}</think>", content))
1007 }
1008 MessageSegment::RedactedThinking(_) => {}
1009 }
1010 }
1011 text.push('\n');
1012 }
1013
1014 text
1015 }
1016
1017 /// Serializes this thread into a format for storage or telemetry.
1018 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1019 let initial_project_snapshot = self.initial_project_snapshot.clone();
1020 cx.spawn(async move |this, cx| {
1021 let initial_project_snapshot = initial_project_snapshot.await;
1022 this.read_with(cx, |this, cx| SerializedThread {
1023 version: SerializedThread::VERSION.to_string(),
1024 summary: this.summary_or_default(),
1025 updated_at: this.updated_at(),
1026 messages: this
1027 .messages()
1028 .map(|message| SerializedMessage {
1029 id: message.id,
1030 role: message.role,
1031 segments: message
1032 .segments
1033 .iter()
1034 .map(|segment| match segment {
1035 MessageSegment::Text(text) => {
1036 SerializedMessageSegment::Text { text: text.clone() }
1037 }
1038 MessageSegment::Thinking { text, signature } => {
1039 SerializedMessageSegment::Thinking {
1040 text: text.clone(),
1041 signature: signature.clone(),
1042 }
1043 }
1044 MessageSegment::RedactedThinking(data) => {
1045 SerializedMessageSegment::RedactedThinking {
1046 data: data.clone(),
1047 }
1048 }
1049 })
1050 .collect(),
1051 tool_uses: this
1052 .tool_uses_for_message(message.id, cx)
1053 .into_iter()
1054 .map(|tool_use| SerializedToolUse {
1055 id: tool_use.id,
1056 name: tool_use.name,
1057 input: tool_use.input,
1058 })
1059 .collect(),
1060 tool_results: this
1061 .tool_results_for_message(message.id)
1062 .into_iter()
1063 .map(|tool_result| SerializedToolResult {
1064 tool_use_id: tool_result.tool_use_id.clone(),
1065 is_error: tool_result.is_error,
1066 content: tool_result.content.clone(),
1067 })
1068 .collect(),
1069 context: message.loaded_context.text.clone(),
1070 creases: message
1071 .creases
1072 .iter()
1073 .map(|crease| SerializedCrease {
1074 start: crease.range.start,
1075 end: crease.range.end,
1076 icon_path: crease.metadata.icon_path.clone(),
1077 label: crease.metadata.label.clone(),
1078 })
1079 .collect(),
1080 })
1081 .collect(),
1082 initial_project_snapshot,
1083 cumulative_token_usage: this.cumulative_token_usage,
1084 request_token_usage: this.request_token_usage.clone(),
1085 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1086 exceeded_window_error: this.exceeded_window_error.clone(),
1087 model: this
1088 .configured_model
1089 .as_ref()
1090 .map(|model| SerializedLanguageModel {
1091 provider: model.provider.id().0.to_string(),
1092 model: model.model.id().0.to_string(),
1093 }),
1094 completion_mode: Some(this.completion_mode),
1095 })
1096 })
1097 }
1098
1099 pub fn remaining_turns(&self) -> u32 {
1100 self.remaining_turns
1101 }
1102
1103 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1104 self.remaining_turns = remaining_turns;
1105 }
1106
1107 pub fn send_to_model(
1108 &mut self,
1109 model: Arc<dyn LanguageModel>,
1110 window: Option<AnyWindowHandle>,
1111 cx: &mut Context<Self>,
1112 ) {
1113 if self.remaining_turns == 0 {
1114 return;
1115 }
1116
1117 self.remaining_turns -= 1;
1118
1119 let request = self.to_completion_request(model.clone(), cx);
1120
1121 self.stream_completion(request, model, window, cx);
1122 }
1123
1124 pub fn used_tools_since_last_user_message(&self) -> bool {
1125 for message in self.messages.iter().rev() {
1126 if self.tool_use.message_has_tool_results(message.id) {
1127 return true;
1128 } else if message.role == Role::User {
1129 return false;
1130 }
1131 }
1132
1133 false
1134 }
1135
1136 pub fn to_completion_request(
1137 &self,
1138 model: Arc<dyn LanguageModel>,
1139 cx: &mut Context<Self>,
1140 ) -> LanguageModelRequest {
1141 let mut request = LanguageModelRequest {
1142 thread_id: Some(self.id.to_string()),
1143 prompt_id: Some(self.last_prompt_id.to_string()),
1144 mode: None,
1145 messages: vec![],
1146 tools: Vec::new(),
1147 stop: Vec::new(),
1148 temperature: None,
1149 };
1150
1151 let available_tools = self.available_tools(cx, model.clone());
1152 let available_tool_names = available_tools
1153 .iter()
1154 .map(|tool| tool.name.clone())
1155 .collect();
1156
1157 let model_context = &ModelContext {
1158 available_tools: available_tool_names,
1159 };
1160
1161 if let Some(project_context) = self.project_context.borrow().as_ref() {
1162 match self
1163 .prompt_builder
1164 .generate_assistant_system_prompt(project_context, model_context)
1165 {
1166 Err(err) => {
1167 let message = format!("{err:?}").into();
1168 log::error!("{message}");
1169 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1170 header: "Error generating system prompt".into(),
1171 message,
1172 }));
1173 }
1174 Ok(system_prompt) => {
1175 request.messages.push(LanguageModelRequestMessage {
1176 role: Role::System,
1177 content: vec![MessageContent::Text(system_prompt)],
1178 cache: true,
1179 });
1180 }
1181 }
1182 } else {
1183 let message = "Context for system prompt unexpectedly not ready.".into();
1184 log::error!("{message}");
1185 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1186 header: "Error generating system prompt".into(),
1187 message,
1188 }));
1189 }
1190
1191 for message in &self.messages {
1192 let mut request_message = LanguageModelRequestMessage {
1193 role: message.role,
1194 content: Vec::new(),
1195 cache: false,
1196 };
1197
1198 message
1199 .loaded_context
1200 .add_to_request_message(&mut request_message);
1201
1202 for segment in &message.segments {
1203 match segment {
1204 MessageSegment::Text(text) => {
1205 if !text.is_empty() {
1206 request_message
1207 .content
1208 .push(MessageContent::Text(text.into()));
1209 }
1210 }
1211 MessageSegment::Thinking { text, signature } => {
1212 if !text.is_empty() {
1213 request_message.content.push(MessageContent::Thinking {
1214 text: text.into(),
1215 signature: signature.clone(),
1216 });
1217 }
1218 }
1219 MessageSegment::RedactedThinking(data) => {
1220 request_message
1221 .content
1222 .push(MessageContent::RedactedThinking(data.clone()));
1223 }
1224 };
1225 }
1226
1227 self.tool_use
1228 .attach_tool_uses(message.id, &mut request_message);
1229
1230 request.messages.push(request_message);
1231
1232 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1233 request.messages.push(tool_results_message);
1234 }
1235 }
1236
1237 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1238 if let Some(last) = request.messages.last_mut() {
1239 last.cache = true;
1240 }
1241
1242 self.attached_tracked_files_state(&mut request.messages, cx);
1243
1244 request.tools = available_tools;
1245 request.mode = if model.supports_max_mode() {
1246 Some(self.completion_mode.into())
1247 } else {
1248 Some(CompletionMode::Normal.into())
1249 };
1250
1251 request
1252 }
1253
1254 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1255 let mut request = LanguageModelRequest {
1256 thread_id: None,
1257 prompt_id: None,
1258 mode: None,
1259 messages: vec![],
1260 tools: Vec::new(),
1261 stop: Vec::new(),
1262 temperature: None,
1263 };
1264
1265 for message in &self.messages {
1266 let mut request_message = LanguageModelRequestMessage {
1267 role: message.role,
1268 content: Vec::new(),
1269 cache: false,
1270 };
1271
1272 for segment in &message.segments {
1273 match segment {
1274 MessageSegment::Text(text) => request_message
1275 .content
1276 .push(MessageContent::Text(text.clone())),
1277 MessageSegment::Thinking { .. } => {}
1278 MessageSegment::RedactedThinking(_) => {}
1279 }
1280 }
1281
1282 if request_message.content.is_empty() {
1283 continue;
1284 }
1285
1286 request.messages.push(request_message);
1287 }
1288
1289 request.messages.push(LanguageModelRequestMessage {
1290 role: Role::User,
1291 content: vec![MessageContent::Text(added_user_message)],
1292 cache: false,
1293 });
1294
1295 request
1296 }
1297
1298 fn attached_tracked_files_state(
1299 &self,
1300 messages: &mut Vec<LanguageModelRequestMessage>,
1301 cx: &App,
1302 ) {
1303 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1304
1305 let mut stale_message = String::new();
1306
1307 let action_log = self.action_log.read(cx);
1308
1309 for stale_file in action_log.stale_buffers(cx) {
1310 let Some(file) = stale_file.read(cx).file() else {
1311 continue;
1312 };
1313
1314 if stale_message.is_empty() {
1315 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1316 }
1317
1318 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1319 }
1320
1321 let mut content = Vec::with_capacity(2);
1322
1323 if !stale_message.is_empty() {
1324 content.push(stale_message.into());
1325 }
1326
1327 if !content.is_empty() {
1328 let context_message = LanguageModelRequestMessage {
1329 role: Role::User,
1330 content,
1331 cache: false,
1332 };
1333
1334 messages.push(context_message);
1335 }
1336 }
1337
1338 pub fn stream_completion(
1339 &mut self,
1340 request: LanguageModelRequest,
1341 model: Arc<dyn LanguageModel>,
1342 window: Option<AnyWindowHandle>,
1343 cx: &mut Context<Self>,
1344 ) {
1345 self.tool_use_limit_reached = false;
1346
1347 let pending_completion_id = post_inc(&mut self.completion_count);
1348 let mut request_callback_parameters = if self.request_callback.is_some() {
1349 Some((request.clone(), Vec::new()))
1350 } else {
1351 None
1352 };
1353 let prompt_id = self.last_prompt_id.clone();
1354 let tool_use_metadata = ToolUseMetadata {
1355 model: model.clone(),
1356 thread_id: self.id.clone(),
1357 prompt_id: prompt_id.clone(),
1358 };
1359
1360 self.last_received_chunk_at = Some(Instant::now());
1361
1362 let task = cx.spawn(async move |thread, cx| {
1363 let stream_completion_future = model.stream_completion(request, &cx);
1364 let initial_token_usage =
1365 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1366 let stream_completion = async {
1367 let mut events = stream_completion_future.await?;
1368
1369 let mut stop_reason = StopReason::EndTurn;
1370 let mut current_token_usage = TokenUsage::default();
1371
1372 thread
1373 .update(cx, |_thread, cx| {
1374 cx.emit(ThreadEvent::NewRequest);
1375 })
1376 .ok();
1377
1378 let mut request_assistant_message_id = None;
1379
1380 while let Some(event) = events.next().await {
1381 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1382 response_events
1383 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1384 }
1385
1386 thread.update(cx, |thread, cx| {
1387 let event = match event {
1388 Ok(event) => event,
1389 Err(LanguageModelCompletionError::BadInputJson {
1390 id,
1391 tool_name,
1392 raw_input: invalid_input_json,
1393 json_parse_error,
1394 }) => {
1395 thread.receive_invalid_tool_json(
1396 id,
1397 tool_name,
1398 invalid_input_json,
1399 json_parse_error,
1400 window,
1401 cx,
1402 );
1403 return Ok(());
1404 }
1405 Err(LanguageModelCompletionError::Other(error)) => {
1406 return Err(error);
1407 }
1408 };
1409
1410 match event {
1411 LanguageModelCompletionEvent::StartMessage { .. } => {
1412 request_assistant_message_id =
1413 Some(thread.insert_assistant_message(
1414 vec![MessageSegment::Text(String::new())],
1415 cx,
1416 ));
1417 }
1418 LanguageModelCompletionEvent::Stop(reason) => {
1419 stop_reason = reason;
1420 }
1421 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1422 thread.update_token_usage_at_last_message(token_usage);
1423 thread.cumulative_token_usage = thread.cumulative_token_usage
1424 + token_usage
1425 - current_token_usage;
1426 current_token_usage = token_usage;
1427 }
1428 LanguageModelCompletionEvent::Text(chunk) => {
1429 thread.received_chunk();
1430
1431 cx.emit(ThreadEvent::ReceivedTextChunk);
1432 if let Some(last_message) = thread.messages.last_mut() {
1433 if last_message.role == Role::Assistant
1434 && !thread.tool_use.has_tool_results(last_message.id)
1435 {
1436 last_message.push_text(&chunk);
1437 cx.emit(ThreadEvent::StreamedAssistantText(
1438 last_message.id,
1439 chunk,
1440 ));
1441 } else {
1442 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1443 // of a new Assistant response.
1444 //
1445 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1446 // will result in duplicating the text of the chunk in the rendered Markdown.
1447 request_assistant_message_id =
1448 Some(thread.insert_assistant_message(
1449 vec![MessageSegment::Text(chunk.to_string())],
1450 cx,
1451 ));
1452 };
1453 }
1454 }
1455 LanguageModelCompletionEvent::Thinking {
1456 text: chunk,
1457 signature,
1458 } => {
1459 thread.received_chunk();
1460
1461 if let Some(last_message) = thread.messages.last_mut() {
1462 if last_message.role == Role::Assistant
1463 && !thread.tool_use.has_tool_results(last_message.id)
1464 {
1465 last_message.push_thinking(&chunk, signature);
1466 cx.emit(ThreadEvent::StreamedAssistantThinking(
1467 last_message.id,
1468 chunk,
1469 ));
1470 } else {
1471 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1472 // of a new Assistant response.
1473 //
1474 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1475 // will result in duplicating the text of the chunk in the rendered Markdown.
1476 request_assistant_message_id =
1477 Some(thread.insert_assistant_message(
1478 vec![MessageSegment::Thinking {
1479 text: chunk.to_string(),
1480 signature,
1481 }],
1482 cx,
1483 ));
1484 };
1485 }
1486 }
1487 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1488 let last_assistant_message_id = request_assistant_message_id
1489 .unwrap_or_else(|| {
1490 let new_assistant_message_id =
1491 thread.insert_assistant_message(vec![], cx);
1492 request_assistant_message_id =
1493 Some(new_assistant_message_id);
1494 new_assistant_message_id
1495 });
1496
1497 let tool_use_id = tool_use.id.clone();
1498 let streamed_input = if tool_use.is_input_complete {
1499 None
1500 } else {
1501 Some((&tool_use.input).clone())
1502 };
1503
1504 let ui_text = thread.tool_use.request_tool_use(
1505 last_assistant_message_id,
1506 tool_use,
1507 tool_use_metadata.clone(),
1508 cx,
1509 );
1510
1511 if let Some(input) = streamed_input {
1512 cx.emit(ThreadEvent::StreamedToolUse {
1513 tool_use_id,
1514 ui_text,
1515 input,
1516 });
1517 }
1518 }
1519 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1520 if let Some(completion) = thread
1521 .pending_completions
1522 .iter_mut()
1523 .find(|completion| completion.id == pending_completion_id)
1524 {
1525 match status_update {
1526 CompletionRequestStatus::Queued {
1527 position,
1528 } => {
1529 completion.queue_state = QueueState::Queued { position };
1530 }
1531 CompletionRequestStatus::Started => {
1532 completion.queue_state = QueueState::Started;
1533 }
1534 CompletionRequestStatus::Failed {
1535 code, message
1536 } => {
1537 return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
1538 }
1539 CompletionRequestStatus::UsageUpdated {
1540 amount, limit
1541 } => {
1542 let usage = RequestUsage { limit, amount: amount as i32 };
1543
1544 thread.last_usage = Some(usage);
1545 }
1546 CompletionRequestStatus::ToolUseLimitReached => {
1547 thread.tool_use_limit_reached = true;
1548 }
1549 }
1550 }
1551 }
1552 }
1553
1554 thread.touch_updated_at();
1555 cx.emit(ThreadEvent::StreamedCompletion);
1556 cx.notify();
1557
1558 thread.auto_capture_telemetry(cx);
1559 Ok(())
1560 })??;
1561
1562 smol::future::yield_now().await;
1563 }
1564
1565 thread.update(cx, |thread, cx| {
1566 thread.last_received_chunk_at = None;
1567 thread
1568 .pending_completions
1569 .retain(|completion| completion.id != pending_completion_id);
1570
1571 // If there is a response without tool use, summarize the message. Otherwise,
1572 // allow two tool uses before summarizing.
1573 if thread.summary.is_none()
1574 && thread.messages.len() >= 2
1575 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1576 {
1577 thread.summarize(cx);
1578 }
1579 })?;
1580
1581 anyhow::Ok(stop_reason)
1582 };
1583
1584 let result = stream_completion.await;
1585
1586 thread
1587 .update(cx, |thread, cx| {
1588 thread.finalize_pending_checkpoint(cx);
1589 match result.as_ref() {
1590 Ok(stop_reason) => match stop_reason {
1591 StopReason::ToolUse => {
1592 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1593 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1594 }
1595 StopReason::EndTurn | StopReason::MaxTokens => {
1596 thread.project.update(cx, |project, cx| {
1597 project.set_agent_location(None, cx);
1598 });
1599 }
1600 },
1601 Err(error) => {
1602 thread.project.update(cx, |project, cx| {
1603 project.set_agent_location(None, cx);
1604 });
1605
1606 if error.is::<PaymentRequiredError>() {
1607 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1608 } else if error.is::<MaxMonthlySpendReachedError>() {
1609 cx.emit(ThreadEvent::ShowError(
1610 ThreadError::MaxMonthlySpendReached,
1611 ));
1612 } else if let Some(error) =
1613 error.downcast_ref::<ModelRequestLimitReachedError>()
1614 {
1615 cx.emit(ThreadEvent::ShowError(
1616 ThreadError::ModelRequestLimitReached { plan: error.plan },
1617 ));
1618 } else if let Some(known_error) =
1619 error.downcast_ref::<LanguageModelKnownError>()
1620 {
1621 match known_error {
1622 LanguageModelKnownError::ContextWindowLimitExceeded {
1623 tokens,
1624 } => {
1625 thread.exceeded_window_error = Some(ExceededWindowError {
1626 model_id: model.id(),
1627 token_count: *tokens,
1628 });
1629 cx.notify();
1630 }
1631 }
1632 } else {
1633 let error_message = error
1634 .chain()
1635 .map(|err| err.to_string())
1636 .collect::<Vec<_>>()
1637 .join("\n");
1638 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1639 header: "Error interacting with language model".into(),
1640 message: SharedString::from(error_message.clone()),
1641 }));
1642 }
1643
1644 thread.cancel_last_completion(window, cx);
1645 }
1646 }
1647 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1648
1649 if let Some((request_callback, (request, response_events))) = thread
1650 .request_callback
1651 .as_mut()
1652 .zip(request_callback_parameters.as_ref())
1653 {
1654 request_callback(request, response_events);
1655 }
1656
1657 thread.auto_capture_telemetry(cx);
1658
1659 if let Ok(initial_usage) = initial_token_usage {
1660 let usage = thread.cumulative_token_usage - initial_usage;
1661
1662 telemetry::event!(
1663 "Assistant Thread Completion",
1664 thread_id = thread.id().to_string(),
1665 prompt_id = prompt_id,
1666 model = model.telemetry_id(),
1667 model_provider = model.provider_id().to_string(),
1668 input_tokens = usage.input_tokens,
1669 output_tokens = usage.output_tokens,
1670 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1671 cache_read_input_tokens = usage.cache_read_input_tokens,
1672 );
1673 }
1674 })
1675 .ok();
1676 });
1677
1678 self.pending_completions.push(PendingCompletion {
1679 id: pending_completion_id,
1680 queue_state: QueueState::Sending,
1681 _task: task,
1682 });
1683 }
1684
1685 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1686 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1687 return;
1688 };
1689
1690 if !model.provider.is_authenticated(cx) {
1691 return;
1692 }
1693
1694 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1695 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1696 If the conversation is about a specific subject, include it in the title. \
1697 Be descriptive. DO NOT speak in the first person.";
1698
1699 let request = self.to_summarize_request(added_user_message.into());
1700
1701 self.pending_summary = cx.spawn(async move |this, cx| {
1702 async move {
1703 let mut messages = model.model.stream_completion(request, &cx).await?;
1704
1705 let mut new_summary = String::new();
1706 while let Some(event) = messages.next().await {
1707 let event = event?;
1708 let text = match event {
1709 LanguageModelCompletionEvent::Text(text) => text,
1710 LanguageModelCompletionEvent::StatusUpdate(
1711 CompletionRequestStatus::UsageUpdated { amount, limit },
1712 ) => {
1713 this.update(cx, |thread, _cx| {
1714 thread.last_usage = Some(RequestUsage {
1715 limit,
1716 amount: amount as i32,
1717 });
1718 })?;
1719 continue;
1720 }
1721 _ => continue,
1722 };
1723
1724 let mut lines = text.lines();
1725 new_summary.extend(lines.next());
1726
1727 // Stop if the LLM generated multiple lines.
1728 if lines.next().is_some() {
1729 break;
1730 }
1731 }
1732
1733 this.update(cx, |this, cx| {
1734 if !new_summary.is_empty() {
1735 this.summary = Some(new_summary.into());
1736 }
1737
1738 cx.emit(ThreadEvent::SummaryGenerated);
1739 })?;
1740
1741 anyhow::Ok(())
1742 }
1743 .log_err()
1744 .await
1745 });
1746 }
1747
1748 pub fn start_generating_detailed_summary_if_needed(
1749 &mut self,
1750 thread_store: WeakEntity<ThreadStore>,
1751 cx: &mut Context<Self>,
1752 ) {
1753 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1754 return;
1755 };
1756
1757 match &*self.detailed_summary_rx.borrow() {
1758 DetailedSummaryState::Generating { message_id, .. }
1759 | DetailedSummaryState::Generated { message_id, .. }
1760 if *message_id == last_message_id =>
1761 {
1762 // Already up-to-date
1763 return;
1764 }
1765 _ => {}
1766 }
1767
1768 let Some(ConfiguredModel { model, provider }) =
1769 LanguageModelRegistry::read_global(cx).thread_summary_model()
1770 else {
1771 return;
1772 };
1773
1774 if !provider.is_authenticated(cx) {
1775 return;
1776 }
1777
1778 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1779 1. A brief overview of what was discussed\n\
1780 2. Key facts or information discovered\n\
1781 3. Outcomes or conclusions reached\n\
1782 4. Any action items or next steps if any\n\
1783 Format it in Markdown with headings and bullet points.";
1784
1785 let request = self.to_summarize_request(added_user_message.into());
1786
1787 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1788 message_id: last_message_id,
1789 };
1790
1791 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1792 // be better to allow the old task to complete, but this would require logic for choosing
1793 // which result to prefer (the old task could complete after the new one, resulting in a
1794 // stale summary).
1795 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1796 let stream = model.stream_completion_text(request, &cx);
1797 let Some(mut messages) = stream.await.log_err() else {
1798 thread
1799 .update(cx, |thread, _cx| {
1800 *thread.detailed_summary_tx.borrow_mut() =
1801 DetailedSummaryState::NotGenerated;
1802 })
1803 .ok()?;
1804 return None;
1805 };
1806
1807 let mut new_detailed_summary = String::new();
1808
1809 while let Some(chunk) = messages.stream.next().await {
1810 if let Some(chunk) = chunk.log_err() {
1811 new_detailed_summary.push_str(&chunk);
1812 }
1813 }
1814
1815 thread
1816 .update(cx, |thread, _cx| {
1817 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1818 text: new_detailed_summary.into(),
1819 message_id: last_message_id,
1820 };
1821 })
1822 .ok()?;
1823
1824 // Save thread so its summary can be reused later
1825 if let Some(thread) = thread.upgrade() {
1826 if let Ok(Ok(save_task)) = cx.update(|cx| {
1827 thread_store
1828 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1829 }) {
1830 save_task.await.log_err();
1831 }
1832 }
1833
1834 Some(())
1835 });
1836 }
1837
1838 pub async fn wait_for_detailed_summary_or_text(
1839 this: &Entity<Self>,
1840 cx: &mut AsyncApp,
1841 ) -> Option<SharedString> {
1842 let mut detailed_summary_rx = this
1843 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1844 .ok()?;
1845 loop {
1846 match detailed_summary_rx.recv().await? {
1847 DetailedSummaryState::Generating { .. } => {}
1848 DetailedSummaryState::NotGenerated => {
1849 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1850 }
1851 DetailedSummaryState::Generated { text, .. } => return Some(text),
1852 }
1853 }
1854 }
1855
1856 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1857 self.detailed_summary_rx
1858 .borrow()
1859 .text()
1860 .unwrap_or_else(|| self.text().into())
1861 }
1862
1863 pub fn is_generating_detailed_summary(&self) -> bool {
1864 matches!(
1865 &*self.detailed_summary_rx.borrow(),
1866 DetailedSummaryState::Generating { .. }
1867 )
1868 }
1869
1870 pub fn use_pending_tools(
1871 &mut self,
1872 window: Option<AnyWindowHandle>,
1873 cx: &mut Context<Self>,
1874 model: Arc<dyn LanguageModel>,
1875 ) -> Vec<PendingToolUse> {
1876 self.auto_capture_telemetry(cx);
1877 let request = self.to_completion_request(model, cx);
1878 let messages = Arc::new(request.messages);
1879 let pending_tool_uses = self
1880 .tool_use
1881 .pending_tool_uses()
1882 .into_iter()
1883 .filter(|tool_use| tool_use.status.is_idle())
1884 .cloned()
1885 .collect::<Vec<_>>();
1886
1887 for tool_use in pending_tool_uses.iter() {
1888 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1889 if tool.needs_confirmation(&tool_use.input, cx)
1890 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1891 {
1892 self.tool_use.confirm_tool_use(
1893 tool_use.id.clone(),
1894 tool_use.ui_text.clone(),
1895 tool_use.input.clone(),
1896 messages.clone(),
1897 tool,
1898 );
1899 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1900 } else {
1901 self.run_tool(
1902 tool_use.id.clone(),
1903 tool_use.ui_text.clone(),
1904 tool_use.input.clone(),
1905 &messages,
1906 tool,
1907 window,
1908 cx,
1909 );
1910 }
1911 } else {
1912 self.handle_hallucinated_tool_use(
1913 tool_use.id.clone(),
1914 tool_use.name.clone(),
1915 window,
1916 cx,
1917 );
1918 }
1919 }
1920
1921 pending_tool_uses
1922 }
1923
1924 pub fn handle_hallucinated_tool_use(
1925 &mut self,
1926 tool_use_id: LanguageModelToolUseId,
1927 hallucinated_tool_name: Arc<str>,
1928 window: Option<AnyWindowHandle>,
1929 cx: &mut Context<Thread>,
1930 ) {
1931 let available_tools = self.tools.read(cx).enabled_tools(cx);
1932
1933 let tool_list = available_tools
1934 .iter()
1935 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1936 .collect::<Vec<_>>()
1937 .join("\n");
1938
1939 let error_message = format!(
1940 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1941 hallucinated_tool_name, tool_list
1942 );
1943
1944 let pending_tool_use = self.tool_use.insert_tool_output(
1945 tool_use_id.clone(),
1946 hallucinated_tool_name,
1947 Err(anyhow!("Missing tool call: {error_message}")),
1948 self.configured_model.as_ref(),
1949 );
1950
1951 cx.emit(ThreadEvent::MissingToolUse {
1952 tool_use_id: tool_use_id.clone(),
1953 ui_text: error_message.into(),
1954 });
1955
1956 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1957 }
1958
1959 pub fn receive_invalid_tool_json(
1960 &mut self,
1961 tool_use_id: LanguageModelToolUseId,
1962 tool_name: Arc<str>,
1963 invalid_json: Arc<str>,
1964 error: String,
1965 window: Option<AnyWindowHandle>,
1966 cx: &mut Context<Thread>,
1967 ) {
1968 log::error!("The model returned invalid input JSON: {invalid_json}");
1969
1970 let pending_tool_use = self.tool_use.insert_tool_output(
1971 tool_use_id.clone(),
1972 tool_name,
1973 Err(anyhow!("Error parsing input JSON: {error}")),
1974 self.configured_model.as_ref(),
1975 );
1976 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1977 pending_tool_use.ui_text.clone()
1978 } else {
1979 log::error!(
1980 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1981 );
1982 format!("Unknown tool {}", tool_use_id).into()
1983 };
1984
1985 cx.emit(ThreadEvent::InvalidToolInput {
1986 tool_use_id: tool_use_id.clone(),
1987 ui_text,
1988 invalid_input_json: invalid_json,
1989 });
1990
1991 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1992 }
1993
1994 pub fn run_tool(
1995 &mut self,
1996 tool_use_id: LanguageModelToolUseId,
1997 ui_text: impl Into<SharedString>,
1998 input: serde_json::Value,
1999 messages: &[LanguageModelRequestMessage],
2000 tool: Arc<dyn Tool>,
2001 window: Option<AnyWindowHandle>,
2002 cx: &mut Context<Thread>,
2003 ) {
2004 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
2005 self.tool_use
2006 .run_pending_tool(tool_use_id, ui_text.into(), task);
2007 }
2008
2009 fn spawn_tool_use(
2010 &mut self,
2011 tool_use_id: LanguageModelToolUseId,
2012 messages: &[LanguageModelRequestMessage],
2013 input: serde_json::Value,
2014 tool: Arc<dyn Tool>,
2015 window: Option<AnyWindowHandle>,
2016 cx: &mut Context<Thread>,
2017 ) -> Task<()> {
2018 let tool_name: Arc<str> = tool.name().into();
2019
2020 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2021 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2022 } else {
2023 tool.run(
2024 input,
2025 messages,
2026 self.project.clone(),
2027 self.action_log.clone(),
2028 window,
2029 cx,
2030 )
2031 };
2032
2033 // Store the card separately if it exists
2034 if let Some(card) = tool_result.card.clone() {
2035 self.tool_use
2036 .insert_tool_result_card(tool_use_id.clone(), card);
2037 }
2038
2039 cx.spawn({
2040 async move |thread: WeakEntity<Thread>, cx| {
2041 let output = tool_result.output.await;
2042
2043 thread
2044 .update(cx, |thread, cx| {
2045 let pending_tool_use = thread.tool_use.insert_tool_output(
2046 tool_use_id.clone(),
2047 tool_name,
2048 output,
2049 thread.configured_model.as_ref(),
2050 );
2051 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2052 })
2053 .ok();
2054 }
2055 })
2056 }
2057
2058 fn tool_finished(
2059 &mut self,
2060 tool_use_id: LanguageModelToolUseId,
2061 pending_tool_use: Option<PendingToolUse>,
2062 canceled: bool,
2063 window: Option<AnyWindowHandle>,
2064 cx: &mut Context<Self>,
2065 ) {
2066 if self.all_tools_finished() {
2067 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2068 if !canceled {
2069 self.send_to_model(model.clone(), window, cx);
2070 }
2071 self.auto_capture_telemetry(cx);
2072 }
2073 }
2074
2075 cx.emit(ThreadEvent::ToolFinished {
2076 tool_use_id,
2077 pending_tool_use,
2078 });
2079 }
2080
2081 /// Cancels the last pending completion, if there are any pending.
2082 ///
2083 /// Returns whether a completion was canceled.
2084 pub fn cancel_last_completion(
2085 &mut self,
2086 window: Option<AnyWindowHandle>,
2087 cx: &mut Context<Self>,
2088 ) -> bool {
2089 let mut canceled = self.pending_completions.pop().is_some();
2090
2091 for pending_tool_use in self.tool_use.cancel_pending() {
2092 canceled = true;
2093 self.tool_finished(
2094 pending_tool_use.id.clone(),
2095 Some(pending_tool_use),
2096 true,
2097 window,
2098 cx,
2099 );
2100 }
2101
2102 self.finalize_pending_checkpoint(cx);
2103
2104 if canceled {
2105 cx.emit(ThreadEvent::CompletionCanceled);
2106 }
2107
2108 canceled
2109 }
2110
2111 /// Signals that any in-progress editing should be canceled.
2112 ///
2113 /// This method is used to notify listeners (like ActiveThread) that
2114 /// they should cancel any editing operations.
2115 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2116 cx.emit(ThreadEvent::CancelEditing);
2117 }
2118
2119 pub fn feedback(&self) -> Option<ThreadFeedback> {
2120 self.feedback
2121 }
2122
2123 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2124 self.message_feedback.get(&message_id).copied()
2125 }
2126
2127 pub fn report_message_feedback(
2128 &mut self,
2129 message_id: MessageId,
2130 feedback: ThreadFeedback,
2131 cx: &mut Context<Self>,
2132 ) -> Task<Result<()>> {
2133 if self.message_feedback.get(&message_id) == Some(&feedback) {
2134 return Task::ready(Ok(()));
2135 }
2136
2137 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2138 let serialized_thread = self.serialize(cx);
2139 let thread_id = self.id().clone();
2140 let client = self.project.read(cx).client();
2141
2142 let enabled_tool_names: Vec<String> = self
2143 .tools()
2144 .read(cx)
2145 .enabled_tools(cx)
2146 .iter()
2147 .map(|tool| tool.name().to_string())
2148 .collect();
2149
2150 self.message_feedback.insert(message_id, feedback);
2151
2152 cx.notify();
2153
2154 let message_content = self
2155 .message(message_id)
2156 .map(|msg| msg.to_string())
2157 .unwrap_or_default();
2158
2159 cx.background_spawn(async move {
2160 let final_project_snapshot = final_project_snapshot.await;
2161 let serialized_thread = serialized_thread.await?;
2162 let thread_data =
2163 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2164
2165 let rating = match feedback {
2166 ThreadFeedback::Positive => "positive",
2167 ThreadFeedback::Negative => "negative",
2168 };
2169 telemetry::event!(
2170 "Assistant Thread Rated",
2171 rating,
2172 thread_id,
2173 enabled_tool_names,
2174 message_id = message_id.0,
2175 message_content,
2176 thread_data,
2177 final_project_snapshot
2178 );
2179 client.telemetry().flush_events().await;
2180
2181 Ok(())
2182 })
2183 }
2184
2185 pub fn report_feedback(
2186 &mut self,
2187 feedback: ThreadFeedback,
2188 cx: &mut Context<Self>,
2189 ) -> Task<Result<()>> {
2190 let last_assistant_message_id = self
2191 .messages
2192 .iter()
2193 .rev()
2194 .find(|msg| msg.role == Role::Assistant)
2195 .map(|msg| msg.id);
2196
2197 if let Some(message_id) = last_assistant_message_id {
2198 self.report_message_feedback(message_id, feedback, cx)
2199 } else {
2200 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2201 let serialized_thread = self.serialize(cx);
2202 let thread_id = self.id().clone();
2203 let client = self.project.read(cx).client();
2204 self.feedback = Some(feedback);
2205 cx.notify();
2206
2207 cx.background_spawn(async move {
2208 let final_project_snapshot = final_project_snapshot.await;
2209 let serialized_thread = serialized_thread.await?;
2210 let thread_data = serde_json::to_value(serialized_thread)
2211 .unwrap_or_else(|_| serde_json::Value::Null);
2212
2213 let rating = match feedback {
2214 ThreadFeedback::Positive => "positive",
2215 ThreadFeedback::Negative => "negative",
2216 };
2217 telemetry::event!(
2218 "Assistant Thread Rated",
2219 rating,
2220 thread_id,
2221 thread_data,
2222 final_project_snapshot
2223 );
2224 client.telemetry().flush_events().await;
2225
2226 Ok(())
2227 })
2228 }
2229 }
2230
2231 /// Create a snapshot of the current project state including git information and unsaved buffers.
2232 fn project_snapshot(
2233 project: Entity<Project>,
2234 cx: &mut Context<Self>,
2235 ) -> Task<Arc<ProjectSnapshot>> {
2236 let git_store = project.read(cx).git_store().clone();
2237 let worktree_snapshots: Vec<_> = project
2238 .read(cx)
2239 .visible_worktrees(cx)
2240 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2241 .collect();
2242
2243 cx.spawn(async move |_, cx| {
2244 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2245
2246 let mut unsaved_buffers = Vec::new();
2247 cx.update(|app_cx| {
2248 let buffer_store = project.read(app_cx).buffer_store();
2249 for buffer_handle in buffer_store.read(app_cx).buffers() {
2250 let buffer = buffer_handle.read(app_cx);
2251 if buffer.is_dirty() {
2252 if let Some(file) = buffer.file() {
2253 let path = file.path().to_string_lossy().to_string();
2254 unsaved_buffers.push(path);
2255 }
2256 }
2257 }
2258 })
2259 .ok();
2260
2261 Arc::new(ProjectSnapshot {
2262 worktree_snapshots,
2263 unsaved_buffer_paths: unsaved_buffers,
2264 timestamp: Utc::now(),
2265 })
2266 })
2267 }
2268
2269 fn worktree_snapshot(
2270 worktree: Entity<project::Worktree>,
2271 git_store: Entity<GitStore>,
2272 cx: &App,
2273 ) -> Task<WorktreeSnapshot> {
2274 cx.spawn(async move |cx| {
2275 // Get worktree path and snapshot
2276 let worktree_info = cx.update(|app_cx| {
2277 let worktree = worktree.read(app_cx);
2278 let path = worktree.abs_path().to_string_lossy().to_string();
2279 let snapshot = worktree.snapshot();
2280 (path, snapshot)
2281 });
2282
2283 let Ok((worktree_path, _snapshot)) = worktree_info else {
2284 return WorktreeSnapshot {
2285 worktree_path: String::new(),
2286 git_state: None,
2287 };
2288 };
2289
2290 let git_state = git_store
2291 .update(cx, |git_store, cx| {
2292 git_store
2293 .repositories()
2294 .values()
2295 .find(|repo| {
2296 repo.read(cx)
2297 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2298 .is_some()
2299 })
2300 .cloned()
2301 })
2302 .ok()
2303 .flatten()
2304 .map(|repo| {
2305 repo.update(cx, |repo, _| {
2306 let current_branch =
2307 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2308 repo.send_job(None, |state, _| async move {
2309 let RepositoryState::Local { backend, .. } = state else {
2310 return GitState {
2311 remote_url: None,
2312 head_sha: None,
2313 current_branch,
2314 diff: None,
2315 };
2316 };
2317
2318 let remote_url = backend.remote_url("origin");
2319 let head_sha = backend.head_sha().await;
2320 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2321
2322 GitState {
2323 remote_url,
2324 head_sha,
2325 current_branch,
2326 diff,
2327 }
2328 })
2329 })
2330 });
2331
2332 let git_state = match git_state {
2333 Some(git_state) => match git_state.ok() {
2334 Some(git_state) => git_state.await.ok(),
2335 None => None,
2336 },
2337 None => None,
2338 };
2339
2340 WorktreeSnapshot {
2341 worktree_path,
2342 git_state,
2343 }
2344 })
2345 }
2346
2347 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2348 let mut markdown = Vec::new();
2349
2350 if let Some(summary) = self.summary() {
2351 writeln!(markdown, "# {summary}\n")?;
2352 };
2353
2354 for message in self.messages() {
2355 writeln!(
2356 markdown,
2357 "## {role}\n",
2358 role = match message.role {
2359 Role::User => "User",
2360 Role::Assistant => "Agent",
2361 Role::System => "System",
2362 }
2363 )?;
2364
2365 if !message.loaded_context.text.is_empty() {
2366 writeln!(markdown, "{}", message.loaded_context.text)?;
2367 }
2368
2369 if !message.loaded_context.images.is_empty() {
2370 writeln!(
2371 markdown,
2372 "\n{} images attached as context.\n",
2373 message.loaded_context.images.len()
2374 )?;
2375 }
2376
2377 for segment in &message.segments {
2378 match segment {
2379 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2380 MessageSegment::Thinking { text, .. } => {
2381 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2382 }
2383 MessageSegment::RedactedThinking(_) => {}
2384 }
2385 }
2386
2387 for tool_use in self.tool_uses_for_message(message.id, cx) {
2388 writeln!(
2389 markdown,
2390 "**Use Tool: {} ({})**",
2391 tool_use.name, tool_use.id
2392 )?;
2393 writeln!(markdown, "```json")?;
2394 writeln!(
2395 markdown,
2396 "{}",
2397 serde_json::to_string_pretty(&tool_use.input)?
2398 )?;
2399 writeln!(markdown, "```")?;
2400 }
2401
2402 for tool_result in self.tool_results_for_message(message.id) {
2403 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2404 if tool_result.is_error {
2405 write!(markdown, " (Error)")?;
2406 }
2407
2408 writeln!(markdown, "**\n")?;
2409 writeln!(markdown, "{}", tool_result.content)?;
2410 }
2411 }
2412
2413 Ok(String::from_utf8_lossy(&markdown).to_string())
2414 }
2415
2416 pub fn keep_edits_in_range(
2417 &mut self,
2418 buffer: Entity<language::Buffer>,
2419 buffer_range: Range<language::Anchor>,
2420 cx: &mut Context<Self>,
2421 ) {
2422 self.action_log.update(cx, |action_log, cx| {
2423 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2424 });
2425 }
2426
2427 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2428 self.action_log
2429 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2430 }
2431
2432 pub fn reject_edits_in_ranges(
2433 &mut self,
2434 buffer: Entity<language::Buffer>,
2435 buffer_ranges: Vec<Range<language::Anchor>>,
2436 cx: &mut Context<Self>,
2437 ) -> Task<Result<()>> {
2438 self.action_log.update(cx, |action_log, cx| {
2439 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2440 })
2441 }
2442
2443 pub fn action_log(&self) -> &Entity<ActionLog> {
2444 &self.action_log
2445 }
2446
2447 pub fn project(&self) -> &Entity<Project> {
2448 &self.project
2449 }
2450
2451 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2452 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2453 return;
2454 }
2455
2456 let now = Instant::now();
2457 if let Some(last) = self.last_auto_capture_at {
2458 if now.duration_since(last).as_secs() < 10 {
2459 return;
2460 }
2461 }
2462
2463 self.last_auto_capture_at = Some(now);
2464
2465 let thread_id = self.id().clone();
2466 let github_login = self
2467 .project
2468 .read(cx)
2469 .user_store()
2470 .read(cx)
2471 .current_user()
2472 .map(|user| user.github_login.clone());
2473 let client = self.project.read(cx).client().clone();
2474 let serialize_task = self.serialize(cx);
2475
2476 cx.background_executor()
2477 .spawn(async move {
2478 if let Ok(serialized_thread) = serialize_task.await {
2479 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2480 telemetry::event!(
2481 "Agent Thread Auto-Captured",
2482 thread_id = thread_id.to_string(),
2483 thread_data = thread_data,
2484 auto_capture_reason = "tracked_user",
2485 github_login = github_login
2486 );
2487
2488 client.telemetry().flush_events().await;
2489 }
2490 }
2491 })
2492 .detach();
2493 }
2494
2495 pub fn cumulative_token_usage(&self) -> TokenUsage {
2496 self.cumulative_token_usage
2497 }
2498
2499 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2500 let Some(model) = self.configured_model.as_ref() else {
2501 return TotalTokenUsage::default();
2502 };
2503
2504 let max = model.model.max_token_count();
2505
2506 let index = self
2507 .messages
2508 .iter()
2509 .position(|msg| msg.id == message_id)
2510 .unwrap_or(0);
2511
2512 if index == 0 {
2513 return TotalTokenUsage { total: 0, max };
2514 }
2515
2516 let token_usage = &self
2517 .request_token_usage
2518 .get(index - 1)
2519 .cloned()
2520 .unwrap_or_default();
2521
2522 TotalTokenUsage {
2523 total: token_usage.total_tokens() as usize,
2524 max,
2525 }
2526 }
2527
2528 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2529 let model = self.configured_model.as_ref()?;
2530
2531 let max = model.model.max_token_count();
2532
2533 if let Some(exceeded_error) = &self.exceeded_window_error {
2534 if model.model.id() == exceeded_error.model_id {
2535 return Some(TotalTokenUsage {
2536 total: exceeded_error.token_count,
2537 max,
2538 });
2539 }
2540 }
2541
2542 let total = self
2543 .token_usage_at_last_message()
2544 .unwrap_or_default()
2545 .total_tokens() as usize;
2546
2547 Some(TotalTokenUsage { total, max })
2548 }
2549
2550 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2551 self.request_token_usage
2552 .get(self.messages.len().saturating_sub(1))
2553 .or_else(|| self.request_token_usage.last())
2554 .cloned()
2555 }
2556
2557 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2558 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2559 self.request_token_usage
2560 .resize(self.messages.len(), placeholder);
2561
2562 if let Some(last) = self.request_token_usage.last_mut() {
2563 *last = token_usage;
2564 }
2565 }
2566
2567 pub fn deny_tool_use(
2568 &mut self,
2569 tool_use_id: LanguageModelToolUseId,
2570 tool_name: Arc<str>,
2571 window: Option<AnyWindowHandle>,
2572 cx: &mut Context<Self>,
2573 ) {
2574 let err = Err(anyhow::anyhow!(
2575 "Permission to run tool action denied by user"
2576 ));
2577
2578 self.tool_use.insert_tool_output(
2579 tool_use_id.clone(),
2580 tool_name,
2581 err,
2582 self.configured_model.as_ref(),
2583 );
2584 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2585 }
2586}
2587
2588#[derive(Debug, Clone, Error)]
2589pub enum ThreadError {
2590 #[error("Payment required")]
2591 PaymentRequired,
2592 #[error("Max monthly spend reached")]
2593 MaxMonthlySpendReached,
2594 #[error("Model request limit reached")]
2595 ModelRequestLimitReached { plan: Plan },
2596 #[error("Message {header}: {message}")]
2597 Message {
2598 header: SharedString,
2599 message: SharedString,
2600 },
2601}
2602
2603#[derive(Debug, Clone)]
2604pub enum ThreadEvent {
2605 ShowError(ThreadError),
2606 StreamedCompletion,
2607 ReceivedTextChunk,
2608 NewRequest,
2609 StreamedAssistantText(MessageId, String),
2610 StreamedAssistantThinking(MessageId, String),
2611 StreamedToolUse {
2612 tool_use_id: LanguageModelToolUseId,
2613 ui_text: Arc<str>,
2614 input: serde_json::Value,
2615 },
2616 MissingToolUse {
2617 tool_use_id: LanguageModelToolUseId,
2618 ui_text: Arc<str>,
2619 },
2620 InvalidToolInput {
2621 tool_use_id: LanguageModelToolUseId,
2622 ui_text: Arc<str>,
2623 invalid_input_json: Arc<str>,
2624 },
2625 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2626 MessageAdded(MessageId),
2627 MessageEdited(MessageId),
2628 MessageDeleted(MessageId),
2629 SummaryGenerated,
2630 SummaryChanged,
2631 UsePendingTools {
2632 tool_uses: Vec<PendingToolUse>,
2633 },
2634 ToolFinished {
2635 #[allow(unused)]
2636 tool_use_id: LanguageModelToolUseId,
2637 /// The pending tool use that corresponds to this tool.
2638 pending_tool_use: Option<PendingToolUse>,
2639 },
2640 CheckpointChanged,
2641 ToolConfirmationNeeded,
2642 CancelEditing,
2643 CompletionCanceled,
2644}
2645
2646impl EventEmitter<ThreadEvent> for Thread {}
2647
2648struct PendingCompletion {
2649 id: usize,
2650 queue_state: QueueState,
2651 _task: Task<()>,
2652}
2653
2654#[cfg(test)]
2655mod tests {
2656 use super::*;
2657 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2658 use assistant_settings::AssistantSettings;
2659 use assistant_tool::ToolRegistry;
2660 use editor::EditorSettings;
2661 use gpui::TestAppContext;
2662 use language_model::fake_provider::FakeLanguageModel;
2663 use project::{FakeFs, Project};
2664 use prompt_store::PromptBuilder;
2665 use serde_json::json;
2666 use settings::{Settings, SettingsStore};
2667 use std::sync::Arc;
2668 use theme::ThemeSettings;
2669 use util::path;
2670 use workspace::Workspace;
2671
2672 #[gpui::test]
2673 async fn test_message_with_context(cx: &mut TestAppContext) {
2674 init_test_settings(cx);
2675
2676 let project = create_test_project(
2677 cx,
2678 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2679 )
2680 .await;
2681
2682 let (_workspace, _thread_store, thread, context_store, model) =
2683 setup_test_environment(cx, project.clone()).await;
2684
2685 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2686 .await
2687 .unwrap();
2688
2689 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2690 let loaded_context = cx
2691 .update(|cx| load_context(vec![context], &project, &None, cx))
2692 .await;
2693
2694 // Insert user message with context
2695 let message_id = thread.update(cx, |thread, cx| {
2696 thread.insert_user_message(
2697 "Please explain this code",
2698 loaded_context,
2699 None,
2700 Vec::new(),
2701 cx,
2702 )
2703 });
2704
2705 // Check content and context in message object
2706 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2707
2708 // Use different path format strings based on platform for the test
2709 #[cfg(windows)]
2710 let path_part = r"test\code.rs";
2711 #[cfg(not(windows))]
2712 let path_part = "test/code.rs";
2713
2714 let expected_context = format!(
2715 r#"
2716<context>
2717The following items were attached by the user. They are up-to-date and don't need to be re-read.
2718
2719<files>
2720```rs {path_part}
2721fn main() {{
2722 println!("Hello, world!");
2723}}
2724```
2725</files>
2726</context>
2727"#
2728 );
2729
2730 assert_eq!(message.role, Role::User);
2731 assert_eq!(message.segments.len(), 1);
2732 assert_eq!(
2733 message.segments[0],
2734 MessageSegment::Text("Please explain this code".to_string())
2735 );
2736 assert_eq!(message.loaded_context.text, expected_context);
2737
2738 // Check message in request
2739 let request = thread.update(cx, |thread, cx| {
2740 thread.to_completion_request(model.clone(), cx)
2741 });
2742
2743 assert_eq!(request.messages.len(), 2);
2744 let expected_full_message = format!("{}Please explain this code", expected_context);
2745 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2746 }
2747
2748 #[gpui::test]
2749 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2750 init_test_settings(cx);
2751
2752 let project = create_test_project(
2753 cx,
2754 json!({
2755 "file1.rs": "fn function1() {}\n",
2756 "file2.rs": "fn function2() {}\n",
2757 "file3.rs": "fn function3() {}\n",
2758 "file4.rs": "fn function4() {}\n",
2759 }),
2760 )
2761 .await;
2762
2763 let (_, _thread_store, thread, context_store, model) =
2764 setup_test_environment(cx, project.clone()).await;
2765
2766 // First message with context 1
2767 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2768 .await
2769 .unwrap();
2770 let new_contexts = context_store.update(cx, |store, cx| {
2771 store.new_context_for_thread(thread.read(cx), None)
2772 });
2773 assert_eq!(new_contexts.len(), 1);
2774 let loaded_context = cx
2775 .update(|cx| load_context(new_contexts, &project, &None, cx))
2776 .await;
2777 let message1_id = thread.update(cx, |thread, cx| {
2778 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2779 });
2780
2781 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2782 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2783 .await
2784 .unwrap();
2785 let new_contexts = context_store.update(cx, |store, cx| {
2786 store.new_context_for_thread(thread.read(cx), None)
2787 });
2788 assert_eq!(new_contexts.len(), 1);
2789 let loaded_context = cx
2790 .update(|cx| load_context(new_contexts, &project, &None, cx))
2791 .await;
2792 let message2_id = thread.update(cx, |thread, cx| {
2793 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2794 });
2795
2796 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2797 //
2798 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2799 .await
2800 .unwrap();
2801 let new_contexts = context_store.update(cx, |store, cx| {
2802 store.new_context_for_thread(thread.read(cx), None)
2803 });
2804 assert_eq!(new_contexts.len(), 1);
2805 let loaded_context = cx
2806 .update(|cx| load_context(new_contexts, &project, &None, cx))
2807 .await;
2808 let message3_id = thread.update(cx, |thread, cx| {
2809 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2810 });
2811
2812 // Check what contexts are included in each message
2813 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2814 (
2815 thread.message(message1_id).unwrap().clone(),
2816 thread.message(message2_id).unwrap().clone(),
2817 thread.message(message3_id).unwrap().clone(),
2818 )
2819 });
2820
2821 // First message should include context 1
2822 assert!(message1.loaded_context.text.contains("file1.rs"));
2823
2824 // Second message should include only context 2 (not 1)
2825 assert!(!message2.loaded_context.text.contains("file1.rs"));
2826 assert!(message2.loaded_context.text.contains("file2.rs"));
2827
2828 // Third message should include only context 3 (not 1 or 2)
2829 assert!(!message3.loaded_context.text.contains("file1.rs"));
2830 assert!(!message3.loaded_context.text.contains("file2.rs"));
2831 assert!(message3.loaded_context.text.contains("file3.rs"));
2832
2833 // Check entire request to make sure all contexts are properly included
2834 let request = thread.update(cx, |thread, cx| {
2835 thread.to_completion_request(model.clone(), cx)
2836 });
2837
2838 // The request should contain all 3 messages
2839 assert_eq!(request.messages.len(), 4);
2840
2841 // Check that the contexts are properly formatted in each message
2842 assert!(request.messages[1].string_contents().contains("file1.rs"));
2843 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2844 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2845
2846 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2847 assert!(request.messages[2].string_contents().contains("file2.rs"));
2848 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2849
2850 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2851 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2852 assert!(request.messages[3].string_contents().contains("file3.rs"));
2853
2854 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2855 .await
2856 .unwrap();
2857 let new_contexts = context_store.update(cx, |store, cx| {
2858 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2859 });
2860 assert_eq!(new_contexts.len(), 3);
2861 let loaded_context = cx
2862 .update(|cx| load_context(new_contexts, &project, &None, cx))
2863 .await
2864 .loaded_context;
2865
2866 assert!(!loaded_context.text.contains("file1.rs"));
2867 assert!(loaded_context.text.contains("file2.rs"));
2868 assert!(loaded_context.text.contains("file3.rs"));
2869 assert!(loaded_context.text.contains("file4.rs"));
2870
2871 let new_contexts = context_store.update(cx, |store, cx| {
2872 // Remove file4.rs
2873 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2874 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2875 });
2876 assert_eq!(new_contexts.len(), 2);
2877 let loaded_context = cx
2878 .update(|cx| load_context(new_contexts, &project, &None, cx))
2879 .await
2880 .loaded_context;
2881
2882 assert!(!loaded_context.text.contains("file1.rs"));
2883 assert!(loaded_context.text.contains("file2.rs"));
2884 assert!(loaded_context.text.contains("file3.rs"));
2885 assert!(!loaded_context.text.contains("file4.rs"));
2886
2887 let new_contexts = context_store.update(cx, |store, cx| {
2888 // Remove file3.rs
2889 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2890 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2891 });
2892 assert_eq!(new_contexts.len(), 1);
2893 let loaded_context = cx
2894 .update(|cx| load_context(new_contexts, &project, &None, cx))
2895 .await
2896 .loaded_context;
2897
2898 assert!(!loaded_context.text.contains("file1.rs"));
2899 assert!(loaded_context.text.contains("file2.rs"));
2900 assert!(!loaded_context.text.contains("file3.rs"));
2901 assert!(!loaded_context.text.contains("file4.rs"));
2902 }
2903
2904 #[gpui::test]
2905 async fn test_message_without_files(cx: &mut TestAppContext) {
2906 init_test_settings(cx);
2907
2908 let project = create_test_project(
2909 cx,
2910 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2911 )
2912 .await;
2913
2914 let (_, _thread_store, thread, _context_store, model) =
2915 setup_test_environment(cx, project.clone()).await;
2916
2917 // Insert user message without any context (empty context vector)
2918 let message_id = thread.update(cx, |thread, cx| {
2919 thread.insert_user_message(
2920 "What is the best way to learn Rust?",
2921 ContextLoadResult::default(),
2922 None,
2923 Vec::new(),
2924 cx,
2925 )
2926 });
2927
2928 // Check content and context in message object
2929 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2930
2931 // Context should be empty when no files are included
2932 assert_eq!(message.role, Role::User);
2933 assert_eq!(message.segments.len(), 1);
2934 assert_eq!(
2935 message.segments[0],
2936 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2937 );
2938 assert_eq!(message.loaded_context.text, "");
2939
2940 // Check message in request
2941 let request = thread.update(cx, |thread, cx| {
2942 thread.to_completion_request(model.clone(), cx)
2943 });
2944
2945 assert_eq!(request.messages.len(), 2);
2946 assert_eq!(
2947 request.messages[1].string_contents(),
2948 "What is the best way to learn Rust?"
2949 );
2950
2951 // Add second message, also without context
2952 let message2_id = thread.update(cx, |thread, cx| {
2953 thread.insert_user_message(
2954 "Are there any good books?",
2955 ContextLoadResult::default(),
2956 None,
2957 Vec::new(),
2958 cx,
2959 )
2960 });
2961
2962 let message2 =
2963 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2964 assert_eq!(message2.loaded_context.text, "");
2965
2966 // Check that both messages appear in the request
2967 let request = thread.update(cx, |thread, cx| {
2968 thread.to_completion_request(model.clone(), cx)
2969 });
2970
2971 assert_eq!(request.messages.len(), 3);
2972 assert_eq!(
2973 request.messages[1].string_contents(),
2974 "What is the best way to learn Rust?"
2975 );
2976 assert_eq!(
2977 request.messages[2].string_contents(),
2978 "Are there any good books?"
2979 );
2980 }
2981
2982 #[gpui::test]
2983 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2984 init_test_settings(cx);
2985
2986 let project = create_test_project(
2987 cx,
2988 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2989 )
2990 .await;
2991
2992 let (_workspace, _thread_store, thread, context_store, model) =
2993 setup_test_environment(cx, project.clone()).await;
2994
2995 // Open buffer and add it to context
2996 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2997 .await
2998 .unwrap();
2999
3000 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3001 let loaded_context = cx
3002 .update(|cx| load_context(vec![context], &project, &None, cx))
3003 .await;
3004
3005 // Insert user message with the buffer as context
3006 thread.update(cx, |thread, cx| {
3007 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3008 });
3009
3010 // Create a request and check that it doesn't have a stale buffer warning yet
3011 let initial_request = thread.update(cx, |thread, cx| {
3012 thread.to_completion_request(model.clone(), cx)
3013 });
3014
3015 // Make sure we don't have a stale file warning yet
3016 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3017 msg.string_contents()
3018 .contains("These files changed since last read:")
3019 });
3020 assert!(
3021 !has_stale_warning,
3022 "Should not have stale buffer warning before buffer is modified"
3023 );
3024
3025 // Modify the buffer
3026 buffer.update(cx, |buffer, cx| {
3027 // Find a position at the end of line 1
3028 buffer.edit(
3029 [(1..1, "\n println!(\"Added a new line\");\n")],
3030 None,
3031 cx,
3032 );
3033 });
3034
3035 // Insert another user message without context
3036 thread.update(cx, |thread, cx| {
3037 thread.insert_user_message(
3038 "What does the code do now?",
3039 ContextLoadResult::default(),
3040 None,
3041 Vec::new(),
3042 cx,
3043 )
3044 });
3045
3046 // Create a new request and check for the stale buffer warning
3047 let new_request = thread.update(cx, |thread, cx| {
3048 thread.to_completion_request(model.clone(), cx)
3049 });
3050
3051 // We should have a stale file warning as the last message
3052 let last_message = new_request
3053 .messages
3054 .last()
3055 .expect("Request should have messages");
3056
3057 // The last message should be the stale buffer notification
3058 assert_eq!(last_message.role, Role::User);
3059
3060 // Check the exact content of the message
3061 let expected_content = "These files changed since last read:\n- code.rs\n";
3062 assert_eq!(
3063 last_message.string_contents(),
3064 expected_content,
3065 "Last message should be exactly the stale buffer notification"
3066 );
3067 }
3068
3069 fn init_test_settings(cx: &mut TestAppContext) {
3070 cx.update(|cx| {
3071 let settings_store = SettingsStore::test(cx);
3072 cx.set_global(settings_store);
3073 language::init(cx);
3074 Project::init_settings(cx);
3075 AssistantSettings::register(cx);
3076 prompt_store::init(cx);
3077 thread_store::init(cx);
3078 workspace::init_settings(cx);
3079 language_model::init_settings(cx);
3080 ThemeSettings::register(cx);
3081 EditorSettings::register(cx);
3082 ToolRegistry::default_global(cx);
3083 });
3084 }
3085
3086 // Helper to create a test project with test files
3087 async fn create_test_project(
3088 cx: &mut TestAppContext,
3089 files: serde_json::Value,
3090 ) -> Entity<Project> {
3091 let fs = FakeFs::new(cx.executor());
3092 fs.insert_tree(path!("/test"), files).await;
3093 Project::test(fs, [path!("/test").as_ref()], cx).await
3094 }
3095
3096 async fn setup_test_environment(
3097 cx: &mut TestAppContext,
3098 project: Entity<Project>,
3099 ) -> (
3100 Entity<Workspace>,
3101 Entity<ThreadStore>,
3102 Entity<Thread>,
3103 Entity<ContextStore>,
3104 Arc<dyn LanguageModel>,
3105 ) {
3106 let (workspace, cx) =
3107 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3108
3109 let thread_store = cx
3110 .update(|_, cx| {
3111 ThreadStore::load(
3112 project.clone(),
3113 cx.new(|_| ToolWorkingSet::default()),
3114 None,
3115 Arc::new(PromptBuilder::new(None).unwrap()),
3116 cx,
3117 )
3118 })
3119 .await
3120 .unwrap();
3121
3122 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3123 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3124
3125 let model = FakeLanguageModel::default();
3126 let model: Arc<dyn LanguageModel> = Arc::new(model);
3127
3128 (workspace, thread_store, thread, context_store, model)
3129 }
3130
3131 async fn add_file_to_context(
3132 project: &Entity<Project>,
3133 context_store: &Entity<ContextStore>,
3134 path: &str,
3135 cx: &mut TestAppContext,
3136 ) -> Result<Entity<language::Buffer>> {
3137 let buffer_path = project
3138 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3139 .unwrap();
3140
3141 let buffer = project
3142 .update(cx, |project, cx| {
3143 project.open_buffer(buffer_path.clone(), cx)
3144 })
3145 .await
3146 .unwrap();
3147
3148 context_store.update(cx, |context_store, cx| {
3149 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3150 });
3151
3152 Ok(buffer)
3153 }
3154}