1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use util::{ResultExt as _, TryFutureExt as _, post_inc};
39use uuid::Uuid;
40use zed_llm_client::{CompletionMode, CompletionRequestStatus};
41
42use crate::ThreadStore;
43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
44use crate::thread_store::{
45 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
46 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
47};
48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
101#[derive(Clone, Debug)]
102pub struct MessageCrease {
103 pub range: Range<usize>,
104 pub metadata: CreaseMetadata,
105 /// None for a deserialized message, Some otherwise.
106 pub context: Option<AgentContextHandle>,
107}
108
109/// A message in a [`Thread`].
110#[derive(Debug, Clone)]
111pub struct Message {
112 pub id: MessageId,
113 pub role: Role,
114 pub segments: Vec<MessageSegment>,
115 pub loaded_context: LoadedContext,
116 pub creases: Vec<MessageCrease>,
117}
118
119impl Message {
120 /// Returns whether the message contains any meaningful text that should be displayed
121 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
122 pub fn should_display_content(&self) -> bool {
123 self.segments.iter().all(|segment| segment.should_display())
124 }
125
126 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
127 if let Some(MessageSegment::Thinking {
128 text: segment,
129 signature: current_signature,
130 }) = self.segments.last_mut()
131 {
132 if let Some(signature) = signature {
133 *current_signature = Some(signature);
134 }
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Thinking {
138 text: text.to_string(),
139 signature,
140 });
141 }
142 }
143
144 pub fn push_text(&mut self, text: &str) {
145 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
146 segment.push_str(text);
147 } else {
148 self.segments.push(MessageSegment::Text(text.to_string()));
149 }
150 }
151
152 pub fn to_string(&self) -> String {
153 let mut result = String::new();
154
155 if !self.loaded_context.text.is_empty() {
156 result.push_str(&self.loaded_context.text);
157 }
158
159 for segment in &self.segments {
160 match segment {
161 MessageSegment::Text(text) => result.push_str(text),
162 MessageSegment::Thinking { text, .. } => {
163 result.push_str("<think>\n");
164 result.push_str(text);
165 result.push_str("\n</think>");
166 }
167 MessageSegment::RedactedThinking(_) => {}
168 }
169 }
170
171 result
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum MessageSegment {
177 Text(String),
178 Thinking {
179 text: String,
180 signature: Option<String>,
181 },
182 RedactedThinking(Vec<u8>),
183}
184
185impl MessageSegment {
186 pub fn should_display(&self) -> bool {
187 match self {
188 Self::Text(text) => text.is_empty(),
189 Self::Thinking { text, .. } => text.is_empty(),
190 Self::RedactedThinking(_) => false,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ProjectSnapshot {
197 pub worktree_snapshots: Vec<WorktreeSnapshot>,
198 pub unsaved_buffer_paths: Vec<String>,
199 pub timestamp: DateTime<Utc>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WorktreeSnapshot {
204 pub worktree_path: String,
205 pub git_state: Option<GitState>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GitState {
210 pub remote_url: Option<String>,
211 pub head_sha: Option<String>,
212 pub current_branch: Option<String>,
213 pub diff: Option<String>,
214}
215
216#[derive(Clone)]
217pub struct ThreadCheckpoint {
218 message_id: MessageId,
219 git_checkpoint: GitStoreCheckpoint,
220}
221
222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
223pub enum ThreadFeedback {
224 Positive,
225 Negative,
226}
227
228pub enum LastRestoreCheckpoint {
229 Pending {
230 message_id: MessageId,
231 },
232 Error {
233 message_id: MessageId,
234 error: String,
235 },
236}
237
238impl LastRestoreCheckpoint {
239 pub fn message_id(&self) -> MessageId {
240 match self {
241 LastRestoreCheckpoint::Pending { message_id } => *message_id,
242 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub enum DetailedSummaryState {
249 #[default]
250 NotGenerated,
251 Generating {
252 message_id: MessageId,
253 },
254 Generated {
255 text: SharedString,
256 message_id: MessageId,
257 },
258}
259
260impl DetailedSummaryState {
261 fn text(&self) -> Option<SharedString> {
262 if let Self::Generated { text, .. } = self {
263 Some(text.clone())
264 } else {
265 None
266 }
267 }
268}
269
270#[derive(Default)]
271pub struct TotalTokenUsage {
272 pub total: usize,
273 pub max: usize,
274}
275
276impl TotalTokenUsage {
277 pub fn ratio(&self) -> TokenUsageRatio {
278 #[cfg(debug_assertions)]
279 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
280 .unwrap_or("0.8".to_string())
281 .parse()
282 .unwrap();
283 #[cfg(not(debug_assertions))]
284 let warning_threshold: f32 = 0.8;
285
286 // When the maximum is unknown because there is no selected model,
287 // avoid showing the token limit warning.
288 if self.max == 0 {
289 TokenUsageRatio::Normal
290 } else if self.total >= self.max {
291 TokenUsageRatio::Exceeded
292 } else if self.total as f32 / self.max as f32 >= warning_threshold {
293 TokenUsageRatio::Warning
294 } else {
295 TokenUsageRatio::Normal
296 }
297 }
298
299 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
300 TotalTokenUsage {
301 total: self.total + tokens,
302 max: self.max,
303 }
304 }
305}
306
307#[derive(Debug, Default, PartialEq, Eq)]
308pub enum TokenUsageRatio {
309 #[default]
310 Normal,
311 Warning,
312 Exceeded,
313}
314
315fn default_completion_mode(cx: &App) -> CompletionMode {
316 if cx.is_staff() {
317 CompletionMode::Max
318 } else {
319 CompletionMode::Normal
320 }
321}
322
323#[derive(Debug, Clone, Copy)]
324pub enum QueueState {
325 Sending,
326 Queued { position: usize },
327 Started,
328}
329
330/// A thread of conversation with the LLM.
331pub struct Thread {
332 id: ThreadId,
333 updated_at: DateTime<Utc>,
334 summary: Option<SharedString>,
335 pending_summary: Task<Option<()>>,
336 detailed_summary_task: Task<Option<()>>,
337 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
338 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
339 completion_mode: CompletionMode,
340 messages: Vec<Message>,
341 next_message_id: MessageId,
342 last_prompt_id: PromptId,
343 project_context: SharedProjectContext,
344 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
345 completion_count: usize,
346 pending_completions: Vec<PendingCompletion>,
347 project: Entity<Project>,
348 prompt_builder: Arc<PromptBuilder>,
349 tools: Entity<ToolWorkingSet>,
350 tool_use: ToolUseState,
351 action_log: Entity<ActionLog>,
352 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
353 pending_checkpoint: Option<ThreadCheckpoint>,
354 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
355 request_token_usage: Vec<TokenUsage>,
356 cumulative_token_usage: TokenUsage,
357 exceeded_window_error: Option<ExceededWindowError>,
358 last_usage: Option<RequestUsage>,
359 tool_use_limit_reached: bool,
360 feedback: Option<ThreadFeedback>,
361 message_feedback: HashMap<MessageId, ThreadFeedback>,
362 last_auto_capture_at: Option<Instant>,
363 last_received_chunk_at: Option<Instant>,
364 request_callback: Option<
365 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
366 >,
367 remaining_turns: u32,
368 configured_model: Option<ConfiguredModel>,
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct ExceededWindowError {
373 /// Model used when last message exceeded context window
374 model_id: LanguageModelId,
375 /// Token count including last message
376 token_count: usize,
377}
378
379impl Thread {
380 pub fn new(
381 project: Entity<Project>,
382 tools: Entity<ToolWorkingSet>,
383 prompt_builder: Arc<PromptBuilder>,
384 system_prompt: SharedProjectContext,
385 cx: &mut Context<Self>,
386 ) -> Self {
387 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
388 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
389
390 Self {
391 id: ThreadId::new(),
392 updated_at: Utc::now(),
393 summary: None,
394 pending_summary: Task::ready(None),
395 detailed_summary_task: Task::ready(None),
396 detailed_summary_tx,
397 detailed_summary_rx,
398 completion_mode: default_completion_mode(cx),
399 messages: Vec::new(),
400 next_message_id: MessageId(0),
401 last_prompt_id: PromptId::new(),
402 project_context: system_prompt,
403 checkpoints_by_message: HashMap::default(),
404 completion_count: 0,
405 pending_completions: Vec::new(),
406 project: project.clone(),
407 prompt_builder,
408 tools: tools.clone(),
409 last_restore_checkpoint: None,
410 pending_checkpoint: None,
411 tool_use: ToolUseState::new(tools.clone()),
412 action_log: cx.new(|_| ActionLog::new(project.clone())),
413 initial_project_snapshot: {
414 let project_snapshot = Self::project_snapshot(project, cx);
415 cx.foreground_executor()
416 .spawn(async move { Some(project_snapshot.await) })
417 .shared()
418 },
419 request_token_usage: Vec::new(),
420 cumulative_token_usage: TokenUsage::default(),
421 exceeded_window_error: None,
422 last_usage: None,
423 tool_use_limit_reached: false,
424 feedback: None,
425 message_feedback: HashMap::default(),
426 last_auto_capture_at: None,
427 last_received_chunk_at: None,
428 request_callback: None,
429 remaining_turns: u32::MAX,
430 configured_model,
431 }
432 }
433
434 pub fn deserialize(
435 id: ThreadId,
436 serialized: SerializedThread,
437 project: Entity<Project>,
438 tools: Entity<ToolWorkingSet>,
439 prompt_builder: Arc<PromptBuilder>,
440 project_context: SharedProjectContext,
441 cx: &mut Context<Self>,
442 ) -> Self {
443 let next_message_id = MessageId(
444 serialized
445 .messages
446 .last()
447 .map(|message| message.id.0 + 1)
448 .unwrap_or(0),
449 );
450 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
451 let (detailed_summary_tx, detailed_summary_rx) =
452 postage::watch::channel_with(serialized.detailed_summary_state);
453
454 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
455 serialized
456 .model
457 .and_then(|model| {
458 let model = SelectedModel {
459 provider: model.provider.clone().into(),
460 model: model.model.clone().into(),
461 };
462 registry.select_model(&model, cx)
463 })
464 .or_else(|| registry.default_model())
465 });
466
467 Self {
468 id,
469 updated_at: serialized.updated_at,
470 summary: Some(serialized.summary),
471 pending_summary: Task::ready(None),
472 detailed_summary_task: Task::ready(None),
473 detailed_summary_tx,
474 detailed_summary_rx,
475 completion_mode: default_completion_mode(cx),
476 messages: serialized
477 .messages
478 .into_iter()
479 .map(|message| Message {
480 id: message.id,
481 role: message.role,
482 segments: message
483 .segments
484 .into_iter()
485 .map(|segment| match segment {
486 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
487 SerializedMessageSegment::Thinking { text, signature } => {
488 MessageSegment::Thinking { text, signature }
489 }
490 SerializedMessageSegment::RedactedThinking { data } => {
491 MessageSegment::RedactedThinking(data)
492 }
493 })
494 .collect(),
495 loaded_context: LoadedContext {
496 contexts: Vec::new(),
497 text: message.context,
498 images: Vec::new(),
499 },
500 creases: message
501 .creases
502 .into_iter()
503 .map(|crease| MessageCrease {
504 range: crease.start..crease.end,
505 metadata: CreaseMetadata {
506 icon_path: crease.icon_path,
507 label: crease.label,
508 },
509 context: None,
510 })
511 .collect(),
512 })
513 .collect(),
514 next_message_id,
515 last_prompt_id: PromptId::new(),
516 project_context,
517 checkpoints_by_message: HashMap::default(),
518 completion_count: 0,
519 pending_completions: Vec::new(),
520 last_restore_checkpoint: None,
521 pending_checkpoint: None,
522 project: project.clone(),
523 prompt_builder,
524 tools,
525 tool_use,
526 action_log: cx.new(|_| ActionLog::new(project)),
527 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
528 request_token_usage: serialized.request_token_usage,
529 cumulative_token_usage: serialized.cumulative_token_usage,
530 exceeded_window_error: None,
531 last_usage: None,
532 tool_use_limit_reached: false,
533 feedback: None,
534 message_feedback: HashMap::default(),
535 last_auto_capture_at: None,
536 last_received_chunk_at: None,
537 request_callback: None,
538 remaining_turns: u32::MAX,
539 configured_model,
540 }
541 }
542
543 pub fn set_request_callback(
544 &mut self,
545 callback: impl 'static
546 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
547 ) {
548 self.request_callback = Some(Box::new(callback));
549 }
550
551 pub fn id(&self) -> &ThreadId {
552 &self.id
553 }
554
555 pub fn is_empty(&self) -> bool {
556 self.messages.is_empty()
557 }
558
559 pub fn updated_at(&self) -> DateTime<Utc> {
560 self.updated_at
561 }
562
563 pub fn touch_updated_at(&mut self) {
564 self.updated_at = Utc::now();
565 }
566
567 pub fn advance_prompt_id(&mut self) {
568 self.last_prompt_id = PromptId::new();
569 }
570
571 pub fn summary(&self) -> Option<SharedString> {
572 self.summary.clone()
573 }
574
575 pub fn project_context(&self) -> SharedProjectContext {
576 self.project_context.clone()
577 }
578
579 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
580 if self.configured_model.is_none() {
581 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
582 }
583 self.configured_model.clone()
584 }
585
586 pub fn configured_model(&self) -> Option<ConfiguredModel> {
587 self.configured_model.clone()
588 }
589
590 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
591 self.configured_model = model;
592 cx.notify();
593 }
594
595 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
596
597 pub fn summary_or_default(&self) -> SharedString {
598 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
599 }
600
601 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
602 let Some(current_summary) = &self.summary else {
603 // Don't allow setting summary until generated
604 return;
605 };
606
607 let mut new_summary = new_summary.into();
608
609 if new_summary.is_empty() {
610 new_summary = Self::DEFAULT_SUMMARY;
611 }
612
613 if current_summary != &new_summary {
614 self.summary = Some(new_summary);
615 cx.emit(ThreadEvent::SummaryChanged);
616 }
617 }
618
619 pub fn completion_mode(&self) -> CompletionMode {
620 self.completion_mode
621 }
622
623 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
624 self.completion_mode = mode;
625 }
626
627 pub fn message(&self, id: MessageId) -> Option<&Message> {
628 let index = self
629 .messages
630 .binary_search_by(|message| message.id.cmp(&id))
631 .ok()?;
632
633 self.messages.get(index)
634 }
635
636 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
637 self.messages.iter()
638 }
639
640 pub fn is_generating(&self) -> bool {
641 !self.pending_completions.is_empty() || !self.all_tools_finished()
642 }
643
644 /// Indicates whether streaming of language model events is stale.
645 /// When `is_generating()` is false, this method returns `None`.
646 pub fn is_generation_stale(&self) -> Option<bool> {
647 const STALE_THRESHOLD: u128 = 250;
648
649 self.last_received_chunk_at
650 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
651 }
652
653 fn received_chunk(&mut self) {
654 self.last_received_chunk_at = Some(Instant::now());
655 }
656
657 pub fn queue_state(&self) -> Option<QueueState> {
658 self.pending_completions
659 .first()
660 .map(|pending_completion| pending_completion.queue_state)
661 }
662
663 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
664 &self.tools
665 }
666
667 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
668 self.tool_use
669 .pending_tool_uses()
670 .into_iter()
671 .find(|tool_use| &tool_use.id == id)
672 }
673
674 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
675 self.tool_use
676 .pending_tool_uses()
677 .into_iter()
678 .filter(|tool_use| tool_use.status.needs_confirmation())
679 }
680
681 pub fn has_pending_tool_uses(&self) -> bool {
682 !self.tool_use.pending_tool_uses().is_empty()
683 }
684
685 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
686 self.checkpoints_by_message.get(&id).cloned()
687 }
688
689 pub fn restore_checkpoint(
690 &mut self,
691 checkpoint: ThreadCheckpoint,
692 cx: &mut Context<Self>,
693 ) -> Task<Result<()>> {
694 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
695 message_id: checkpoint.message_id,
696 });
697 cx.emit(ThreadEvent::CheckpointChanged);
698 cx.notify();
699
700 let git_store = self.project().read(cx).git_store().clone();
701 let restore = git_store.update(cx, |git_store, cx| {
702 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
703 });
704
705 cx.spawn(async move |this, cx| {
706 let result = restore.await;
707 this.update(cx, |this, cx| {
708 if let Err(err) = result.as_ref() {
709 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
710 message_id: checkpoint.message_id,
711 error: err.to_string(),
712 });
713 } else {
714 this.truncate(checkpoint.message_id, cx);
715 this.last_restore_checkpoint = None;
716 }
717 this.pending_checkpoint = None;
718 cx.emit(ThreadEvent::CheckpointChanged);
719 cx.notify();
720 })?;
721 result
722 })
723 }
724
725 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
726 let pending_checkpoint = if self.is_generating() {
727 return;
728 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
729 checkpoint
730 } else {
731 return;
732 };
733
734 let git_store = self.project.read(cx).git_store().clone();
735 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
736 cx.spawn(async move |this, cx| match final_checkpoint.await {
737 Ok(final_checkpoint) => {
738 let equal = git_store
739 .update(cx, |store, cx| {
740 store.compare_checkpoints(
741 pending_checkpoint.git_checkpoint.clone(),
742 final_checkpoint.clone(),
743 cx,
744 )
745 })?
746 .await
747 .unwrap_or(false);
748
749 if !equal {
750 this.update(cx, |this, cx| {
751 this.insert_checkpoint(pending_checkpoint, cx)
752 })?;
753 }
754
755 Ok(())
756 }
757 Err(_) => this.update(cx, |this, cx| {
758 this.insert_checkpoint(pending_checkpoint, cx)
759 }),
760 })
761 .detach();
762 }
763
764 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
765 self.checkpoints_by_message
766 .insert(checkpoint.message_id, checkpoint);
767 cx.emit(ThreadEvent::CheckpointChanged);
768 cx.notify();
769 }
770
771 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
772 self.last_restore_checkpoint.as_ref()
773 }
774
775 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
776 let Some(message_ix) = self
777 .messages
778 .iter()
779 .rposition(|message| message.id == message_id)
780 else {
781 return;
782 };
783 for deleted_message in self.messages.drain(message_ix..) {
784 self.checkpoints_by_message.remove(&deleted_message.id);
785 }
786 cx.notify();
787 }
788
789 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
790 self.messages
791 .iter()
792 .find(|message| message.id == id)
793 .into_iter()
794 .flat_map(|message| message.loaded_context.contexts.iter())
795 }
796
797 pub fn is_turn_end(&self, ix: usize) -> bool {
798 if self.messages.is_empty() {
799 return false;
800 }
801
802 if !self.is_generating() && ix == self.messages.len() - 1 {
803 return true;
804 }
805
806 let Some(message) = self.messages.get(ix) else {
807 return false;
808 };
809
810 if message.role != Role::Assistant {
811 return false;
812 }
813
814 self.messages
815 .get(ix + 1)
816 .and_then(|message| {
817 self.message(message.id)
818 .map(|next_message| next_message.role == Role::User)
819 })
820 .unwrap_or(false)
821 }
822
823 pub fn last_usage(&self) -> Option<RequestUsage> {
824 self.last_usage
825 }
826
827 pub fn tool_use_limit_reached(&self) -> bool {
828 self.tool_use_limit_reached
829 }
830
831 /// Returns whether all of the tool uses have finished running.
832 pub fn all_tools_finished(&self) -> bool {
833 // If the only pending tool uses left are the ones with errors, then
834 // that means that we've finished running all of the pending tools.
835 self.tool_use
836 .pending_tool_uses()
837 .iter()
838 .all(|tool_use| tool_use.status.is_error())
839 }
840
841 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
842 self.tool_use.tool_uses_for_message(id, cx)
843 }
844
845 pub fn tool_results_for_message(
846 &self,
847 assistant_message_id: MessageId,
848 ) -> Vec<&LanguageModelToolResult> {
849 self.tool_use.tool_results_for_message(assistant_message_id)
850 }
851
852 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
853 self.tool_use.tool_result(id)
854 }
855
856 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
857 Some(&self.tool_use.tool_result(id)?.content)
858 }
859
860 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
861 self.tool_use.tool_result_card(id).cloned()
862 }
863
864 /// Return tools that are both enabled and supported by the model
865 pub fn available_tools(
866 &self,
867 cx: &App,
868 model: Arc<dyn LanguageModel>,
869 ) -> Vec<LanguageModelRequestTool> {
870 if model.supports_tools() {
871 self.tools()
872 .read(cx)
873 .enabled_tools(cx)
874 .into_iter()
875 .filter_map(|tool| {
876 // Skip tools that cannot be supported
877 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
878 Some(LanguageModelRequestTool {
879 name: tool.name(),
880 description: tool.description(),
881 input_schema,
882 })
883 })
884 .collect()
885 } else {
886 Vec::default()
887 }
888 }
889
890 pub fn insert_user_message(
891 &mut self,
892 text: impl Into<String>,
893 loaded_context: ContextLoadResult,
894 git_checkpoint: Option<GitStoreCheckpoint>,
895 creases: Vec<MessageCrease>,
896 cx: &mut Context<Self>,
897 ) -> MessageId {
898 if !loaded_context.referenced_buffers.is_empty() {
899 self.action_log.update(cx, |log, cx| {
900 for buffer in loaded_context.referenced_buffers {
901 log.track_buffer(buffer, cx);
902 }
903 });
904 }
905
906 let message_id = self.insert_message(
907 Role::User,
908 vec![MessageSegment::Text(text.into())],
909 loaded_context.loaded_context,
910 creases,
911 cx,
912 );
913
914 if let Some(git_checkpoint) = git_checkpoint {
915 self.pending_checkpoint = Some(ThreadCheckpoint {
916 message_id,
917 git_checkpoint,
918 });
919 }
920
921 self.auto_capture_telemetry(cx);
922
923 message_id
924 }
925
926 pub fn insert_assistant_message(
927 &mut self,
928 segments: Vec<MessageSegment>,
929 cx: &mut Context<Self>,
930 ) -> MessageId {
931 self.insert_message(
932 Role::Assistant,
933 segments,
934 LoadedContext::default(),
935 Vec::new(),
936 cx,
937 )
938 }
939
940 pub fn insert_message(
941 &mut self,
942 role: Role,
943 segments: Vec<MessageSegment>,
944 loaded_context: LoadedContext,
945 creases: Vec<MessageCrease>,
946 cx: &mut Context<Self>,
947 ) -> MessageId {
948 let id = self.next_message_id.post_inc();
949 self.messages.push(Message {
950 id,
951 role,
952 segments,
953 loaded_context,
954 creases,
955 });
956 self.touch_updated_at();
957 cx.emit(ThreadEvent::MessageAdded(id));
958 id
959 }
960
961 pub fn edit_message(
962 &mut self,
963 id: MessageId,
964 new_role: Role,
965 new_segments: Vec<MessageSegment>,
966 loaded_context: Option<LoadedContext>,
967 cx: &mut Context<Self>,
968 ) -> bool {
969 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
970 return false;
971 };
972 message.role = new_role;
973 message.segments = new_segments;
974 if let Some(context) = loaded_context {
975 message.loaded_context = context;
976 }
977 self.touch_updated_at();
978 cx.emit(ThreadEvent::MessageEdited(id));
979 true
980 }
981
982 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
983 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
984 return false;
985 };
986 self.messages.remove(index);
987 self.touch_updated_at();
988 cx.emit(ThreadEvent::MessageDeleted(id));
989 true
990 }
991
992 /// Returns the representation of this [`Thread`] in a textual form.
993 ///
994 /// This is the representation we use when attaching a thread as context to another thread.
995 pub fn text(&self) -> String {
996 let mut text = String::new();
997
998 for message in &self.messages {
999 text.push_str(match message.role {
1000 language_model::Role::User => "User:",
1001 language_model::Role::Assistant => "Assistant:",
1002 language_model::Role::System => "System:",
1003 });
1004 text.push('\n');
1005
1006 for segment in &message.segments {
1007 match segment {
1008 MessageSegment::Text(content) => text.push_str(content),
1009 MessageSegment::Thinking { text: content, .. } => {
1010 text.push_str(&format!("<think>{}</think>", content))
1011 }
1012 MessageSegment::RedactedThinking(_) => {}
1013 }
1014 }
1015 text.push('\n');
1016 }
1017
1018 text
1019 }
1020
1021 /// Serializes this thread into a format for storage or telemetry.
1022 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1023 let initial_project_snapshot = self.initial_project_snapshot.clone();
1024 cx.spawn(async move |this, cx| {
1025 let initial_project_snapshot = initial_project_snapshot.await;
1026 this.read_with(cx, |this, cx| SerializedThread {
1027 version: SerializedThread::VERSION.to_string(),
1028 summary: this.summary_or_default(),
1029 updated_at: this.updated_at(),
1030 messages: this
1031 .messages()
1032 .map(|message| SerializedMessage {
1033 id: message.id,
1034 role: message.role,
1035 segments: message
1036 .segments
1037 .iter()
1038 .map(|segment| match segment {
1039 MessageSegment::Text(text) => {
1040 SerializedMessageSegment::Text { text: text.clone() }
1041 }
1042 MessageSegment::Thinking { text, signature } => {
1043 SerializedMessageSegment::Thinking {
1044 text: text.clone(),
1045 signature: signature.clone(),
1046 }
1047 }
1048 MessageSegment::RedactedThinking(data) => {
1049 SerializedMessageSegment::RedactedThinking {
1050 data: data.clone(),
1051 }
1052 }
1053 })
1054 .collect(),
1055 tool_uses: this
1056 .tool_uses_for_message(message.id, cx)
1057 .into_iter()
1058 .map(|tool_use| SerializedToolUse {
1059 id: tool_use.id,
1060 name: tool_use.name,
1061 input: tool_use.input,
1062 })
1063 .collect(),
1064 tool_results: this
1065 .tool_results_for_message(message.id)
1066 .into_iter()
1067 .map(|tool_result| SerializedToolResult {
1068 tool_use_id: tool_result.tool_use_id.clone(),
1069 is_error: tool_result.is_error,
1070 content: tool_result.content.clone(),
1071 })
1072 .collect(),
1073 context: message.loaded_context.text.clone(),
1074 creases: message
1075 .creases
1076 .iter()
1077 .map(|crease| SerializedCrease {
1078 start: crease.range.start,
1079 end: crease.range.end,
1080 icon_path: crease.metadata.icon_path.clone(),
1081 label: crease.metadata.label.clone(),
1082 })
1083 .collect(),
1084 })
1085 .collect(),
1086 initial_project_snapshot,
1087 cumulative_token_usage: this.cumulative_token_usage,
1088 request_token_usage: this.request_token_usage.clone(),
1089 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1090 exceeded_window_error: this.exceeded_window_error.clone(),
1091 model: this
1092 .configured_model
1093 .as_ref()
1094 .map(|model| SerializedLanguageModel {
1095 provider: model.provider.id().0.to_string(),
1096 model: model.model.id().0.to_string(),
1097 }),
1098 })
1099 })
1100 }
1101
1102 pub fn remaining_turns(&self) -> u32 {
1103 self.remaining_turns
1104 }
1105
1106 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1107 self.remaining_turns = remaining_turns;
1108 }
1109
1110 pub fn send_to_model(
1111 &mut self,
1112 model: Arc<dyn LanguageModel>,
1113 window: Option<AnyWindowHandle>,
1114 cx: &mut Context<Self>,
1115 ) {
1116 if self.remaining_turns == 0 {
1117 return;
1118 }
1119
1120 self.remaining_turns -= 1;
1121
1122 let request = self.to_completion_request(model.clone(), cx);
1123
1124 self.stream_completion(request, model, window, cx);
1125 }
1126
1127 pub fn used_tools_since_last_user_message(&self) -> bool {
1128 for message in self.messages.iter().rev() {
1129 if self.tool_use.message_has_tool_results(message.id) {
1130 return true;
1131 } else if message.role == Role::User {
1132 return false;
1133 }
1134 }
1135
1136 false
1137 }
1138
1139 pub fn to_completion_request(
1140 &self,
1141 model: Arc<dyn LanguageModel>,
1142 cx: &mut Context<Self>,
1143 ) -> LanguageModelRequest {
1144 let mut request = LanguageModelRequest {
1145 thread_id: Some(self.id.to_string()),
1146 prompt_id: Some(self.last_prompt_id.to_string()),
1147 mode: None,
1148 messages: vec![],
1149 tools: Vec::new(),
1150 stop: Vec::new(),
1151 temperature: None,
1152 };
1153
1154 let available_tools = self.available_tools(cx, model.clone());
1155 let available_tool_names = available_tools
1156 .iter()
1157 .map(|tool| tool.name.clone())
1158 .collect();
1159
1160 let model_context = &ModelContext {
1161 available_tools: available_tool_names,
1162 };
1163
1164 if let Some(project_context) = self.project_context.borrow().as_ref() {
1165 match self
1166 .prompt_builder
1167 .generate_assistant_system_prompt(project_context, model_context)
1168 {
1169 Err(err) => {
1170 let message = format!("{err:?}").into();
1171 log::error!("{message}");
1172 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1173 header: "Error generating system prompt".into(),
1174 message,
1175 }));
1176 }
1177 Ok(system_prompt) => {
1178 request.messages.push(LanguageModelRequestMessage {
1179 role: Role::System,
1180 content: vec![MessageContent::Text(system_prompt)],
1181 cache: true,
1182 });
1183 }
1184 }
1185 } else {
1186 let message = "Context for system prompt unexpectedly not ready.".into();
1187 log::error!("{message}");
1188 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1189 header: "Error generating system prompt".into(),
1190 message,
1191 }));
1192 }
1193
1194 for message in &self.messages {
1195 let mut request_message = LanguageModelRequestMessage {
1196 role: message.role,
1197 content: Vec::new(),
1198 cache: false,
1199 };
1200
1201 message
1202 .loaded_context
1203 .add_to_request_message(&mut request_message);
1204
1205 for segment in &message.segments {
1206 match segment {
1207 MessageSegment::Text(text) => {
1208 if !text.is_empty() {
1209 request_message
1210 .content
1211 .push(MessageContent::Text(text.into()));
1212 }
1213 }
1214 MessageSegment::Thinking { text, signature } => {
1215 if !text.is_empty() {
1216 request_message.content.push(MessageContent::Thinking {
1217 text: text.into(),
1218 signature: signature.clone(),
1219 });
1220 }
1221 }
1222 MessageSegment::RedactedThinking(data) => {
1223 request_message
1224 .content
1225 .push(MessageContent::RedactedThinking(data.clone()));
1226 }
1227 };
1228 }
1229
1230 self.tool_use
1231 .attach_tool_uses(message.id, &mut request_message);
1232
1233 request.messages.push(request_message);
1234
1235 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1236 request.messages.push(tool_results_message);
1237 }
1238 }
1239
1240 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1241 if let Some(last) = request.messages.last_mut() {
1242 last.cache = true;
1243 }
1244
1245 self.attached_tracked_files_state(&mut request.messages, cx);
1246
1247 request.tools = available_tools;
1248 request.mode = if model.supports_max_mode() {
1249 Some(self.completion_mode)
1250 } else {
1251 Some(CompletionMode::Normal)
1252 };
1253
1254 request
1255 }
1256
1257 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1258 let mut request = LanguageModelRequest {
1259 thread_id: None,
1260 prompt_id: None,
1261 mode: None,
1262 messages: vec![],
1263 tools: Vec::new(),
1264 stop: Vec::new(),
1265 temperature: None,
1266 };
1267
1268 for message in &self.messages {
1269 let mut request_message = LanguageModelRequestMessage {
1270 role: message.role,
1271 content: Vec::new(),
1272 cache: false,
1273 };
1274
1275 for segment in &message.segments {
1276 match segment {
1277 MessageSegment::Text(text) => request_message
1278 .content
1279 .push(MessageContent::Text(text.clone())),
1280 MessageSegment::Thinking { .. } => {}
1281 MessageSegment::RedactedThinking(_) => {}
1282 }
1283 }
1284
1285 if request_message.content.is_empty() {
1286 continue;
1287 }
1288
1289 request.messages.push(request_message);
1290 }
1291
1292 request.messages.push(LanguageModelRequestMessage {
1293 role: Role::User,
1294 content: vec![MessageContent::Text(added_user_message)],
1295 cache: false,
1296 });
1297
1298 request
1299 }
1300
1301 fn attached_tracked_files_state(
1302 &self,
1303 messages: &mut Vec<LanguageModelRequestMessage>,
1304 cx: &App,
1305 ) {
1306 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1307
1308 let mut stale_message = String::new();
1309
1310 let action_log = self.action_log.read(cx);
1311
1312 for stale_file in action_log.stale_buffers(cx) {
1313 let Some(file) = stale_file.read(cx).file() else {
1314 continue;
1315 };
1316
1317 if stale_message.is_empty() {
1318 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1319 }
1320
1321 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1322 }
1323
1324 let mut content = Vec::with_capacity(2);
1325
1326 if !stale_message.is_empty() {
1327 content.push(stale_message.into());
1328 }
1329
1330 if !content.is_empty() {
1331 let context_message = LanguageModelRequestMessage {
1332 role: Role::User,
1333 content,
1334 cache: false,
1335 };
1336
1337 messages.push(context_message);
1338 }
1339 }
1340
1341 pub fn stream_completion(
1342 &mut self,
1343 request: LanguageModelRequest,
1344 model: Arc<dyn LanguageModel>,
1345 window: Option<AnyWindowHandle>,
1346 cx: &mut Context<Self>,
1347 ) {
1348 self.tool_use_limit_reached = false;
1349
1350 let pending_completion_id = post_inc(&mut self.completion_count);
1351 let mut request_callback_parameters = if self.request_callback.is_some() {
1352 Some((request.clone(), Vec::new()))
1353 } else {
1354 None
1355 };
1356 let prompt_id = self.last_prompt_id.clone();
1357 let tool_use_metadata = ToolUseMetadata {
1358 model: model.clone(),
1359 thread_id: self.id.clone(),
1360 prompt_id: prompt_id.clone(),
1361 };
1362
1363 self.last_received_chunk_at = Some(Instant::now());
1364
1365 let task = cx.spawn(async move |thread, cx| {
1366 let stream_completion_future = model.stream_completion(request, &cx);
1367 let initial_token_usage =
1368 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1369 let stream_completion = async {
1370 let mut events = stream_completion_future.await?;
1371
1372 let mut stop_reason = StopReason::EndTurn;
1373 let mut current_token_usage = TokenUsage::default();
1374
1375 thread
1376 .update(cx, |_thread, cx| {
1377 cx.emit(ThreadEvent::NewRequest);
1378 })
1379 .ok();
1380
1381 let mut request_assistant_message_id = None;
1382
1383 while let Some(event) = events.next().await {
1384 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1385 response_events
1386 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1387 }
1388
1389 thread.update(cx, |thread, cx| {
1390 let event = match event {
1391 Ok(event) => event,
1392 Err(LanguageModelCompletionError::BadInputJson {
1393 id,
1394 tool_name,
1395 raw_input: invalid_input_json,
1396 json_parse_error,
1397 }) => {
1398 thread.receive_invalid_tool_json(
1399 id,
1400 tool_name,
1401 invalid_input_json,
1402 json_parse_error,
1403 window,
1404 cx,
1405 );
1406 return Ok(());
1407 }
1408 Err(LanguageModelCompletionError::Other(error)) => {
1409 return Err(error);
1410 }
1411 };
1412
1413 match event {
1414 LanguageModelCompletionEvent::StartMessage { .. } => {
1415 request_assistant_message_id =
1416 Some(thread.insert_assistant_message(
1417 vec![MessageSegment::Text(String::new())],
1418 cx,
1419 ));
1420 }
1421 LanguageModelCompletionEvent::Stop(reason) => {
1422 stop_reason = reason;
1423 }
1424 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1425 thread.update_token_usage_at_last_message(token_usage);
1426 thread.cumulative_token_usage = thread.cumulative_token_usage
1427 + token_usage
1428 - current_token_usage;
1429 current_token_usage = token_usage;
1430 }
1431 LanguageModelCompletionEvent::Text(chunk) => {
1432 thread.received_chunk();
1433
1434 cx.emit(ThreadEvent::ReceivedTextChunk);
1435 if let Some(last_message) = thread.messages.last_mut() {
1436 if last_message.role == Role::Assistant
1437 && !thread.tool_use.has_tool_results(last_message.id)
1438 {
1439 last_message.push_text(&chunk);
1440 cx.emit(ThreadEvent::StreamedAssistantText(
1441 last_message.id,
1442 chunk,
1443 ));
1444 } else {
1445 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1446 // of a new Assistant response.
1447 //
1448 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1449 // will result in duplicating the text of the chunk in the rendered Markdown.
1450 request_assistant_message_id =
1451 Some(thread.insert_assistant_message(
1452 vec![MessageSegment::Text(chunk.to_string())],
1453 cx,
1454 ));
1455 };
1456 }
1457 }
1458 LanguageModelCompletionEvent::Thinking {
1459 text: chunk,
1460 signature,
1461 } => {
1462 thread.received_chunk();
1463
1464 if let Some(last_message) = thread.messages.last_mut() {
1465 if last_message.role == Role::Assistant
1466 && !thread.tool_use.has_tool_results(last_message.id)
1467 {
1468 last_message.push_thinking(&chunk, signature);
1469 cx.emit(ThreadEvent::StreamedAssistantThinking(
1470 last_message.id,
1471 chunk,
1472 ));
1473 } else {
1474 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1475 // of a new Assistant response.
1476 //
1477 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1478 // will result in duplicating the text of the chunk in the rendered Markdown.
1479 request_assistant_message_id =
1480 Some(thread.insert_assistant_message(
1481 vec![MessageSegment::Thinking {
1482 text: chunk.to_string(),
1483 signature,
1484 }],
1485 cx,
1486 ));
1487 };
1488 }
1489 }
1490 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1491 let last_assistant_message_id = request_assistant_message_id
1492 .unwrap_or_else(|| {
1493 let new_assistant_message_id =
1494 thread.insert_assistant_message(vec![], cx);
1495 request_assistant_message_id =
1496 Some(new_assistant_message_id);
1497 new_assistant_message_id
1498 });
1499
1500 let tool_use_id = tool_use.id.clone();
1501 let streamed_input = if tool_use.is_input_complete {
1502 None
1503 } else {
1504 Some((&tool_use.input).clone())
1505 };
1506
1507 let ui_text = thread.tool_use.request_tool_use(
1508 last_assistant_message_id,
1509 tool_use,
1510 tool_use_metadata.clone(),
1511 cx,
1512 );
1513
1514 if let Some(input) = streamed_input {
1515 cx.emit(ThreadEvent::StreamedToolUse {
1516 tool_use_id,
1517 ui_text,
1518 input,
1519 });
1520 }
1521 }
1522 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1523 if let Some(completion) = thread
1524 .pending_completions
1525 .iter_mut()
1526 .find(|completion| completion.id == pending_completion_id)
1527 {
1528 match status_update {
1529 CompletionRequestStatus::Queued {
1530 position,
1531 } => {
1532 completion.queue_state = QueueState::Queued { position };
1533 }
1534 CompletionRequestStatus::Started => {
1535 completion.queue_state = QueueState::Started;
1536 }
1537 CompletionRequestStatus::Failed {
1538 code, message
1539 } => {
1540 return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
1541 }
1542 CompletionRequestStatus::UsageUpdated {
1543 amount, limit
1544 } => {
1545 let usage = RequestUsage { limit, amount: amount as i32 };
1546
1547 thread.last_usage = Some(usage);
1548 cx.emit(ThreadEvent::UsageUpdated(usage));
1549 }
1550 CompletionRequestStatus::ToolUseLimitReached => {
1551 thread.tool_use_limit_reached = true;
1552 }
1553 }
1554 }
1555 }
1556 }
1557
1558 thread.touch_updated_at();
1559 cx.emit(ThreadEvent::StreamedCompletion);
1560 cx.notify();
1561
1562 thread.auto_capture_telemetry(cx);
1563 Ok(())
1564 })??;
1565
1566 smol::future::yield_now().await;
1567 }
1568
1569 thread.update(cx, |thread, cx| {
1570 thread.last_received_chunk_at = None;
1571 thread
1572 .pending_completions
1573 .retain(|completion| completion.id != pending_completion_id);
1574
1575 // If there is a response without tool use, summarize the message. Otherwise,
1576 // allow two tool uses before summarizing.
1577 if thread.summary.is_none()
1578 && thread.messages.len() >= 2
1579 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1580 {
1581 thread.summarize(cx);
1582 }
1583 })?;
1584
1585 anyhow::Ok(stop_reason)
1586 };
1587
1588 let result = stream_completion.await;
1589
1590 thread
1591 .update(cx, |thread, cx| {
1592 thread.finalize_pending_checkpoint(cx);
1593 match result.as_ref() {
1594 Ok(stop_reason) => match stop_reason {
1595 StopReason::ToolUse => {
1596 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1597 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1598 }
1599 StopReason::EndTurn | StopReason::MaxTokens => {
1600 thread.project.update(cx, |project, cx| {
1601 project.set_agent_location(None, cx);
1602 });
1603 }
1604 },
1605 Err(error) => {
1606 thread.project.update(cx, |project, cx| {
1607 project.set_agent_location(None, cx);
1608 });
1609
1610 if error.is::<PaymentRequiredError>() {
1611 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1612 } else if error.is::<MaxMonthlySpendReachedError>() {
1613 cx.emit(ThreadEvent::ShowError(
1614 ThreadError::MaxMonthlySpendReached,
1615 ));
1616 } else if let Some(error) =
1617 error.downcast_ref::<ModelRequestLimitReachedError>()
1618 {
1619 cx.emit(ThreadEvent::ShowError(
1620 ThreadError::ModelRequestLimitReached { plan: error.plan },
1621 ));
1622 } else if let Some(known_error) =
1623 error.downcast_ref::<LanguageModelKnownError>()
1624 {
1625 match known_error {
1626 LanguageModelKnownError::ContextWindowLimitExceeded {
1627 tokens,
1628 } => {
1629 thread.exceeded_window_error = Some(ExceededWindowError {
1630 model_id: model.id(),
1631 token_count: *tokens,
1632 });
1633 cx.notify();
1634 }
1635 }
1636 } else {
1637 let error_message = error
1638 .chain()
1639 .map(|err| err.to_string())
1640 .collect::<Vec<_>>()
1641 .join("\n");
1642 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1643 header: "Error interacting with language model".into(),
1644 message: SharedString::from(error_message.clone()),
1645 }));
1646 }
1647
1648 thread.cancel_last_completion(window, cx);
1649 }
1650 }
1651 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1652
1653 if let Some((request_callback, (request, response_events))) = thread
1654 .request_callback
1655 .as_mut()
1656 .zip(request_callback_parameters.as_ref())
1657 {
1658 request_callback(request, response_events);
1659 }
1660
1661 thread.auto_capture_telemetry(cx);
1662
1663 if let Ok(initial_usage) = initial_token_usage {
1664 let usage = thread.cumulative_token_usage - initial_usage;
1665
1666 telemetry::event!(
1667 "Assistant Thread Completion",
1668 thread_id = thread.id().to_string(),
1669 prompt_id = prompt_id,
1670 model = model.telemetry_id(),
1671 model_provider = model.provider_id().to_string(),
1672 input_tokens = usage.input_tokens,
1673 output_tokens = usage.output_tokens,
1674 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1675 cache_read_input_tokens = usage.cache_read_input_tokens,
1676 );
1677 }
1678 })
1679 .ok();
1680 });
1681
1682 self.pending_completions.push(PendingCompletion {
1683 id: pending_completion_id,
1684 queue_state: QueueState::Sending,
1685 _task: task,
1686 });
1687 }
1688
1689 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1690 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1691 return;
1692 };
1693
1694 if !model.provider.is_authenticated(cx) {
1695 return;
1696 }
1697
1698 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1699 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1700 If the conversation is about a specific subject, include it in the title. \
1701 Be descriptive. DO NOT speak in the first person.";
1702
1703 let request = self.to_summarize_request(added_user_message.into());
1704
1705 self.pending_summary = cx.spawn(async move |this, cx| {
1706 async move {
1707 let mut messages = model.model.stream_completion(request, &cx).await?;
1708
1709 let mut new_summary = String::new();
1710 while let Some(event) = messages.next().await {
1711 let event = event?;
1712 let text = match event {
1713 LanguageModelCompletionEvent::Text(text) => text,
1714 LanguageModelCompletionEvent::StatusUpdate(
1715 CompletionRequestStatus::UsageUpdated { amount, limit },
1716 ) => {
1717 this.update(cx, |_, cx| {
1718 cx.emit(ThreadEvent::UsageUpdated(RequestUsage {
1719 limit,
1720 amount: amount as i32,
1721 }));
1722 })?;
1723 continue;
1724 }
1725 _ => continue,
1726 };
1727
1728 let mut lines = text.lines();
1729 new_summary.extend(lines.next());
1730
1731 // Stop if the LLM generated multiple lines.
1732 if lines.next().is_some() {
1733 break;
1734 }
1735 }
1736
1737 this.update(cx, |this, cx| {
1738 if !new_summary.is_empty() {
1739 this.summary = Some(new_summary.into());
1740 }
1741
1742 cx.emit(ThreadEvent::SummaryGenerated);
1743 })?;
1744
1745 anyhow::Ok(())
1746 }
1747 .log_err()
1748 .await
1749 });
1750 }
1751
1752 pub fn start_generating_detailed_summary_if_needed(
1753 &mut self,
1754 thread_store: WeakEntity<ThreadStore>,
1755 cx: &mut Context<Self>,
1756 ) {
1757 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1758 return;
1759 };
1760
1761 match &*self.detailed_summary_rx.borrow() {
1762 DetailedSummaryState::Generating { message_id, .. }
1763 | DetailedSummaryState::Generated { message_id, .. }
1764 if *message_id == last_message_id =>
1765 {
1766 // Already up-to-date
1767 return;
1768 }
1769 _ => {}
1770 }
1771
1772 let Some(ConfiguredModel { model, provider }) =
1773 LanguageModelRegistry::read_global(cx).thread_summary_model()
1774 else {
1775 return;
1776 };
1777
1778 if !provider.is_authenticated(cx) {
1779 return;
1780 }
1781
1782 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1783 1. A brief overview of what was discussed\n\
1784 2. Key facts or information discovered\n\
1785 3. Outcomes or conclusions reached\n\
1786 4. Any action items or next steps if any\n\
1787 Format it in Markdown with headings and bullet points.";
1788
1789 let request = self.to_summarize_request(added_user_message.into());
1790
1791 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1792 message_id: last_message_id,
1793 };
1794
1795 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1796 // be better to allow the old task to complete, but this would require logic for choosing
1797 // which result to prefer (the old task could complete after the new one, resulting in a
1798 // stale summary).
1799 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1800 let stream = model.stream_completion_text(request, &cx);
1801 let Some(mut messages) = stream.await.log_err() else {
1802 thread
1803 .update(cx, |thread, _cx| {
1804 *thread.detailed_summary_tx.borrow_mut() =
1805 DetailedSummaryState::NotGenerated;
1806 })
1807 .ok()?;
1808 return None;
1809 };
1810
1811 let mut new_detailed_summary = String::new();
1812
1813 while let Some(chunk) = messages.stream.next().await {
1814 if let Some(chunk) = chunk.log_err() {
1815 new_detailed_summary.push_str(&chunk);
1816 }
1817 }
1818
1819 thread
1820 .update(cx, |thread, _cx| {
1821 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1822 text: new_detailed_summary.into(),
1823 message_id: last_message_id,
1824 };
1825 })
1826 .ok()?;
1827
1828 // Save thread so its summary can be reused later
1829 if let Some(thread) = thread.upgrade() {
1830 if let Ok(Ok(save_task)) = cx.update(|cx| {
1831 thread_store
1832 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1833 }) {
1834 save_task.await.log_err();
1835 }
1836 }
1837
1838 Some(())
1839 });
1840 }
1841
1842 pub async fn wait_for_detailed_summary_or_text(
1843 this: &Entity<Self>,
1844 cx: &mut AsyncApp,
1845 ) -> Option<SharedString> {
1846 let mut detailed_summary_rx = this
1847 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1848 .ok()?;
1849 loop {
1850 match detailed_summary_rx.recv().await? {
1851 DetailedSummaryState::Generating { .. } => {}
1852 DetailedSummaryState::NotGenerated => {
1853 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1854 }
1855 DetailedSummaryState::Generated { text, .. } => return Some(text),
1856 }
1857 }
1858 }
1859
1860 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1861 self.detailed_summary_rx
1862 .borrow()
1863 .text()
1864 .unwrap_or_else(|| self.text().into())
1865 }
1866
1867 pub fn is_generating_detailed_summary(&self) -> bool {
1868 matches!(
1869 &*self.detailed_summary_rx.borrow(),
1870 DetailedSummaryState::Generating { .. }
1871 )
1872 }
1873
1874 pub fn use_pending_tools(
1875 &mut self,
1876 window: Option<AnyWindowHandle>,
1877 cx: &mut Context<Self>,
1878 model: Arc<dyn LanguageModel>,
1879 ) -> Vec<PendingToolUse> {
1880 self.auto_capture_telemetry(cx);
1881 let request = self.to_completion_request(model, cx);
1882 let messages = Arc::new(request.messages);
1883 let pending_tool_uses = self
1884 .tool_use
1885 .pending_tool_uses()
1886 .into_iter()
1887 .filter(|tool_use| tool_use.status.is_idle())
1888 .cloned()
1889 .collect::<Vec<_>>();
1890
1891 for tool_use in pending_tool_uses.iter() {
1892 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1893 if tool.needs_confirmation(&tool_use.input, cx)
1894 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1895 {
1896 self.tool_use.confirm_tool_use(
1897 tool_use.id.clone(),
1898 tool_use.ui_text.clone(),
1899 tool_use.input.clone(),
1900 messages.clone(),
1901 tool,
1902 );
1903 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1904 } else {
1905 self.run_tool(
1906 tool_use.id.clone(),
1907 tool_use.ui_text.clone(),
1908 tool_use.input.clone(),
1909 &messages,
1910 tool,
1911 window,
1912 cx,
1913 );
1914 }
1915 }
1916 }
1917
1918 pending_tool_uses
1919 }
1920
1921 pub fn receive_invalid_tool_json(
1922 &mut self,
1923 tool_use_id: LanguageModelToolUseId,
1924 tool_name: Arc<str>,
1925 invalid_json: Arc<str>,
1926 error: String,
1927 window: Option<AnyWindowHandle>,
1928 cx: &mut Context<Thread>,
1929 ) {
1930 log::error!("The model returned invalid input JSON: {invalid_json}");
1931
1932 let pending_tool_use = self.tool_use.insert_tool_output(
1933 tool_use_id.clone(),
1934 tool_name,
1935 Err(anyhow!("Error parsing input JSON: {error}")),
1936 self.configured_model.as_ref(),
1937 );
1938 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1939 pending_tool_use.ui_text.clone()
1940 } else {
1941 log::error!(
1942 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1943 );
1944 format!("Unknown tool {}", tool_use_id).into()
1945 };
1946
1947 cx.emit(ThreadEvent::InvalidToolInput {
1948 tool_use_id: tool_use_id.clone(),
1949 ui_text,
1950 invalid_input_json: invalid_json,
1951 });
1952
1953 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1954 }
1955
1956 pub fn run_tool(
1957 &mut self,
1958 tool_use_id: LanguageModelToolUseId,
1959 ui_text: impl Into<SharedString>,
1960 input: serde_json::Value,
1961 messages: &[LanguageModelRequestMessage],
1962 tool: Arc<dyn Tool>,
1963 window: Option<AnyWindowHandle>,
1964 cx: &mut Context<Thread>,
1965 ) {
1966 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1967 self.tool_use
1968 .run_pending_tool(tool_use_id, ui_text.into(), task);
1969 }
1970
1971 fn spawn_tool_use(
1972 &mut self,
1973 tool_use_id: LanguageModelToolUseId,
1974 messages: &[LanguageModelRequestMessage],
1975 input: serde_json::Value,
1976 tool: Arc<dyn Tool>,
1977 window: Option<AnyWindowHandle>,
1978 cx: &mut Context<Thread>,
1979 ) -> Task<()> {
1980 let tool_name: Arc<str> = tool.name().into();
1981
1982 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1983 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1984 } else {
1985 tool.run(
1986 input,
1987 messages,
1988 self.project.clone(),
1989 self.action_log.clone(),
1990 window,
1991 cx,
1992 )
1993 };
1994
1995 // Store the card separately if it exists
1996 if let Some(card) = tool_result.card.clone() {
1997 self.tool_use
1998 .insert_tool_result_card(tool_use_id.clone(), card);
1999 }
2000
2001 cx.spawn({
2002 async move |thread: WeakEntity<Thread>, cx| {
2003 let output = tool_result.output.await;
2004
2005 thread
2006 .update(cx, |thread, cx| {
2007 let pending_tool_use = thread.tool_use.insert_tool_output(
2008 tool_use_id.clone(),
2009 tool_name,
2010 output,
2011 thread.configured_model.as_ref(),
2012 );
2013 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2014 })
2015 .ok();
2016 }
2017 })
2018 }
2019
2020 fn tool_finished(
2021 &mut self,
2022 tool_use_id: LanguageModelToolUseId,
2023 pending_tool_use: Option<PendingToolUse>,
2024 canceled: bool,
2025 window: Option<AnyWindowHandle>,
2026 cx: &mut Context<Self>,
2027 ) {
2028 if self.all_tools_finished() {
2029 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2030 if !canceled {
2031 self.send_to_model(model.clone(), window, cx);
2032 }
2033 self.auto_capture_telemetry(cx);
2034 }
2035 }
2036
2037 cx.emit(ThreadEvent::ToolFinished {
2038 tool_use_id,
2039 pending_tool_use,
2040 });
2041 }
2042
2043 /// Cancels the last pending completion, if there are any pending.
2044 ///
2045 /// Returns whether a completion was canceled.
2046 pub fn cancel_last_completion(
2047 &mut self,
2048 window: Option<AnyWindowHandle>,
2049 cx: &mut Context<Self>,
2050 ) -> bool {
2051 let mut canceled = self.pending_completions.pop().is_some();
2052
2053 for pending_tool_use in self.tool_use.cancel_pending() {
2054 canceled = true;
2055 self.tool_finished(
2056 pending_tool_use.id.clone(),
2057 Some(pending_tool_use),
2058 true,
2059 window,
2060 cx,
2061 );
2062 }
2063
2064 self.finalize_pending_checkpoint(cx);
2065
2066 if canceled {
2067 cx.emit(ThreadEvent::CompletionCanceled);
2068 }
2069
2070 canceled
2071 }
2072
2073 /// Signals that any in-progress editing should be canceled.
2074 ///
2075 /// This method is used to notify listeners (like ActiveThread) that
2076 /// they should cancel any editing operations.
2077 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2078 cx.emit(ThreadEvent::CancelEditing);
2079 }
2080
2081 pub fn feedback(&self) -> Option<ThreadFeedback> {
2082 self.feedback
2083 }
2084
2085 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2086 self.message_feedback.get(&message_id).copied()
2087 }
2088
2089 pub fn report_message_feedback(
2090 &mut self,
2091 message_id: MessageId,
2092 feedback: ThreadFeedback,
2093 cx: &mut Context<Self>,
2094 ) -> Task<Result<()>> {
2095 if self.message_feedback.get(&message_id) == Some(&feedback) {
2096 return Task::ready(Ok(()));
2097 }
2098
2099 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2100 let serialized_thread = self.serialize(cx);
2101 let thread_id = self.id().clone();
2102 let client = self.project.read(cx).client();
2103
2104 let enabled_tool_names: Vec<String> = self
2105 .tools()
2106 .read(cx)
2107 .enabled_tools(cx)
2108 .iter()
2109 .map(|tool| tool.name().to_string())
2110 .collect();
2111
2112 self.message_feedback.insert(message_id, feedback);
2113
2114 cx.notify();
2115
2116 let message_content = self
2117 .message(message_id)
2118 .map(|msg| msg.to_string())
2119 .unwrap_or_default();
2120
2121 cx.background_spawn(async move {
2122 let final_project_snapshot = final_project_snapshot.await;
2123 let serialized_thread = serialized_thread.await?;
2124 let thread_data =
2125 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2126
2127 let rating = match feedback {
2128 ThreadFeedback::Positive => "positive",
2129 ThreadFeedback::Negative => "negative",
2130 };
2131 telemetry::event!(
2132 "Assistant Thread Rated",
2133 rating,
2134 thread_id,
2135 enabled_tool_names,
2136 message_id = message_id.0,
2137 message_content,
2138 thread_data,
2139 final_project_snapshot
2140 );
2141 client.telemetry().flush_events().await;
2142
2143 Ok(())
2144 })
2145 }
2146
2147 pub fn report_feedback(
2148 &mut self,
2149 feedback: ThreadFeedback,
2150 cx: &mut Context<Self>,
2151 ) -> Task<Result<()>> {
2152 let last_assistant_message_id = self
2153 .messages
2154 .iter()
2155 .rev()
2156 .find(|msg| msg.role == Role::Assistant)
2157 .map(|msg| msg.id);
2158
2159 if let Some(message_id) = last_assistant_message_id {
2160 self.report_message_feedback(message_id, feedback, cx)
2161 } else {
2162 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2163 let serialized_thread = self.serialize(cx);
2164 let thread_id = self.id().clone();
2165 let client = self.project.read(cx).client();
2166 self.feedback = Some(feedback);
2167 cx.notify();
2168
2169 cx.background_spawn(async move {
2170 let final_project_snapshot = final_project_snapshot.await;
2171 let serialized_thread = serialized_thread.await?;
2172 let thread_data = serde_json::to_value(serialized_thread)
2173 .unwrap_or_else(|_| serde_json::Value::Null);
2174
2175 let rating = match feedback {
2176 ThreadFeedback::Positive => "positive",
2177 ThreadFeedback::Negative => "negative",
2178 };
2179 telemetry::event!(
2180 "Assistant Thread Rated",
2181 rating,
2182 thread_id,
2183 thread_data,
2184 final_project_snapshot
2185 );
2186 client.telemetry().flush_events().await;
2187
2188 Ok(())
2189 })
2190 }
2191 }
2192
2193 /// Create a snapshot of the current project state including git information and unsaved buffers.
2194 fn project_snapshot(
2195 project: Entity<Project>,
2196 cx: &mut Context<Self>,
2197 ) -> Task<Arc<ProjectSnapshot>> {
2198 let git_store = project.read(cx).git_store().clone();
2199 let worktree_snapshots: Vec<_> = project
2200 .read(cx)
2201 .visible_worktrees(cx)
2202 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2203 .collect();
2204
2205 cx.spawn(async move |_, cx| {
2206 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2207
2208 let mut unsaved_buffers = Vec::new();
2209 cx.update(|app_cx| {
2210 let buffer_store = project.read(app_cx).buffer_store();
2211 for buffer_handle in buffer_store.read(app_cx).buffers() {
2212 let buffer = buffer_handle.read(app_cx);
2213 if buffer.is_dirty() {
2214 if let Some(file) = buffer.file() {
2215 let path = file.path().to_string_lossy().to_string();
2216 unsaved_buffers.push(path);
2217 }
2218 }
2219 }
2220 })
2221 .ok();
2222
2223 Arc::new(ProjectSnapshot {
2224 worktree_snapshots,
2225 unsaved_buffer_paths: unsaved_buffers,
2226 timestamp: Utc::now(),
2227 })
2228 })
2229 }
2230
2231 fn worktree_snapshot(
2232 worktree: Entity<project::Worktree>,
2233 git_store: Entity<GitStore>,
2234 cx: &App,
2235 ) -> Task<WorktreeSnapshot> {
2236 cx.spawn(async move |cx| {
2237 // Get worktree path and snapshot
2238 let worktree_info = cx.update(|app_cx| {
2239 let worktree = worktree.read(app_cx);
2240 let path = worktree.abs_path().to_string_lossy().to_string();
2241 let snapshot = worktree.snapshot();
2242 (path, snapshot)
2243 });
2244
2245 let Ok((worktree_path, _snapshot)) = worktree_info else {
2246 return WorktreeSnapshot {
2247 worktree_path: String::new(),
2248 git_state: None,
2249 };
2250 };
2251
2252 let git_state = git_store
2253 .update(cx, |git_store, cx| {
2254 git_store
2255 .repositories()
2256 .values()
2257 .find(|repo| {
2258 repo.read(cx)
2259 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2260 .is_some()
2261 })
2262 .cloned()
2263 })
2264 .ok()
2265 .flatten()
2266 .map(|repo| {
2267 repo.update(cx, |repo, _| {
2268 let current_branch =
2269 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2270 repo.send_job(None, |state, _| async move {
2271 let RepositoryState::Local { backend, .. } = state else {
2272 return GitState {
2273 remote_url: None,
2274 head_sha: None,
2275 current_branch,
2276 diff: None,
2277 };
2278 };
2279
2280 let remote_url = backend.remote_url("origin");
2281 let head_sha = backend.head_sha().await;
2282 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2283
2284 GitState {
2285 remote_url,
2286 head_sha,
2287 current_branch,
2288 diff,
2289 }
2290 })
2291 })
2292 });
2293
2294 let git_state = match git_state {
2295 Some(git_state) => match git_state.ok() {
2296 Some(git_state) => git_state.await.ok(),
2297 None => None,
2298 },
2299 None => None,
2300 };
2301
2302 WorktreeSnapshot {
2303 worktree_path,
2304 git_state,
2305 }
2306 })
2307 }
2308
2309 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2310 let mut markdown = Vec::new();
2311
2312 if let Some(summary) = self.summary() {
2313 writeln!(markdown, "# {summary}\n")?;
2314 };
2315
2316 for message in self.messages() {
2317 writeln!(
2318 markdown,
2319 "## {role}\n",
2320 role = match message.role {
2321 Role::User => "User",
2322 Role::Assistant => "Assistant",
2323 Role::System => "System",
2324 }
2325 )?;
2326
2327 if !message.loaded_context.text.is_empty() {
2328 writeln!(markdown, "{}", message.loaded_context.text)?;
2329 }
2330
2331 if !message.loaded_context.images.is_empty() {
2332 writeln!(
2333 markdown,
2334 "\n{} images attached as context.\n",
2335 message.loaded_context.images.len()
2336 )?;
2337 }
2338
2339 for segment in &message.segments {
2340 match segment {
2341 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2342 MessageSegment::Thinking { text, .. } => {
2343 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2344 }
2345 MessageSegment::RedactedThinking(_) => {}
2346 }
2347 }
2348
2349 for tool_use in self.tool_uses_for_message(message.id, cx) {
2350 writeln!(
2351 markdown,
2352 "**Use Tool: {} ({})**",
2353 tool_use.name, tool_use.id
2354 )?;
2355 writeln!(markdown, "```json")?;
2356 writeln!(
2357 markdown,
2358 "{}",
2359 serde_json::to_string_pretty(&tool_use.input)?
2360 )?;
2361 writeln!(markdown, "```")?;
2362 }
2363
2364 for tool_result in self.tool_results_for_message(message.id) {
2365 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2366 if tool_result.is_error {
2367 write!(markdown, " (Error)")?;
2368 }
2369
2370 writeln!(markdown, "**\n")?;
2371 writeln!(markdown, "{}", tool_result.content)?;
2372 }
2373 }
2374
2375 Ok(String::from_utf8_lossy(&markdown).to_string())
2376 }
2377
2378 pub fn keep_edits_in_range(
2379 &mut self,
2380 buffer: Entity<language::Buffer>,
2381 buffer_range: Range<language::Anchor>,
2382 cx: &mut Context<Self>,
2383 ) {
2384 self.action_log.update(cx, |action_log, cx| {
2385 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2386 });
2387 }
2388
2389 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2390 self.action_log
2391 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2392 }
2393
2394 pub fn reject_edits_in_ranges(
2395 &mut self,
2396 buffer: Entity<language::Buffer>,
2397 buffer_ranges: Vec<Range<language::Anchor>>,
2398 cx: &mut Context<Self>,
2399 ) -> Task<Result<()>> {
2400 self.action_log.update(cx, |action_log, cx| {
2401 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2402 })
2403 }
2404
2405 pub fn action_log(&self) -> &Entity<ActionLog> {
2406 &self.action_log
2407 }
2408
2409 pub fn project(&self) -> &Entity<Project> {
2410 &self.project
2411 }
2412
2413 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2414 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2415 return;
2416 }
2417
2418 let now = Instant::now();
2419 if let Some(last) = self.last_auto_capture_at {
2420 if now.duration_since(last).as_secs() < 10 {
2421 return;
2422 }
2423 }
2424
2425 self.last_auto_capture_at = Some(now);
2426
2427 let thread_id = self.id().clone();
2428 let github_login = self
2429 .project
2430 .read(cx)
2431 .user_store()
2432 .read(cx)
2433 .current_user()
2434 .map(|user| user.github_login.clone());
2435 let client = self.project.read(cx).client().clone();
2436 let serialize_task = self.serialize(cx);
2437
2438 cx.background_executor()
2439 .spawn(async move {
2440 if let Ok(serialized_thread) = serialize_task.await {
2441 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2442 telemetry::event!(
2443 "Agent Thread Auto-Captured",
2444 thread_id = thread_id.to_string(),
2445 thread_data = thread_data,
2446 auto_capture_reason = "tracked_user",
2447 github_login = github_login
2448 );
2449
2450 client.telemetry().flush_events().await;
2451 }
2452 }
2453 })
2454 .detach();
2455 }
2456
2457 pub fn cumulative_token_usage(&self) -> TokenUsage {
2458 self.cumulative_token_usage
2459 }
2460
2461 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2462 let Some(model) = self.configured_model.as_ref() else {
2463 return TotalTokenUsage::default();
2464 };
2465
2466 let max = model.model.max_token_count();
2467
2468 let index = self
2469 .messages
2470 .iter()
2471 .position(|msg| msg.id == message_id)
2472 .unwrap_or(0);
2473
2474 if index == 0 {
2475 return TotalTokenUsage { total: 0, max };
2476 }
2477
2478 let token_usage = &self
2479 .request_token_usage
2480 .get(index - 1)
2481 .cloned()
2482 .unwrap_or_default();
2483
2484 TotalTokenUsage {
2485 total: token_usage.total_tokens() as usize,
2486 max,
2487 }
2488 }
2489
2490 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2491 let model = self.configured_model.as_ref()?;
2492
2493 let max = model.model.max_token_count();
2494
2495 if let Some(exceeded_error) = &self.exceeded_window_error {
2496 if model.model.id() == exceeded_error.model_id {
2497 return Some(TotalTokenUsage {
2498 total: exceeded_error.token_count,
2499 max,
2500 });
2501 }
2502 }
2503
2504 let total = self
2505 .token_usage_at_last_message()
2506 .unwrap_or_default()
2507 .total_tokens() as usize;
2508
2509 Some(TotalTokenUsage { total, max })
2510 }
2511
2512 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2513 self.request_token_usage
2514 .get(self.messages.len().saturating_sub(1))
2515 .or_else(|| self.request_token_usage.last())
2516 .cloned()
2517 }
2518
2519 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2520 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2521 self.request_token_usage
2522 .resize(self.messages.len(), placeholder);
2523
2524 if let Some(last) = self.request_token_usage.last_mut() {
2525 *last = token_usage;
2526 }
2527 }
2528
2529 pub fn deny_tool_use(
2530 &mut self,
2531 tool_use_id: LanguageModelToolUseId,
2532 tool_name: Arc<str>,
2533 window: Option<AnyWindowHandle>,
2534 cx: &mut Context<Self>,
2535 ) {
2536 let err = Err(anyhow::anyhow!(
2537 "Permission to run tool action denied by user"
2538 ));
2539
2540 self.tool_use.insert_tool_output(
2541 tool_use_id.clone(),
2542 tool_name,
2543 err,
2544 self.configured_model.as_ref(),
2545 );
2546 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2547 }
2548}
2549
2550#[derive(Debug, Clone, Error)]
2551pub enum ThreadError {
2552 #[error("Payment required")]
2553 PaymentRequired,
2554 #[error("Max monthly spend reached")]
2555 MaxMonthlySpendReached,
2556 #[error("Model request limit reached")]
2557 ModelRequestLimitReached { plan: Plan },
2558 #[error("Message {header}: {message}")]
2559 Message {
2560 header: SharedString,
2561 message: SharedString,
2562 },
2563}
2564
2565#[derive(Debug, Clone)]
2566pub enum ThreadEvent {
2567 ShowError(ThreadError),
2568 UsageUpdated(RequestUsage),
2569 StreamedCompletion,
2570 ReceivedTextChunk,
2571 NewRequest,
2572 StreamedAssistantText(MessageId, String),
2573 StreamedAssistantThinking(MessageId, String),
2574 StreamedToolUse {
2575 tool_use_id: LanguageModelToolUseId,
2576 ui_text: Arc<str>,
2577 input: serde_json::Value,
2578 },
2579 InvalidToolInput {
2580 tool_use_id: LanguageModelToolUseId,
2581 ui_text: Arc<str>,
2582 invalid_input_json: Arc<str>,
2583 },
2584 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2585 MessageAdded(MessageId),
2586 MessageEdited(MessageId),
2587 MessageDeleted(MessageId),
2588 SummaryGenerated,
2589 SummaryChanged,
2590 UsePendingTools {
2591 tool_uses: Vec<PendingToolUse>,
2592 },
2593 ToolFinished {
2594 #[allow(unused)]
2595 tool_use_id: LanguageModelToolUseId,
2596 /// The pending tool use that corresponds to this tool.
2597 pending_tool_use: Option<PendingToolUse>,
2598 },
2599 CheckpointChanged,
2600 ToolConfirmationNeeded,
2601 CancelEditing,
2602 CompletionCanceled,
2603}
2604
2605impl EventEmitter<ThreadEvent> for Thread {}
2606
2607struct PendingCompletion {
2608 id: usize,
2609 queue_state: QueueState,
2610 _task: Task<()>,
2611}
2612
2613#[cfg(test)]
2614mod tests {
2615 use super::*;
2616 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2617 use assistant_settings::AssistantSettings;
2618 use assistant_tool::ToolRegistry;
2619 use context_server::ContextServerSettings;
2620 use editor::EditorSettings;
2621 use gpui::TestAppContext;
2622 use language_model::fake_provider::FakeLanguageModel;
2623 use project::{FakeFs, Project};
2624 use prompt_store::PromptBuilder;
2625 use serde_json::json;
2626 use settings::{Settings, SettingsStore};
2627 use std::sync::Arc;
2628 use theme::ThemeSettings;
2629 use util::path;
2630 use workspace::Workspace;
2631
2632 #[gpui::test]
2633 async fn test_message_with_context(cx: &mut TestAppContext) {
2634 init_test_settings(cx);
2635
2636 let project = create_test_project(
2637 cx,
2638 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2639 )
2640 .await;
2641
2642 let (_workspace, _thread_store, thread, context_store, model) =
2643 setup_test_environment(cx, project.clone()).await;
2644
2645 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2646 .await
2647 .unwrap();
2648
2649 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2650 let loaded_context = cx
2651 .update(|cx| load_context(vec![context], &project, &None, cx))
2652 .await;
2653
2654 // Insert user message with context
2655 let message_id = thread.update(cx, |thread, cx| {
2656 thread.insert_user_message(
2657 "Please explain this code",
2658 loaded_context,
2659 None,
2660 Vec::new(),
2661 cx,
2662 )
2663 });
2664
2665 // Check content and context in message object
2666 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2667
2668 // Use different path format strings based on platform for the test
2669 #[cfg(windows)]
2670 let path_part = r"test\code.rs";
2671 #[cfg(not(windows))]
2672 let path_part = "test/code.rs";
2673
2674 let expected_context = format!(
2675 r#"
2676<context>
2677The following items were attached by the user. They are up-to-date and don't need to be re-read.
2678
2679<files>
2680```rs {path_part}
2681fn main() {{
2682 println!("Hello, world!");
2683}}
2684```
2685</files>
2686</context>
2687"#
2688 );
2689
2690 assert_eq!(message.role, Role::User);
2691 assert_eq!(message.segments.len(), 1);
2692 assert_eq!(
2693 message.segments[0],
2694 MessageSegment::Text("Please explain this code".to_string())
2695 );
2696 assert_eq!(message.loaded_context.text, expected_context);
2697
2698 // Check message in request
2699 let request = thread.update(cx, |thread, cx| {
2700 thread.to_completion_request(model.clone(), cx)
2701 });
2702
2703 assert_eq!(request.messages.len(), 2);
2704 let expected_full_message = format!("{}Please explain this code", expected_context);
2705 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2706 }
2707
2708 #[gpui::test]
2709 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2710 init_test_settings(cx);
2711
2712 let project = create_test_project(
2713 cx,
2714 json!({
2715 "file1.rs": "fn function1() {}\n",
2716 "file2.rs": "fn function2() {}\n",
2717 "file3.rs": "fn function3() {}\n",
2718 "file4.rs": "fn function4() {}\n",
2719 }),
2720 )
2721 .await;
2722
2723 let (_, _thread_store, thread, context_store, model) =
2724 setup_test_environment(cx, project.clone()).await;
2725
2726 // First message with context 1
2727 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2728 .await
2729 .unwrap();
2730 let new_contexts = context_store.update(cx, |store, cx| {
2731 store.new_context_for_thread(thread.read(cx), None)
2732 });
2733 assert_eq!(new_contexts.len(), 1);
2734 let loaded_context = cx
2735 .update(|cx| load_context(new_contexts, &project, &None, cx))
2736 .await;
2737 let message1_id = thread.update(cx, |thread, cx| {
2738 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2739 });
2740
2741 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2742 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2743 .await
2744 .unwrap();
2745 let new_contexts = context_store.update(cx, |store, cx| {
2746 store.new_context_for_thread(thread.read(cx), None)
2747 });
2748 assert_eq!(new_contexts.len(), 1);
2749 let loaded_context = cx
2750 .update(|cx| load_context(new_contexts, &project, &None, cx))
2751 .await;
2752 let message2_id = thread.update(cx, |thread, cx| {
2753 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2754 });
2755
2756 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2757 //
2758 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2759 .await
2760 .unwrap();
2761 let new_contexts = context_store.update(cx, |store, cx| {
2762 store.new_context_for_thread(thread.read(cx), None)
2763 });
2764 assert_eq!(new_contexts.len(), 1);
2765 let loaded_context = cx
2766 .update(|cx| load_context(new_contexts, &project, &None, cx))
2767 .await;
2768 let message3_id = thread.update(cx, |thread, cx| {
2769 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2770 });
2771
2772 // Check what contexts are included in each message
2773 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2774 (
2775 thread.message(message1_id).unwrap().clone(),
2776 thread.message(message2_id).unwrap().clone(),
2777 thread.message(message3_id).unwrap().clone(),
2778 )
2779 });
2780
2781 // First message should include context 1
2782 assert!(message1.loaded_context.text.contains("file1.rs"));
2783
2784 // Second message should include only context 2 (not 1)
2785 assert!(!message2.loaded_context.text.contains("file1.rs"));
2786 assert!(message2.loaded_context.text.contains("file2.rs"));
2787
2788 // Third message should include only context 3 (not 1 or 2)
2789 assert!(!message3.loaded_context.text.contains("file1.rs"));
2790 assert!(!message3.loaded_context.text.contains("file2.rs"));
2791 assert!(message3.loaded_context.text.contains("file3.rs"));
2792
2793 // Check entire request to make sure all contexts are properly included
2794 let request = thread.update(cx, |thread, cx| {
2795 thread.to_completion_request(model.clone(), cx)
2796 });
2797
2798 // The request should contain all 3 messages
2799 assert_eq!(request.messages.len(), 4);
2800
2801 // Check that the contexts are properly formatted in each message
2802 assert!(request.messages[1].string_contents().contains("file1.rs"));
2803 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2804 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2805
2806 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2807 assert!(request.messages[2].string_contents().contains("file2.rs"));
2808 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2809
2810 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2811 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2812 assert!(request.messages[3].string_contents().contains("file3.rs"));
2813
2814 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2815 .await
2816 .unwrap();
2817 let new_contexts = context_store.update(cx, |store, cx| {
2818 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2819 });
2820 assert_eq!(new_contexts.len(), 3);
2821 let loaded_context = cx
2822 .update(|cx| load_context(new_contexts, &project, &None, cx))
2823 .await
2824 .loaded_context;
2825
2826 assert!(!loaded_context.text.contains("file1.rs"));
2827 assert!(loaded_context.text.contains("file2.rs"));
2828 assert!(loaded_context.text.contains("file3.rs"));
2829 assert!(loaded_context.text.contains("file4.rs"));
2830
2831 let new_contexts = context_store.update(cx, |store, cx| {
2832 // Remove file4.rs
2833 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2834 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2835 });
2836 assert_eq!(new_contexts.len(), 2);
2837 let loaded_context = cx
2838 .update(|cx| load_context(new_contexts, &project, &None, cx))
2839 .await
2840 .loaded_context;
2841
2842 assert!(!loaded_context.text.contains("file1.rs"));
2843 assert!(loaded_context.text.contains("file2.rs"));
2844 assert!(loaded_context.text.contains("file3.rs"));
2845 assert!(!loaded_context.text.contains("file4.rs"));
2846
2847 let new_contexts = context_store.update(cx, |store, cx| {
2848 // Remove file3.rs
2849 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2850 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2851 });
2852 assert_eq!(new_contexts.len(), 1);
2853 let loaded_context = cx
2854 .update(|cx| load_context(new_contexts, &project, &None, cx))
2855 .await
2856 .loaded_context;
2857
2858 assert!(!loaded_context.text.contains("file1.rs"));
2859 assert!(loaded_context.text.contains("file2.rs"));
2860 assert!(!loaded_context.text.contains("file3.rs"));
2861 assert!(!loaded_context.text.contains("file4.rs"));
2862 }
2863
2864 #[gpui::test]
2865 async fn test_message_without_files(cx: &mut TestAppContext) {
2866 init_test_settings(cx);
2867
2868 let project = create_test_project(
2869 cx,
2870 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2871 )
2872 .await;
2873
2874 let (_, _thread_store, thread, _context_store, model) =
2875 setup_test_environment(cx, project.clone()).await;
2876
2877 // Insert user message without any context (empty context vector)
2878 let message_id = thread.update(cx, |thread, cx| {
2879 thread.insert_user_message(
2880 "What is the best way to learn Rust?",
2881 ContextLoadResult::default(),
2882 None,
2883 Vec::new(),
2884 cx,
2885 )
2886 });
2887
2888 // Check content and context in message object
2889 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2890
2891 // Context should be empty when no files are included
2892 assert_eq!(message.role, Role::User);
2893 assert_eq!(message.segments.len(), 1);
2894 assert_eq!(
2895 message.segments[0],
2896 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2897 );
2898 assert_eq!(message.loaded_context.text, "");
2899
2900 // Check message in request
2901 let request = thread.update(cx, |thread, cx| {
2902 thread.to_completion_request(model.clone(), cx)
2903 });
2904
2905 assert_eq!(request.messages.len(), 2);
2906 assert_eq!(
2907 request.messages[1].string_contents(),
2908 "What is the best way to learn Rust?"
2909 );
2910
2911 // Add second message, also without context
2912 let message2_id = thread.update(cx, |thread, cx| {
2913 thread.insert_user_message(
2914 "Are there any good books?",
2915 ContextLoadResult::default(),
2916 None,
2917 Vec::new(),
2918 cx,
2919 )
2920 });
2921
2922 let message2 =
2923 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2924 assert_eq!(message2.loaded_context.text, "");
2925
2926 // Check that both messages appear in the request
2927 let request = thread.update(cx, |thread, cx| {
2928 thread.to_completion_request(model.clone(), cx)
2929 });
2930
2931 assert_eq!(request.messages.len(), 3);
2932 assert_eq!(
2933 request.messages[1].string_contents(),
2934 "What is the best way to learn Rust?"
2935 );
2936 assert_eq!(
2937 request.messages[2].string_contents(),
2938 "Are there any good books?"
2939 );
2940 }
2941
2942 #[gpui::test]
2943 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2944 init_test_settings(cx);
2945
2946 let project = create_test_project(
2947 cx,
2948 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2949 )
2950 .await;
2951
2952 let (_workspace, _thread_store, thread, context_store, model) =
2953 setup_test_environment(cx, project.clone()).await;
2954
2955 // Open buffer and add it to context
2956 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2957 .await
2958 .unwrap();
2959
2960 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2961 let loaded_context = cx
2962 .update(|cx| load_context(vec![context], &project, &None, cx))
2963 .await;
2964
2965 // Insert user message with the buffer as context
2966 thread.update(cx, |thread, cx| {
2967 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2968 });
2969
2970 // Create a request and check that it doesn't have a stale buffer warning yet
2971 let initial_request = thread.update(cx, |thread, cx| {
2972 thread.to_completion_request(model.clone(), cx)
2973 });
2974
2975 // Make sure we don't have a stale file warning yet
2976 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2977 msg.string_contents()
2978 .contains("These files changed since last read:")
2979 });
2980 assert!(
2981 !has_stale_warning,
2982 "Should not have stale buffer warning before buffer is modified"
2983 );
2984
2985 // Modify the buffer
2986 buffer.update(cx, |buffer, cx| {
2987 // Find a position at the end of line 1
2988 buffer.edit(
2989 [(1..1, "\n println!(\"Added a new line\");\n")],
2990 None,
2991 cx,
2992 );
2993 });
2994
2995 // Insert another user message without context
2996 thread.update(cx, |thread, cx| {
2997 thread.insert_user_message(
2998 "What does the code do now?",
2999 ContextLoadResult::default(),
3000 None,
3001 Vec::new(),
3002 cx,
3003 )
3004 });
3005
3006 // Create a new request and check for the stale buffer warning
3007 let new_request = thread.update(cx, |thread, cx| {
3008 thread.to_completion_request(model.clone(), cx)
3009 });
3010
3011 // We should have a stale file warning as the last message
3012 let last_message = new_request
3013 .messages
3014 .last()
3015 .expect("Request should have messages");
3016
3017 // The last message should be the stale buffer notification
3018 assert_eq!(last_message.role, Role::User);
3019
3020 // Check the exact content of the message
3021 let expected_content = "These files changed since last read:\n- code.rs\n";
3022 assert_eq!(
3023 last_message.string_contents(),
3024 expected_content,
3025 "Last message should be exactly the stale buffer notification"
3026 );
3027 }
3028
3029 fn init_test_settings(cx: &mut TestAppContext) {
3030 cx.update(|cx| {
3031 let settings_store = SettingsStore::test(cx);
3032 cx.set_global(settings_store);
3033 language::init(cx);
3034 Project::init_settings(cx);
3035 AssistantSettings::register(cx);
3036 prompt_store::init(cx);
3037 thread_store::init(cx);
3038 workspace::init_settings(cx);
3039 language_model::init_settings(cx);
3040 ThemeSettings::register(cx);
3041 ContextServerSettings::register(cx);
3042 EditorSettings::register(cx);
3043 ToolRegistry::default_global(cx);
3044 });
3045 }
3046
3047 // Helper to create a test project with test files
3048 async fn create_test_project(
3049 cx: &mut TestAppContext,
3050 files: serde_json::Value,
3051 ) -> Entity<Project> {
3052 let fs = FakeFs::new(cx.executor());
3053 fs.insert_tree(path!("/test"), files).await;
3054 Project::test(fs, [path!("/test").as_ref()], cx).await
3055 }
3056
3057 async fn setup_test_environment(
3058 cx: &mut TestAppContext,
3059 project: Entity<Project>,
3060 ) -> (
3061 Entity<Workspace>,
3062 Entity<ThreadStore>,
3063 Entity<Thread>,
3064 Entity<ContextStore>,
3065 Arc<dyn LanguageModel>,
3066 ) {
3067 let (workspace, cx) =
3068 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3069
3070 let thread_store = cx
3071 .update(|_, cx| {
3072 ThreadStore::load(
3073 project.clone(),
3074 cx.new(|_| ToolWorkingSet::default()),
3075 None,
3076 Arc::new(PromptBuilder::new(None).unwrap()),
3077 cx,
3078 )
3079 })
3080 .await
3081 .unwrap();
3082
3083 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3084 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3085
3086 let model = FakeLanguageModel::default();
3087 let model: Arc<dyn LanguageModel> = Arc::new(model);
3088
3089 (workspace, thread_store, thread, context_store, model)
3090 }
3091
3092 async fn add_file_to_context(
3093 project: &Entity<Project>,
3094 context_store: &Entity<ContextStore>,
3095 path: &str,
3096 cx: &mut TestAppContext,
3097 ) -> Result<Entity<language::Buffer>> {
3098 let buffer_path = project
3099 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3100 .unwrap();
3101
3102 let buffer = project
3103 .update(cx, |project, cx| {
3104 project.open_buffer(buffer_path.clone(), cx)
3105 })
3106 .await
3107 .unwrap();
3108
3109 context_store.update(cx, |context_store, cx| {
3110 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3111 });
3112
3113 Ok(buffer)
3114 }
3115}