1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
23};
24use project::Project;
25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
26use prompt_store::PromptBuilder;
27use proto::Plan;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::Settings;
31use thiserror::Error;
32use util::{ResultExt as _, TryFutureExt as _, post_inc};
33use uuid::Uuid;
34
35use crate::context::{AssistantContext, ContextId, format_context_as_string};
36use crate::thread_store::{
37 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
38 SerializedToolUse, SharedProjectContext,
39};
40use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
41
42#[derive(Debug, Clone, Copy)]
43pub enum RequestKind {
44 Chat,
45 /// Used when summarizing a thread.
46 Summarize,
47}
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
73pub struct MessageId(pub(crate) usize);
74
75impl MessageId {
76 fn post_inc(&mut self) -> Self {
77 Self(post_inc(&mut self.0))
78 }
79}
80
81/// A message in a [`Thread`].
82#[derive(Debug, Clone)]
83pub struct Message {
84 pub id: MessageId,
85 pub role: Role,
86 pub segments: Vec<MessageSegment>,
87 pub context: String,
88}
89
90impl Message {
91 /// Returns whether the message contains any meaningful text that should be displayed
92 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
93 pub fn should_display_content(&self) -> bool {
94 self.segments.iter().all(|segment| segment.should_display())
95 }
96
97 pub fn push_thinking(&mut self, text: &str) {
98 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
99 segment.push_str(text);
100 } else {
101 self.segments
102 .push(MessageSegment::Thinking(text.to_string()));
103 }
104 }
105
106 pub fn push_text(&mut self, text: &str) {
107 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
108 segment.push_str(text);
109 } else {
110 self.segments.push(MessageSegment::Text(text.to_string()));
111 }
112 }
113
114 pub fn to_string(&self) -> String {
115 let mut result = String::new();
116
117 if !self.context.is_empty() {
118 result.push_str(&self.context);
119 }
120
121 for segment in &self.segments {
122 match segment {
123 MessageSegment::Text(text) => result.push_str(text),
124 MessageSegment::Thinking(text) => {
125 result.push_str("<think>");
126 result.push_str(text);
127 result.push_str("</think>");
128 }
129 }
130 }
131
132 result
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub enum MessageSegment {
138 Text(String),
139 Thinking(String),
140}
141
142impl MessageSegment {
143 pub fn text_mut(&mut self) -> &mut String {
144 match self {
145 Self::Text(text) => text,
146 Self::Thinking(text) => text,
147 }
148 }
149
150 pub fn should_display(&self) -> bool {
151 // We add USING_TOOL_MARKER when making a request that includes tool uses
152 // without non-whitespace text around them, and this can cause the model
153 // to mimic the pattern, so we consider those segments not displayable.
154 match self {
155 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
156 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ProjectSnapshot {
163 pub worktree_snapshots: Vec<WorktreeSnapshot>,
164 pub unsaved_buffer_paths: Vec<String>,
165 pub timestamp: DateTime<Utc>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct WorktreeSnapshot {
170 pub worktree_path: String,
171 pub git_state: Option<GitState>,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct GitState {
176 pub remote_url: Option<String>,
177 pub head_sha: Option<String>,
178 pub current_branch: Option<String>,
179 pub diff: Option<String>,
180}
181
182#[derive(Clone)]
183pub struct ThreadCheckpoint {
184 message_id: MessageId,
185 git_checkpoint: GitStoreCheckpoint,
186}
187
188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
189pub enum ThreadFeedback {
190 Positive,
191 Negative,
192}
193
194pub enum LastRestoreCheckpoint {
195 Pending {
196 message_id: MessageId,
197 },
198 Error {
199 message_id: MessageId,
200 error: String,
201 },
202}
203
204impl LastRestoreCheckpoint {
205 pub fn message_id(&self) -> MessageId {
206 match self {
207 LastRestoreCheckpoint::Pending { message_id } => *message_id,
208 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
209 }
210 }
211}
212
213#[derive(Clone, Debug, Default, Serialize, Deserialize)]
214pub enum DetailedSummaryState {
215 #[default]
216 NotGenerated,
217 Generating {
218 message_id: MessageId,
219 },
220 Generated {
221 text: SharedString,
222 message_id: MessageId,
223 },
224}
225
226#[derive(Default)]
227pub struct TotalTokenUsage {
228 pub total: usize,
229 pub max: usize,
230 pub ratio: TokenUsageRatio,
231}
232
233#[derive(Debug, Default, PartialEq, Eq)]
234pub enum TokenUsageRatio {
235 #[default]
236 Normal,
237 Warning,
238 Exceeded,
239}
240
241/// A thread of conversation with the LLM.
242pub struct Thread {
243 id: ThreadId,
244 updated_at: DateTime<Utc>,
245 summary: Option<SharedString>,
246 pending_summary: Task<Option<()>>,
247 detailed_summary_state: DetailedSummaryState,
248 messages: Vec<Message>,
249 next_message_id: MessageId,
250 context: BTreeMap<ContextId, AssistantContext>,
251 context_by_message: HashMap<MessageId, Vec<ContextId>>,
252 project_context: SharedProjectContext,
253 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
254 completion_count: usize,
255 pending_completions: Vec<PendingCompletion>,
256 project: Entity<Project>,
257 prompt_builder: Arc<PromptBuilder>,
258 tools: Entity<ToolWorkingSet>,
259 tool_use: ToolUseState,
260 action_log: Entity<ActionLog>,
261 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
262 pending_checkpoint: Option<ThreadCheckpoint>,
263 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
264 cumulative_token_usage: TokenUsage,
265 exceeded_window_error: Option<ExceededWindowError>,
266 feedback: Option<ThreadFeedback>,
267 message_feedback: HashMap<MessageId, ThreadFeedback>,
268 last_auto_capture_at: Option<Instant>,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ExceededWindowError {
273 /// Model used when last message exceeded context window
274 model_id: LanguageModelId,
275 /// Token count including last message
276 token_count: usize,
277}
278
279impl Thread {
280 pub fn new(
281 project: Entity<Project>,
282 tools: Entity<ToolWorkingSet>,
283 prompt_builder: Arc<PromptBuilder>,
284 system_prompt: SharedProjectContext,
285 cx: &mut Context<Self>,
286 ) -> Self {
287 Self {
288 id: ThreadId::new(),
289 updated_at: Utc::now(),
290 summary: None,
291 pending_summary: Task::ready(None),
292 detailed_summary_state: DetailedSummaryState::NotGenerated,
293 messages: Vec::new(),
294 next_message_id: MessageId(0),
295 context: BTreeMap::default(),
296 context_by_message: HashMap::default(),
297 project_context: system_prompt,
298 checkpoints_by_message: HashMap::default(),
299 completion_count: 0,
300 pending_completions: Vec::new(),
301 project: project.clone(),
302 prompt_builder,
303 tools: tools.clone(),
304 last_restore_checkpoint: None,
305 pending_checkpoint: None,
306 tool_use: ToolUseState::new(tools.clone()),
307 action_log: cx.new(|_| ActionLog::new(project.clone())),
308 initial_project_snapshot: {
309 let project_snapshot = Self::project_snapshot(project, cx);
310 cx.foreground_executor()
311 .spawn(async move { Some(project_snapshot.await) })
312 .shared()
313 },
314 cumulative_token_usage: TokenUsage::default(),
315 exceeded_window_error: None,
316 feedback: None,
317 message_feedback: HashMap::default(),
318 last_auto_capture_at: None,
319 }
320 }
321
322 pub fn deserialize(
323 id: ThreadId,
324 serialized: SerializedThread,
325 project: Entity<Project>,
326 tools: Entity<ToolWorkingSet>,
327 prompt_builder: Arc<PromptBuilder>,
328 project_context: SharedProjectContext,
329 cx: &mut Context<Self>,
330 ) -> Self {
331 let next_message_id = MessageId(
332 serialized
333 .messages
334 .last()
335 .map(|message| message.id.0 + 1)
336 .unwrap_or(0),
337 );
338 let tool_use =
339 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
340
341 Self {
342 id,
343 updated_at: serialized.updated_at,
344 summary: Some(serialized.summary),
345 pending_summary: Task::ready(None),
346 detailed_summary_state: serialized.detailed_summary_state,
347 messages: serialized
348 .messages
349 .into_iter()
350 .map(|message| Message {
351 id: message.id,
352 role: message.role,
353 segments: message
354 .segments
355 .into_iter()
356 .map(|segment| match segment {
357 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
358 SerializedMessageSegment::Thinking { text } => {
359 MessageSegment::Thinking(text)
360 }
361 })
362 .collect(),
363 context: message.context,
364 })
365 .collect(),
366 next_message_id,
367 context: BTreeMap::default(),
368 context_by_message: HashMap::default(),
369 project_context,
370 checkpoints_by_message: HashMap::default(),
371 completion_count: 0,
372 pending_completions: Vec::new(),
373 last_restore_checkpoint: None,
374 pending_checkpoint: None,
375 project: project.clone(),
376 prompt_builder,
377 tools,
378 tool_use,
379 action_log: cx.new(|_| ActionLog::new(project)),
380 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
381 cumulative_token_usage: serialized.cumulative_token_usage,
382 exceeded_window_error: None,
383 feedback: None,
384 message_feedback: HashMap::default(),
385 last_auto_capture_at: None,
386 }
387 }
388
389 pub fn id(&self) -> &ThreadId {
390 &self.id
391 }
392
393 pub fn is_empty(&self) -> bool {
394 self.messages.is_empty()
395 }
396
397 pub fn updated_at(&self) -> DateTime<Utc> {
398 self.updated_at
399 }
400
401 pub fn touch_updated_at(&mut self) {
402 self.updated_at = Utc::now();
403 }
404
405 pub fn summary(&self) -> Option<SharedString> {
406 self.summary.clone()
407 }
408
409 pub fn project_context(&self) -> SharedProjectContext {
410 self.project_context.clone()
411 }
412
413 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
414
415 pub fn summary_or_default(&self) -> SharedString {
416 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
417 }
418
419 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
420 let Some(current_summary) = &self.summary else {
421 // Don't allow setting summary until generated
422 return;
423 };
424
425 let mut new_summary = new_summary.into();
426
427 if new_summary.is_empty() {
428 new_summary = Self::DEFAULT_SUMMARY;
429 }
430
431 if current_summary != &new_summary {
432 self.summary = Some(new_summary);
433 cx.emit(ThreadEvent::SummaryChanged);
434 }
435 }
436
437 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
438 self.latest_detailed_summary()
439 .unwrap_or_else(|| self.text().into())
440 }
441
442 fn latest_detailed_summary(&self) -> Option<SharedString> {
443 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
444 Some(text.clone())
445 } else {
446 None
447 }
448 }
449
450 pub fn message(&self, id: MessageId) -> Option<&Message> {
451 self.messages.iter().find(|message| message.id == id)
452 }
453
454 pub fn messages(&self) -> impl Iterator<Item = &Message> {
455 self.messages.iter()
456 }
457
458 pub fn is_generating(&self) -> bool {
459 !self.pending_completions.is_empty() || !self.all_tools_finished()
460 }
461
462 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
463 &self.tools
464 }
465
466 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
467 self.tool_use
468 .pending_tool_uses()
469 .into_iter()
470 .find(|tool_use| &tool_use.id == id)
471 }
472
473 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
474 self.tool_use
475 .pending_tool_uses()
476 .into_iter()
477 .filter(|tool_use| tool_use.status.needs_confirmation())
478 }
479
480 pub fn has_pending_tool_uses(&self) -> bool {
481 !self.tool_use.pending_tool_uses().is_empty()
482 }
483
484 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
485 self.checkpoints_by_message.get(&id).cloned()
486 }
487
488 pub fn restore_checkpoint(
489 &mut self,
490 checkpoint: ThreadCheckpoint,
491 cx: &mut Context<Self>,
492 ) -> Task<Result<()>> {
493 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
494 message_id: checkpoint.message_id,
495 });
496 cx.emit(ThreadEvent::CheckpointChanged);
497 cx.notify();
498
499 let git_store = self.project().read(cx).git_store().clone();
500 let restore = git_store.update(cx, |git_store, cx| {
501 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
502 });
503
504 cx.spawn(async move |this, cx| {
505 let result = restore.await;
506 this.update(cx, |this, cx| {
507 if let Err(err) = result.as_ref() {
508 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
509 message_id: checkpoint.message_id,
510 error: err.to_string(),
511 });
512 } else {
513 this.truncate(checkpoint.message_id, cx);
514 this.last_restore_checkpoint = None;
515 }
516 this.pending_checkpoint = None;
517 cx.emit(ThreadEvent::CheckpointChanged);
518 cx.notify();
519 })?;
520 result
521 })
522 }
523
524 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
525 let pending_checkpoint = if self.is_generating() {
526 return;
527 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
528 checkpoint
529 } else {
530 return;
531 };
532
533 let git_store = self.project.read(cx).git_store().clone();
534 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
535 cx.spawn(async move |this, cx| match final_checkpoint.await {
536 Ok(final_checkpoint) => {
537 let equal = git_store
538 .update(cx, |store, cx| {
539 store.compare_checkpoints(
540 pending_checkpoint.git_checkpoint.clone(),
541 final_checkpoint.clone(),
542 cx,
543 )
544 })?
545 .await
546 .unwrap_or(false);
547
548 if equal {
549 git_store
550 .update(cx, |store, cx| {
551 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
552 })?
553 .detach();
554 } else {
555 this.update(cx, |this, cx| {
556 this.insert_checkpoint(pending_checkpoint, cx)
557 })?;
558 }
559
560 git_store
561 .update(cx, |store, cx| {
562 store.delete_checkpoint(final_checkpoint, cx)
563 })?
564 .detach();
565
566 Ok(())
567 }
568 Err(_) => this.update(cx, |this, cx| {
569 this.insert_checkpoint(pending_checkpoint, cx)
570 }),
571 })
572 .detach();
573 }
574
575 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
576 self.checkpoints_by_message
577 .insert(checkpoint.message_id, checkpoint);
578 cx.emit(ThreadEvent::CheckpointChanged);
579 cx.notify();
580 }
581
582 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
583 self.last_restore_checkpoint.as_ref()
584 }
585
586 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
587 let Some(message_ix) = self
588 .messages
589 .iter()
590 .rposition(|message| message.id == message_id)
591 else {
592 return;
593 };
594 for deleted_message in self.messages.drain(message_ix..) {
595 self.context_by_message.remove(&deleted_message.id);
596 self.checkpoints_by_message.remove(&deleted_message.id);
597 }
598 cx.notify();
599 }
600
601 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
602 self.context_by_message
603 .get(&id)
604 .into_iter()
605 .flat_map(|context| {
606 context
607 .iter()
608 .filter_map(|context_id| self.context.get(&context_id))
609 })
610 }
611
612 /// Returns whether all of the tool uses have finished running.
613 pub fn all_tools_finished(&self) -> bool {
614 // If the only pending tool uses left are the ones with errors, then
615 // that means that we've finished running all of the pending tools.
616 self.tool_use
617 .pending_tool_uses()
618 .iter()
619 .all(|tool_use| tool_use.status.is_error())
620 }
621
622 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
623 self.tool_use.tool_uses_for_message(id, cx)
624 }
625
626 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
627 self.tool_use.tool_results_for_message(id)
628 }
629
630 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
631 self.tool_use.tool_result(id)
632 }
633
634 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
635 Some(&self.tool_use.tool_result(id)?.content)
636 }
637
638 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
639 self.tool_use.tool_result_card(id).cloned()
640 }
641
642 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
643 self.tool_use.message_has_tool_results(message_id)
644 }
645
646 pub fn insert_user_message(
647 &mut self,
648 text: impl Into<String>,
649 context: Vec<AssistantContext>,
650 git_checkpoint: Option<GitStoreCheckpoint>,
651 cx: &mut Context<Self>,
652 ) -> MessageId {
653 let text = text.into();
654
655 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
656
657 // Filter out contexts that have already been included in previous messages
658 let new_context: Vec<_> = context
659 .into_iter()
660 .filter(|ctx| !self.context.contains_key(&ctx.id()))
661 .collect();
662
663 if !new_context.is_empty() {
664 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
665 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
666 message.context = context_string;
667 }
668 }
669
670 self.action_log.update(cx, |log, cx| {
671 // Track all buffers added as context
672 for ctx in &new_context {
673 match ctx {
674 AssistantContext::File(file_ctx) => {
675 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
676 }
677 AssistantContext::Directory(dir_ctx) => {
678 for context_buffer in &dir_ctx.context_buffers {
679 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
680 }
681 }
682 AssistantContext::Symbol(symbol_ctx) => {
683 log.buffer_added_as_context(
684 symbol_ctx.context_symbol.buffer.clone(),
685 cx,
686 );
687 }
688 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
689 }
690 }
691 });
692 }
693
694 let context_ids = new_context
695 .iter()
696 .map(|context| context.id())
697 .collect::<Vec<_>>();
698 self.context.extend(
699 new_context
700 .into_iter()
701 .map(|context| (context.id(), context)),
702 );
703 self.context_by_message.insert(message_id, context_ids);
704
705 if let Some(git_checkpoint) = git_checkpoint {
706 self.pending_checkpoint = Some(ThreadCheckpoint {
707 message_id,
708 git_checkpoint,
709 });
710 }
711
712 self.auto_capture_telemetry(cx);
713
714 message_id
715 }
716
717 pub fn insert_message(
718 &mut self,
719 role: Role,
720 segments: Vec<MessageSegment>,
721 cx: &mut Context<Self>,
722 ) -> MessageId {
723 let id = self.next_message_id.post_inc();
724 self.messages.push(Message {
725 id,
726 role,
727 segments,
728 context: String::new(),
729 });
730 self.touch_updated_at();
731 cx.emit(ThreadEvent::MessageAdded(id));
732 id
733 }
734
735 pub fn edit_message(
736 &mut self,
737 id: MessageId,
738 new_role: Role,
739 new_segments: Vec<MessageSegment>,
740 cx: &mut Context<Self>,
741 ) -> bool {
742 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
743 return false;
744 };
745 message.role = new_role;
746 message.segments = new_segments;
747 self.touch_updated_at();
748 cx.emit(ThreadEvent::MessageEdited(id));
749 true
750 }
751
752 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
753 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
754 return false;
755 };
756 self.messages.remove(index);
757 self.context_by_message.remove(&id);
758 self.touch_updated_at();
759 cx.emit(ThreadEvent::MessageDeleted(id));
760 true
761 }
762
763 /// Returns the representation of this [`Thread`] in a textual form.
764 ///
765 /// This is the representation we use when attaching a thread as context to another thread.
766 pub fn text(&self) -> String {
767 let mut text = String::new();
768
769 for message in &self.messages {
770 text.push_str(match message.role {
771 language_model::Role::User => "User:",
772 language_model::Role::Assistant => "Assistant:",
773 language_model::Role::System => "System:",
774 });
775 text.push('\n');
776
777 for segment in &message.segments {
778 match segment {
779 MessageSegment::Text(content) => text.push_str(content),
780 MessageSegment::Thinking(content) => {
781 text.push_str(&format!("<think>{}</think>", content))
782 }
783 }
784 }
785 text.push('\n');
786 }
787
788 text
789 }
790
791 /// Serializes this thread into a format for storage or telemetry.
792 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
793 let initial_project_snapshot = self.initial_project_snapshot.clone();
794 cx.spawn(async move |this, cx| {
795 let initial_project_snapshot = initial_project_snapshot.await;
796 this.read_with(cx, |this, cx| SerializedThread {
797 version: SerializedThread::VERSION.to_string(),
798 summary: this.summary_or_default(),
799 updated_at: this.updated_at(),
800 messages: this
801 .messages()
802 .map(|message| SerializedMessage {
803 id: message.id,
804 role: message.role,
805 segments: message
806 .segments
807 .iter()
808 .map(|segment| match segment {
809 MessageSegment::Text(text) => {
810 SerializedMessageSegment::Text { text: text.clone() }
811 }
812 MessageSegment::Thinking(text) => {
813 SerializedMessageSegment::Thinking { text: text.clone() }
814 }
815 })
816 .collect(),
817 tool_uses: this
818 .tool_uses_for_message(message.id, cx)
819 .into_iter()
820 .map(|tool_use| SerializedToolUse {
821 id: tool_use.id,
822 name: tool_use.name,
823 input: tool_use.input,
824 })
825 .collect(),
826 tool_results: this
827 .tool_results_for_message(message.id)
828 .into_iter()
829 .map(|tool_result| SerializedToolResult {
830 tool_use_id: tool_result.tool_use_id.clone(),
831 is_error: tool_result.is_error,
832 content: tool_result.content.clone(),
833 })
834 .collect(),
835 context: message.context.clone(),
836 })
837 .collect(),
838 initial_project_snapshot,
839 cumulative_token_usage: this.cumulative_token_usage,
840 detailed_summary_state: this.detailed_summary_state.clone(),
841 exceeded_window_error: this.exceeded_window_error.clone(),
842 })
843 })
844 }
845
846 pub fn send_to_model(
847 &mut self,
848 model: Arc<dyn LanguageModel>,
849 request_kind: RequestKind,
850 cx: &mut Context<Self>,
851 ) {
852 let mut request = self.to_completion_request(request_kind, cx);
853 if model.supports_tools() {
854 request.tools = {
855 let mut tools = Vec::new();
856 tools.extend(
857 self.tools()
858 .read(cx)
859 .enabled_tools(cx)
860 .into_iter()
861 .filter_map(|tool| {
862 // Skip tools that cannot be supported
863 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
864 Some(LanguageModelRequestTool {
865 name: tool.name(),
866 description: tool.description(),
867 input_schema,
868 })
869 }),
870 );
871
872 tools
873 };
874 }
875
876 self.stream_completion(request, model, cx);
877 }
878
879 pub fn used_tools_since_last_user_message(&self) -> bool {
880 for message in self.messages.iter().rev() {
881 if self.tool_use.message_has_tool_results(message.id) {
882 return true;
883 } else if message.role == Role::User {
884 return false;
885 }
886 }
887
888 false
889 }
890
891 pub fn to_completion_request(
892 &self,
893 request_kind: RequestKind,
894 cx: &App,
895 ) -> LanguageModelRequest {
896 let mut request = LanguageModelRequest {
897 messages: vec![],
898 tools: Vec::new(),
899 stop: Vec::new(),
900 temperature: None,
901 };
902
903 if let Some(project_context) = self.project_context.borrow().as_ref() {
904 if let Some(system_prompt) = self
905 .prompt_builder
906 .generate_assistant_system_prompt(project_context)
907 .context("failed to generate assistant system prompt")
908 .log_err()
909 {
910 request.messages.push(LanguageModelRequestMessage {
911 role: Role::System,
912 content: vec![MessageContent::Text(system_prompt)],
913 cache: true,
914 });
915 }
916 } else {
917 log::error!("project_context not set.")
918 }
919
920 for message in &self.messages {
921 let mut request_message = LanguageModelRequestMessage {
922 role: message.role,
923 content: Vec::new(),
924 cache: false,
925 };
926
927 match request_kind {
928 RequestKind::Chat => {
929 self.tool_use
930 .attach_tool_results(message.id, &mut request_message);
931 }
932 RequestKind::Summarize => {
933 // We don't care about tool use during summarization.
934 if self.tool_use.message_has_tool_results(message.id) {
935 continue;
936 }
937 }
938 }
939
940 if !message.segments.is_empty() {
941 request_message
942 .content
943 .push(MessageContent::Text(message.to_string()));
944 }
945
946 match request_kind {
947 RequestKind::Chat => {
948 self.tool_use
949 .attach_tool_uses(message.id, &mut request_message);
950 }
951 RequestKind::Summarize => {
952 // We don't care about tool use during summarization.
953 }
954 };
955
956 request.messages.push(request_message);
957 }
958
959 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
960 if let Some(last) = request.messages.last_mut() {
961 last.cache = true;
962 }
963
964 self.attached_tracked_files_state(&mut request.messages, cx);
965
966 request
967 }
968
969 fn attached_tracked_files_state(
970 &self,
971 messages: &mut Vec<LanguageModelRequestMessage>,
972 cx: &App,
973 ) {
974 const STALE_FILES_HEADER: &str = "These files changed since last read:";
975
976 let mut stale_message = String::new();
977
978 let action_log = self.action_log.read(cx);
979
980 for stale_file in action_log.stale_buffers(cx) {
981 let Some(file) = stale_file.read(cx).file() else {
982 continue;
983 };
984
985 if stale_message.is_empty() {
986 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
987 }
988
989 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
990 }
991
992 let mut content = Vec::with_capacity(2);
993
994 if !stale_message.is_empty() {
995 content.push(stale_message.into());
996 }
997
998 if action_log.has_edited_files_since_project_diagnostics_check() {
999 content.push(
1000 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1001 and fix all errors AND warnings you introduced! \
1002 DO NOT mention you're going to do this until you're done."
1003 .into(),
1004 );
1005 }
1006
1007 if !content.is_empty() {
1008 let context_message = LanguageModelRequestMessage {
1009 role: Role::User,
1010 content,
1011 cache: false,
1012 };
1013
1014 messages.push(context_message);
1015 }
1016 }
1017
1018 pub fn stream_completion(
1019 &mut self,
1020 request: LanguageModelRequest,
1021 model: Arc<dyn LanguageModel>,
1022 cx: &mut Context<Self>,
1023 ) {
1024 let pending_completion_id = post_inc(&mut self.completion_count);
1025
1026 let task = cx.spawn(async move |thread, cx| {
1027 let stream = model.stream_completion(request, &cx);
1028 let initial_token_usage =
1029 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1030 let stream_completion = async {
1031 let mut events = stream.await?;
1032 let mut stop_reason = StopReason::EndTurn;
1033 let mut current_token_usage = TokenUsage::default();
1034
1035 while let Some(event) = events.next().await {
1036 let event = event?;
1037
1038 thread.update(cx, |thread, cx| {
1039 match event {
1040 LanguageModelCompletionEvent::StartMessage { .. } => {
1041 thread.insert_message(
1042 Role::Assistant,
1043 vec![MessageSegment::Text(String::new())],
1044 cx,
1045 );
1046 }
1047 LanguageModelCompletionEvent::Stop(reason) => {
1048 stop_reason = reason;
1049 }
1050 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1051 thread.cumulative_token_usage = thread.cumulative_token_usage
1052 + token_usage
1053 - current_token_usage;
1054 current_token_usage = token_usage;
1055 }
1056 LanguageModelCompletionEvent::Text(chunk) => {
1057 if let Some(last_message) = thread.messages.last_mut() {
1058 if last_message.role == Role::Assistant {
1059 last_message.push_text(&chunk);
1060 cx.emit(ThreadEvent::StreamedAssistantText(
1061 last_message.id,
1062 chunk,
1063 ));
1064 } else {
1065 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1066 // of a new Assistant response.
1067 //
1068 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1069 // will result in duplicating the text of the chunk in the rendered Markdown.
1070 thread.insert_message(
1071 Role::Assistant,
1072 vec![MessageSegment::Text(chunk.to_string())],
1073 cx,
1074 );
1075 };
1076 }
1077 }
1078 LanguageModelCompletionEvent::Thinking(chunk) => {
1079 if let Some(last_message) = thread.messages.last_mut() {
1080 if last_message.role == Role::Assistant {
1081 last_message.push_thinking(&chunk);
1082 cx.emit(ThreadEvent::StreamedAssistantThinking(
1083 last_message.id,
1084 chunk,
1085 ));
1086 } else {
1087 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1088 // of a new Assistant response.
1089 //
1090 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1091 // will result in duplicating the text of the chunk in the rendered Markdown.
1092 thread.insert_message(
1093 Role::Assistant,
1094 vec![MessageSegment::Thinking(chunk.to_string())],
1095 cx,
1096 );
1097 };
1098 }
1099 }
1100 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1101 let last_assistant_message_id = thread
1102 .messages
1103 .iter_mut()
1104 .rfind(|message| message.role == Role::Assistant)
1105 .map(|message| message.id)
1106 .unwrap_or_else(|| {
1107 thread.insert_message(Role::Assistant, vec![], cx)
1108 });
1109
1110 thread.tool_use.request_tool_use(
1111 last_assistant_message_id,
1112 tool_use,
1113 cx,
1114 );
1115 }
1116 }
1117
1118 thread.touch_updated_at();
1119 cx.emit(ThreadEvent::StreamedCompletion);
1120 cx.notify();
1121
1122 thread.auto_capture_telemetry(cx);
1123 })?;
1124
1125 smol::future::yield_now().await;
1126 }
1127
1128 thread.update(cx, |thread, cx| {
1129 thread
1130 .pending_completions
1131 .retain(|completion| completion.id != pending_completion_id);
1132
1133 if thread.summary.is_none() && thread.messages.len() >= 2 {
1134 thread.summarize(cx);
1135 }
1136 })?;
1137
1138 anyhow::Ok(stop_reason)
1139 };
1140
1141 let result = stream_completion.await;
1142
1143 thread
1144 .update(cx, |thread, cx| {
1145 thread.finalize_pending_checkpoint(cx);
1146 match result.as_ref() {
1147 Ok(stop_reason) => match stop_reason {
1148 StopReason::ToolUse => {
1149 let tool_uses = thread.use_pending_tools(cx);
1150 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1151 }
1152 StopReason::EndTurn => {}
1153 StopReason::MaxTokens => {}
1154 },
1155 Err(error) => {
1156 if error.is::<PaymentRequiredError>() {
1157 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1158 } else if error.is::<MaxMonthlySpendReachedError>() {
1159 cx.emit(ThreadEvent::ShowError(
1160 ThreadError::MaxMonthlySpendReached,
1161 ));
1162 } else if let Some(error) =
1163 error.downcast_ref::<ModelRequestLimitReachedError>()
1164 {
1165 cx.emit(ThreadEvent::ShowError(
1166 ThreadError::ModelRequestLimitReached { plan: error.plan },
1167 ));
1168 } else if let Some(known_error) =
1169 error.downcast_ref::<LanguageModelKnownError>()
1170 {
1171 match known_error {
1172 LanguageModelKnownError::ContextWindowLimitExceeded {
1173 tokens,
1174 } => {
1175 thread.exceeded_window_error = Some(ExceededWindowError {
1176 model_id: model.id(),
1177 token_count: *tokens,
1178 });
1179 cx.notify();
1180 }
1181 }
1182 } else {
1183 let error_message = error
1184 .chain()
1185 .map(|err| err.to_string())
1186 .collect::<Vec<_>>()
1187 .join("\n");
1188 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1189 header: "Error interacting with language model".into(),
1190 message: SharedString::from(error_message.clone()),
1191 }));
1192 }
1193
1194 thread.cancel_last_completion(cx);
1195 }
1196 }
1197 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1198
1199 thread.auto_capture_telemetry(cx);
1200
1201 if let Ok(initial_usage) = initial_token_usage {
1202 let usage = thread.cumulative_token_usage - initial_usage;
1203
1204 telemetry::event!(
1205 "Assistant Thread Completion",
1206 thread_id = thread.id().to_string(),
1207 model = model.telemetry_id(),
1208 model_provider = model.provider_id().to_string(),
1209 input_tokens = usage.input_tokens,
1210 output_tokens = usage.output_tokens,
1211 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1212 cache_read_input_tokens = usage.cache_read_input_tokens,
1213 );
1214 }
1215 })
1216 .ok();
1217 });
1218
1219 self.pending_completions.push(PendingCompletion {
1220 id: pending_completion_id,
1221 _task: task,
1222 });
1223 }
1224
1225 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1226 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1227 return;
1228 };
1229
1230 if !model.provider.is_authenticated(cx) {
1231 return;
1232 }
1233
1234 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1235 request.messages.push(LanguageModelRequestMessage {
1236 role: Role::User,
1237 content: vec![
1238 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1239 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1240 If the conversation is about a specific subject, include it in the title. \
1241 Be descriptive. DO NOT speak in the first person."
1242 .into(),
1243 ],
1244 cache: false,
1245 });
1246
1247 self.pending_summary = cx.spawn(async move |this, cx| {
1248 async move {
1249 let stream = model.model.stream_completion_text(request, &cx);
1250 let mut messages = stream.await?;
1251
1252 let mut new_summary = String::new();
1253 while let Some(message) = messages.stream.next().await {
1254 let text = message?;
1255 let mut lines = text.lines();
1256 new_summary.extend(lines.next());
1257
1258 // Stop if the LLM generated multiple lines.
1259 if lines.next().is_some() {
1260 break;
1261 }
1262 }
1263
1264 this.update(cx, |this, cx| {
1265 if !new_summary.is_empty() {
1266 this.summary = Some(new_summary.into());
1267 }
1268
1269 cx.emit(ThreadEvent::SummaryGenerated);
1270 })?;
1271
1272 anyhow::Ok(())
1273 }
1274 .log_err()
1275 .await
1276 });
1277 }
1278
1279 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1280 let last_message_id = self.messages.last().map(|message| message.id)?;
1281
1282 match &self.detailed_summary_state {
1283 DetailedSummaryState::Generating { message_id, .. }
1284 | DetailedSummaryState::Generated { message_id, .. }
1285 if *message_id == last_message_id =>
1286 {
1287 // Already up-to-date
1288 return None;
1289 }
1290 _ => {}
1291 }
1292
1293 let ConfiguredModel { model, provider } =
1294 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1295
1296 if !provider.is_authenticated(cx) {
1297 return None;
1298 }
1299
1300 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1301
1302 request.messages.push(LanguageModelRequestMessage {
1303 role: Role::User,
1304 content: vec![
1305 "Generate a detailed summary of this conversation. Include:\n\
1306 1. A brief overview of what was discussed\n\
1307 2. Key facts or information discovered\n\
1308 3. Outcomes or conclusions reached\n\
1309 4. Any action items or next steps if any\n\
1310 Format it in Markdown with headings and bullet points."
1311 .into(),
1312 ],
1313 cache: false,
1314 });
1315
1316 let task = cx.spawn(async move |thread, cx| {
1317 let stream = model.stream_completion_text(request, &cx);
1318 let Some(mut messages) = stream.await.log_err() else {
1319 thread
1320 .update(cx, |this, _cx| {
1321 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1322 })
1323 .log_err();
1324
1325 return;
1326 };
1327
1328 let mut new_detailed_summary = String::new();
1329
1330 while let Some(chunk) = messages.stream.next().await {
1331 if let Some(chunk) = chunk.log_err() {
1332 new_detailed_summary.push_str(&chunk);
1333 }
1334 }
1335
1336 thread
1337 .update(cx, |this, _cx| {
1338 this.detailed_summary_state = DetailedSummaryState::Generated {
1339 text: new_detailed_summary.into(),
1340 message_id: last_message_id,
1341 };
1342 })
1343 .log_err();
1344 });
1345
1346 self.detailed_summary_state = DetailedSummaryState::Generating {
1347 message_id: last_message_id,
1348 };
1349
1350 Some(task)
1351 }
1352
1353 pub fn is_generating_detailed_summary(&self) -> bool {
1354 matches!(
1355 self.detailed_summary_state,
1356 DetailedSummaryState::Generating { .. }
1357 )
1358 }
1359
1360 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1361 self.auto_capture_telemetry(cx);
1362 let request = self.to_completion_request(RequestKind::Chat, cx);
1363 let messages = Arc::new(request.messages);
1364 let pending_tool_uses = self
1365 .tool_use
1366 .pending_tool_uses()
1367 .into_iter()
1368 .filter(|tool_use| tool_use.status.is_idle())
1369 .cloned()
1370 .collect::<Vec<_>>();
1371
1372 for tool_use in pending_tool_uses.iter() {
1373 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1374 if tool.needs_confirmation(&tool_use.input, cx)
1375 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1376 {
1377 self.tool_use.confirm_tool_use(
1378 tool_use.id.clone(),
1379 tool_use.ui_text.clone(),
1380 tool_use.input.clone(),
1381 messages.clone(),
1382 tool,
1383 );
1384 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1385 } else {
1386 self.run_tool(
1387 tool_use.id.clone(),
1388 tool_use.ui_text.clone(),
1389 tool_use.input.clone(),
1390 &messages,
1391 tool,
1392 cx,
1393 );
1394 }
1395 }
1396 }
1397
1398 pending_tool_uses
1399 }
1400
1401 pub fn run_tool(
1402 &mut self,
1403 tool_use_id: LanguageModelToolUseId,
1404 ui_text: impl Into<SharedString>,
1405 input: serde_json::Value,
1406 messages: &[LanguageModelRequestMessage],
1407 tool: Arc<dyn Tool>,
1408 cx: &mut Context<Thread>,
1409 ) {
1410 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1411 self.tool_use
1412 .run_pending_tool(tool_use_id, ui_text.into(), task);
1413 }
1414
1415 fn spawn_tool_use(
1416 &mut self,
1417 tool_use_id: LanguageModelToolUseId,
1418 messages: &[LanguageModelRequestMessage],
1419 input: serde_json::Value,
1420 tool: Arc<dyn Tool>,
1421 cx: &mut Context<Thread>,
1422 ) -> Task<()> {
1423 let tool_name: Arc<str> = tool.name().into();
1424
1425 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1426 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1427 } else {
1428 tool.run(
1429 input,
1430 messages,
1431 self.project.clone(),
1432 self.action_log.clone(),
1433 cx,
1434 )
1435 };
1436
1437 // Store the card separately if it exists
1438 if let Some(card) = tool_result.card.clone() {
1439 self.tool_use
1440 .insert_tool_result_card(tool_use_id.clone(), card);
1441 }
1442
1443 cx.spawn({
1444 async move |thread: WeakEntity<Thread>, cx| {
1445 let output = tool_result.output.await;
1446
1447 thread
1448 .update(cx, |thread, cx| {
1449 let pending_tool_use = thread.tool_use.insert_tool_output(
1450 tool_use_id.clone(),
1451 tool_name,
1452 output,
1453 cx,
1454 );
1455 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1456 })
1457 .ok();
1458 }
1459 })
1460 }
1461
1462 fn tool_finished(
1463 &mut self,
1464 tool_use_id: LanguageModelToolUseId,
1465 pending_tool_use: Option<PendingToolUse>,
1466 canceled: bool,
1467 cx: &mut Context<Self>,
1468 ) {
1469 if self.all_tools_finished() {
1470 let model_registry = LanguageModelRegistry::read_global(cx);
1471 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1472 self.attach_tool_results(cx);
1473 if !canceled {
1474 self.send_to_model(model, RequestKind::Chat, cx);
1475 }
1476 }
1477 }
1478
1479 cx.emit(ThreadEvent::ToolFinished {
1480 tool_use_id,
1481 pending_tool_use,
1482 });
1483 }
1484
1485 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1486 // Insert a user message to contain the tool results.
1487 self.insert_user_message(
1488 // TODO: Sending up a user message without any content results in the model sending back
1489 // responses that also don't have any content. We currently don't handle this case well,
1490 // so for now we provide some text to keep the model on track.
1491 "Here are the tool results.",
1492 Vec::new(),
1493 None,
1494 cx,
1495 );
1496 }
1497
1498 /// Cancels the last pending completion, if there are any pending.
1499 ///
1500 /// Returns whether a completion was canceled.
1501 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1502 let canceled = if self.pending_completions.pop().is_some() {
1503 true
1504 } else {
1505 let mut canceled = false;
1506 for pending_tool_use in self.tool_use.cancel_pending() {
1507 canceled = true;
1508 self.tool_finished(
1509 pending_tool_use.id.clone(),
1510 Some(pending_tool_use),
1511 true,
1512 cx,
1513 );
1514 }
1515 canceled
1516 };
1517 self.finalize_pending_checkpoint(cx);
1518 canceled
1519 }
1520
1521 pub fn feedback(&self) -> Option<ThreadFeedback> {
1522 self.feedback
1523 }
1524
1525 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1526 self.message_feedback.get(&message_id).copied()
1527 }
1528
1529 pub fn report_message_feedback(
1530 &mut self,
1531 message_id: MessageId,
1532 feedback: ThreadFeedback,
1533 cx: &mut Context<Self>,
1534 ) -> Task<Result<()>> {
1535 if self.message_feedback.get(&message_id) == Some(&feedback) {
1536 return Task::ready(Ok(()));
1537 }
1538
1539 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1540 let serialized_thread = self.serialize(cx);
1541 let thread_id = self.id().clone();
1542 let client = self.project.read(cx).client();
1543
1544 let enabled_tool_names: Vec<String> = self
1545 .tools()
1546 .read(cx)
1547 .enabled_tools(cx)
1548 .iter()
1549 .map(|tool| tool.name().to_string())
1550 .collect();
1551
1552 self.message_feedback.insert(message_id, feedback);
1553
1554 cx.notify();
1555
1556 let message_content = self
1557 .message(message_id)
1558 .map(|msg| msg.to_string())
1559 .unwrap_or_default();
1560
1561 cx.background_spawn(async move {
1562 let final_project_snapshot = final_project_snapshot.await;
1563 let serialized_thread = serialized_thread.await?;
1564 let thread_data =
1565 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1566
1567 let rating = match feedback {
1568 ThreadFeedback::Positive => "positive",
1569 ThreadFeedback::Negative => "negative",
1570 };
1571 telemetry::event!(
1572 "Assistant Thread Rated",
1573 rating,
1574 thread_id,
1575 enabled_tool_names,
1576 message_id = message_id.0,
1577 message_content,
1578 thread_data,
1579 final_project_snapshot
1580 );
1581 client.telemetry().flush_events();
1582
1583 Ok(())
1584 })
1585 }
1586
1587 pub fn report_feedback(
1588 &mut self,
1589 feedback: ThreadFeedback,
1590 cx: &mut Context<Self>,
1591 ) -> Task<Result<()>> {
1592 let last_assistant_message_id = self
1593 .messages
1594 .iter()
1595 .rev()
1596 .find(|msg| msg.role == Role::Assistant)
1597 .map(|msg| msg.id);
1598
1599 if let Some(message_id) = last_assistant_message_id {
1600 self.report_message_feedback(message_id, feedback, cx)
1601 } else {
1602 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1603 let serialized_thread = self.serialize(cx);
1604 let thread_id = self.id().clone();
1605 let client = self.project.read(cx).client();
1606 self.feedback = Some(feedback);
1607 cx.notify();
1608
1609 cx.background_spawn(async move {
1610 let final_project_snapshot = final_project_snapshot.await;
1611 let serialized_thread = serialized_thread.await?;
1612 let thread_data = serde_json::to_value(serialized_thread)
1613 .unwrap_or_else(|_| serde_json::Value::Null);
1614
1615 let rating = match feedback {
1616 ThreadFeedback::Positive => "positive",
1617 ThreadFeedback::Negative => "negative",
1618 };
1619 telemetry::event!(
1620 "Assistant Thread Rated",
1621 rating,
1622 thread_id,
1623 thread_data,
1624 final_project_snapshot
1625 );
1626 client.telemetry().flush_events();
1627
1628 Ok(())
1629 })
1630 }
1631 }
1632
1633 /// Create a snapshot of the current project state including git information and unsaved buffers.
1634 fn project_snapshot(
1635 project: Entity<Project>,
1636 cx: &mut Context<Self>,
1637 ) -> Task<Arc<ProjectSnapshot>> {
1638 let git_store = project.read(cx).git_store().clone();
1639 let worktree_snapshots: Vec<_> = project
1640 .read(cx)
1641 .visible_worktrees(cx)
1642 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1643 .collect();
1644
1645 cx.spawn(async move |_, cx| {
1646 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1647
1648 let mut unsaved_buffers = Vec::new();
1649 cx.update(|app_cx| {
1650 let buffer_store = project.read(app_cx).buffer_store();
1651 for buffer_handle in buffer_store.read(app_cx).buffers() {
1652 let buffer = buffer_handle.read(app_cx);
1653 if buffer.is_dirty() {
1654 if let Some(file) = buffer.file() {
1655 let path = file.path().to_string_lossy().to_string();
1656 unsaved_buffers.push(path);
1657 }
1658 }
1659 }
1660 })
1661 .ok();
1662
1663 Arc::new(ProjectSnapshot {
1664 worktree_snapshots,
1665 unsaved_buffer_paths: unsaved_buffers,
1666 timestamp: Utc::now(),
1667 })
1668 })
1669 }
1670
1671 fn worktree_snapshot(
1672 worktree: Entity<project::Worktree>,
1673 git_store: Entity<GitStore>,
1674 cx: &App,
1675 ) -> Task<WorktreeSnapshot> {
1676 cx.spawn(async move |cx| {
1677 // Get worktree path and snapshot
1678 let worktree_info = cx.update(|app_cx| {
1679 let worktree = worktree.read(app_cx);
1680 let path = worktree.abs_path().to_string_lossy().to_string();
1681 let snapshot = worktree.snapshot();
1682 (path, snapshot)
1683 });
1684
1685 let Ok((worktree_path, _snapshot)) = worktree_info else {
1686 return WorktreeSnapshot {
1687 worktree_path: String::new(),
1688 git_state: None,
1689 };
1690 };
1691
1692 let git_state = git_store
1693 .update(cx, |git_store, cx| {
1694 git_store
1695 .repositories()
1696 .values()
1697 .find(|repo| {
1698 repo.read(cx)
1699 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1700 .is_some()
1701 })
1702 .cloned()
1703 })
1704 .ok()
1705 .flatten()
1706 .map(|repo| {
1707 repo.update(cx, |repo, _| {
1708 let current_branch =
1709 repo.branch.as_ref().map(|branch| branch.name.to_string());
1710 repo.send_job(None, |state, _| async move {
1711 let RepositoryState::Local { backend, .. } = state else {
1712 return GitState {
1713 remote_url: None,
1714 head_sha: None,
1715 current_branch,
1716 diff: None,
1717 };
1718 };
1719
1720 let remote_url = backend.remote_url("origin");
1721 let head_sha = backend.head_sha();
1722 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1723
1724 GitState {
1725 remote_url,
1726 head_sha,
1727 current_branch,
1728 diff,
1729 }
1730 })
1731 })
1732 });
1733
1734 let git_state = match git_state {
1735 Some(git_state) => match git_state.ok() {
1736 Some(git_state) => git_state.await.ok(),
1737 None => None,
1738 },
1739 None => None,
1740 };
1741
1742 WorktreeSnapshot {
1743 worktree_path,
1744 git_state,
1745 }
1746 })
1747 }
1748
1749 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1750 let mut markdown = Vec::new();
1751
1752 if let Some(summary) = self.summary() {
1753 writeln!(markdown, "# {summary}\n")?;
1754 };
1755
1756 for message in self.messages() {
1757 writeln!(
1758 markdown,
1759 "## {role}\n",
1760 role = match message.role {
1761 Role::User => "User",
1762 Role::Assistant => "Assistant",
1763 Role::System => "System",
1764 }
1765 )?;
1766
1767 if !message.context.is_empty() {
1768 writeln!(markdown, "{}", message.context)?;
1769 }
1770
1771 for segment in &message.segments {
1772 match segment {
1773 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1774 MessageSegment::Thinking(text) => {
1775 writeln!(markdown, "<think>{}</think>\n", text)?
1776 }
1777 }
1778 }
1779
1780 for tool_use in self.tool_uses_for_message(message.id, cx) {
1781 writeln!(
1782 markdown,
1783 "**Use Tool: {} ({})**",
1784 tool_use.name, tool_use.id
1785 )?;
1786 writeln!(markdown, "```json")?;
1787 writeln!(
1788 markdown,
1789 "{}",
1790 serde_json::to_string_pretty(&tool_use.input)?
1791 )?;
1792 writeln!(markdown, "```")?;
1793 }
1794
1795 for tool_result in self.tool_results_for_message(message.id) {
1796 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1797 if tool_result.is_error {
1798 write!(markdown, " (Error)")?;
1799 }
1800
1801 writeln!(markdown, "**\n")?;
1802 writeln!(markdown, "{}", tool_result.content)?;
1803 }
1804 }
1805
1806 Ok(String::from_utf8_lossy(&markdown).to_string())
1807 }
1808
1809 pub fn keep_edits_in_range(
1810 &mut self,
1811 buffer: Entity<language::Buffer>,
1812 buffer_range: Range<language::Anchor>,
1813 cx: &mut Context<Self>,
1814 ) {
1815 self.action_log.update(cx, |action_log, cx| {
1816 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1817 });
1818 }
1819
1820 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1821 self.action_log
1822 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1823 }
1824
1825 pub fn reject_edits_in_ranges(
1826 &mut self,
1827 buffer: Entity<language::Buffer>,
1828 buffer_ranges: Vec<Range<language::Anchor>>,
1829 cx: &mut Context<Self>,
1830 ) -> Task<Result<()>> {
1831 self.action_log.update(cx, |action_log, cx| {
1832 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1833 })
1834 }
1835
1836 pub fn action_log(&self) -> &Entity<ActionLog> {
1837 &self.action_log
1838 }
1839
1840 pub fn project(&self) -> &Entity<Project> {
1841 &self.project
1842 }
1843
1844 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1845 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1846 return;
1847 }
1848
1849 let now = Instant::now();
1850 if let Some(last) = self.last_auto_capture_at {
1851 if now.duration_since(last).as_secs() < 10 {
1852 return;
1853 }
1854 }
1855
1856 self.last_auto_capture_at = Some(now);
1857
1858 let thread_id = self.id().clone();
1859 let github_login = self
1860 .project
1861 .read(cx)
1862 .user_store()
1863 .read(cx)
1864 .current_user()
1865 .map(|user| user.github_login.clone());
1866 let client = self.project.read(cx).client().clone();
1867 let serialize_task = self.serialize(cx);
1868
1869 cx.background_executor()
1870 .spawn(async move {
1871 if let Ok(serialized_thread) = serialize_task.await {
1872 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1873 telemetry::event!(
1874 "Agent Thread Auto-Captured",
1875 thread_id = thread_id.to_string(),
1876 thread_data = thread_data,
1877 auto_capture_reason = "tracked_user",
1878 github_login = github_login
1879 );
1880
1881 client.telemetry().flush_events();
1882 }
1883 }
1884 })
1885 .detach();
1886 }
1887
1888 pub fn cumulative_token_usage(&self) -> TokenUsage {
1889 self.cumulative_token_usage
1890 }
1891
1892 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1893 let model_registry = LanguageModelRegistry::read_global(cx);
1894 let Some(model) = model_registry.default_model() else {
1895 return TotalTokenUsage::default();
1896 };
1897
1898 let max = model.model.max_token_count();
1899
1900 if let Some(exceeded_error) = &self.exceeded_window_error {
1901 if model.model.id() == exceeded_error.model_id {
1902 return TotalTokenUsage {
1903 total: exceeded_error.token_count,
1904 max,
1905 ratio: TokenUsageRatio::Exceeded,
1906 };
1907 }
1908 }
1909
1910 #[cfg(debug_assertions)]
1911 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1912 .unwrap_or("0.8".to_string())
1913 .parse()
1914 .unwrap();
1915 #[cfg(not(debug_assertions))]
1916 let warning_threshold: f32 = 0.8;
1917
1918 let total = self.cumulative_token_usage.total_tokens() as usize;
1919
1920 let ratio = if total >= max {
1921 TokenUsageRatio::Exceeded
1922 } else if total as f32 / max as f32 >= warning_threshold {
1923 TokenUsageRatio::Warning
1924 } else {
1925 TokenUsageRatio::Normal
1926 };
1927
1928 TotalTokenUsage { total, max, ratio }
1929 }
1930
1931 pub fn deny_tool_use(
1932 &mut self,
1933 tool_use_id: LanguageModelToolUseId,
1934 tool_name: Arc<str>,
1935 cx: &mut Context<Self>,
1936 ) {
1937 let err = Err(anyhow::anyhow!(
1938 "Permission to run tool action denied by user"
1939 ));
1940
1941 self.tool_use
1942 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1943 self.tool_finished(tool_use_id.clone(), None, true, cx);
1944 }
1945}
1946
1947#[derive(Debug, Clone, Error)]
1948pub enum ThreadError {
1949 #[error("Payment required")]
1950 PaymentRequired,
1951 #[error("Max monthly spend reached")]
1952 MaxMonthlySpendReached,
1953 #[error("Model request limit reached")]
1954 ModelRequestLimitReached { plan: Plan },
1955 #[error("Message {header}: {message}")]
1956 Message {
1957 header: SharedString,
1958 message: SharedString,
1959 },
1960}
1961
1962#[derive(Debug, Clone)]
1963pub enum ThreadEvent {
1964 ShowError(ThreadError),
1965 StreamedCompletion,
1966 StreamedAssistantText(MessageId, String),
1967 StreamedAssistantThinking(MessageId, String),
1968 Stopped(Result<StopReason, Arc<anyhow::Error>>),
1969 MessageAdded(MessageId),
1970 MessageEdited(MessageId),
1971 MessageDeleted(MessageId),
1972 SummaryGenerated,
1973 SummaryChanged,
1974 UsePendingTools {
1975 tool_uses: Vec<PendingToolUse>,
1976 },
1977 ToolFinished {
1978 #[allow(unused)]
1979 tool_use_id: LanguageModelToolUseId,
1980 /// The pending tool use that corresponds to this tool.
1981 pending_tool_use: Option<PendingToolUse>,
1982 },
1983 CheckpointChanged,
1984 ToolConfirmationNeeded,
1985}
1986
1987impl EventEmitter<ThreadEvent> for Thread {}
1988
1989struct PendingCompletion {
1990 id: usize,
1991 _task: Task<()>,
1992}
1993
1994#[cfg(test)]
1995mod tests {
1996 use super::*;
1997 use crate::{ThreadStore, context_store::ContextStore, thread_store};
1998 use assistant_settings::AssistantSettings;
1999 use context_server::ContextServerSettings;
2000 use editor::EditorSettings;
2001 use gpui::TestAppContext;
2002 use project::{FakeFs, Project};
2003 use prompt_store::PromptBuilder;
2004 use serde_json::json;
2005 use settings::{Settings, SettingsStore};
2006 use std::sync::Arc;
2007 use theme::ThemeSettings;
2008 use util::path;
2009 use workspace::Workspace;
2010
2011 #[gpui::test]
2012 async fn test_message_with_context(cx: &mut TestAppContext) {
2013 init_test_settings(cx);
2014
2015 let project = create_test_project(
2016 cx,
2017 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2018 )
2019 .await;
2020
2021 let (_workspace, _thread_store, thread, context_store) =
2022 setup_test_environment(cx, project.clone()).await;
2023
2024 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2025 .await
2026 .unwrap();
2027
2028 let context =
2029 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2030
2031 // Insert user message with context
2032 let message_id = thread.update(cx, |thread, cx| {
2033 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2034 });
2035
2036 // Check content and context in message object
2037 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2038
2039 // Use different path format strings based on platform for the test
2040 #[cfg(windows)]
2041 let path_part = r"test\code.rs";
2042 #[cfg(not(windows))]
2043 let path_part = "test/code.rs";
2044
2045 let expected_context = format!(
2046 r#"
2047<context>
2048The following items were attached by the user. You don't need to use other tools to read them.
2049
2050<files>
2051```rs {path_part}
2052fn main() {{
2053 println!("Hello, world!");
2054}}
2055```
2056</files>
2057</context>
2058"#
2059 );
2060
2061 assert_eq!(message.role, Role::User);
2062 assert_eq!(message.segments.len(), 1);
2063 assert_eq!(
2064 message.segments[0],
2065 MessageSegment::Text("Please explain this code".to_string())
2066 );
2067 assert_eq!(message.context, expected_context);
2068
2069 // Check message in request
2070 let request = thread.read_with(cx, |thread, cx| {
2071 thread.to_completion_request(RequestKind::Chat, cx)
2072 });
2073
2074 assert_eq!(request.messages.len(), 2);
2075 let expected_full_message = format!("{}Please explain this code", expected_context);
2076 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2077 }
2078
2079 #[gpui::test]
2080 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2081 init_test_settings(cx);
2082
2083 let project = create_test_project(
2084 cx,
2085 json!({
2086 "file1.rs": "fn function1() {}\n",
2087 "file2.rs": "fn function2() {}\n",
2088 "file3.rs": "fn function3() {}\n",
2089 }),
2090 )
2091 .await;
2092
2093 let (_, _thread_store, thread, context_store) =
2094 setup_test_environment(cx, project.clone()).await;
2095
2096 // Open files individually
2097 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2098 .await
2099 .unwrap();
2100 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2101 .await
2102 .unwrap();
2103 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2104 .await
2105 .unwrap();
2106
2107 // Get the context objects
2108 let contexts = context_store.update(cx, |store, _| store.context().clone());
2109 assert_eq!(contexts.len(), 3);
2110
2111 // First message with context 1
2112 let message1_id = thread.update(cx, |thread, cx| {
2113 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2114 });
2115
2116 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2117 let message2_id = thread.update(cx, |thread, cx| {
2118 thread.insert_user_message(
2119 "Message 2",
2120 vec![contexts[0].clone(), contexts[1].clone()],
2121 None,
2122 cx,
2123 )
2124 });
2125
2126 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2127 let message3_id = thread.update(cx, |thread, cx| {
2128 thread.insert_user_message(
2129 "Message 3",
2130 vec![
2131 contexts[0].clone(),
2132 contexts[1].clone(),
2133 contexts[2].clone(),
2134 ],
2135 None,
2136 cx,
2137 )
2138 });
2139
2140 // Check what contexts are included in each message
2141 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2142 (
2143 thread.message(message1_id).unwrap().clone(),
2144 thread.message(message2_id).unwrap().clone(),
2145 thread.message(message3_id).unwrap().clone(),
2146 )
2147 });
2148
2149 // First message should include context 1
2150 assert!(message1.context.contains("file1.rs"));
2151
2152 // Second message should include only context 2 (not 1)
2153 assert!(!message2.context.contains("file1.rs"));
2154 assert!(message2.context.contains("file2.rs"));
2155
2156 // Third message should include only context 3 (not 1 or 2)
2157 assert!(!message3.context.contains("file1.rs"));
2158 assert!(!message3.context.contains("file2.rs"));
2159 assert!(message3.context.contains("file3.rs"));
2160
2161 // Check entire request to make sure all contexts are properly included
2162 let request = thread.read_with(cx, |thread, cx| {
2163 thread.to_completion_request(RequestKind::Chat, cx)
2164 });
2165
2166 // The request should contain all 3 messages
2167 assert_eq!(request.messages.len(), 4);
2168
2169 // Check that the contexts are properly formatted in each message
2170 assert!(request.messages[1].string_contents().contains("file1.rs"));
2171 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2172 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2173
2174 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2175 assert!(request.messages[2].string_contents().contains("file2.rs"));
2176 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2177
2178 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2179 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2180 assert!(request.messages[3].string_contents().contains("file3.rs"));
2181 }
2182
2183 #[gpui::test]
2184 async fn test_message_without_files(cx: &mut TestAppContext) {
2185 init_test_settings(cx);
2186
2187 let project = create_test_project(
2188 cx,
2189 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2190 )
2191 .await;
2192
2193 let (_, _thread_store, thread, _context_store) =
2194 setup_test_environment(cx, project.clone()).await;
2195
2196 // Insert user message without any context (empty context vector)
2197 let message_id = thread.update(cx, |thread, cx| {
2198 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2199 });
2200
2201 // Check content and context in message object
2202 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2203
2204 // Context should be empty when no files are included
2205 assert_eq!(message.role, Role::User);
2206 assert_eq!(message.segments.len(), 1);
2207 assert_eq!(
2208 message.segments[0],
2209 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2210 );
2211 assert_eq!(message.context, "");
2212
2213 // Check message in request
2214 let request = thread.read_with(cx, |thread, cx| {
2215 thread.to_completion_request(RequestKind::Chat, cx)
2216 });
2217
2218 assert_eq!(request.messages.len(), 2);
2219 assert_eq!(
2220 request.messages[1].string_contents(),
2221 "What is the best way to learn Rust?"
2222 );
2223
2224 // Add second message, also without context
2225 let message2_id = thread.update(cx, |thread, cx| {
2226 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2227 });
2228
2229 let message2 =
2230 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2231 assert_eq!(message2.context, "");
2232
2233 // Check that both messages appear in the request
2234 let request = thread.read_with(cx, |thread, cx| {
2235 thread.to_completion_request(RequestKind::Chat, cx)
2236 });
2237
2238 assert_eq!(request.messages.len(), 3);
2239 assert_eq!(
2240 request.messages[1].string_contents(),
2241 "What is the best way to learn Rust?"
2242 );
2243 assert_eq!(
2244 request.messages[2].string_contents(),
2245 "Are there any good books?"
2246 );
2247 }
2248
2249 #[gpui::test]
2250 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2251 init_test_settings(cx);
2252
2253 let project = create_test_project(
2254 cx,
2255 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2256 )
2257 .await;
2258
2259 let (_workspace, _thread_store, thread, context_store) =
2260 setup_test_environment(cx, project.clone()).await;
2261
2262 // Open buffer and add it to context
2263 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2264 .await
2265 .unwrap();
2266
2267 let context =
2268 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2269
2270 // Insert user message with the buffer as context
2271 thread.update(cx, |thread, cx| {
2272 thread.insert_user_message("Explain this code", vec![context], None, cx)
2273 });
2274
2275 // Create a request and check that it doesn't have a stale buffer warning yet
2276 let initial_request = thread.read_with(cx, |thread, cx| {
2277 thread.to_completion_request(RequestKind::Chat, cx)
2278 });
2279
2280 // Make sure we don't have a stale file warning yet
2281 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2282 msg.string_contents()
2283 .contains("These files changed since last read:")
2284 });
2285 assert!(
2286 !has_stale_warning,
2287 "Should not have stale buffer warning before buffer is modified"
2288 );
2289
2290 // Modify the buffer
2291 buffer.update(cx, |buffer, cx| {
2292 // Find a position at the end of line 1
2293 buffer.edit(
2294 [(1..1, "\n println!(\"Added a new line\");\n")],
2295 None,
2296 cx,
2297 );
2298 });
2299
2300 // Insert another user message without context
2301 thread.update(cx, |thread, cx| {
2302 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2303 });
2304
2305 // Create a new request and check for the stale buffer warning
2306 let new_request = thread.read_with(cx, |thread, cx| {
2307 thread.to_completion_request(RequestKind::Chat, cx)
2308 });
2309
2310 // We should have a stale file warning as the last message
2311 let last_message = new_request
2312 .messages
2313 .last()
2314 .expect("Request should have messages");
2315
2316 // The last message should be the stale buffer notification
2317 assert_eq!(last_message.role, Role::User);
2318
2319 // Check the exact content of the message
2320 let expected_content = "These files changed since last read:\n- code.rs\n";
2321 assert_eq!(
2322 last_message.string_contents(),
2323 expected_content,
2324 "Last message should be exactly the stale buffer notification"
2325 );
2326 }
2327
2328 fn init_test_settings(cx: &mut TestAppContext) {
2329 cx.update(|cx| {
2330 let settings_store = SettingsStore::test(cx);
2331 cx.set_global(settings_store);
2332 language::init(cx);
2333 Project::init_settings(cx);
2334 AssistantSettings::register(cx);
2335 thread_store::init(cx);
2336 workspace::init_settings(cx);
2337 ThemeSettings::register(cx);
2338 ContextServerSettings::register(cx);
2339 EditorSettings::register(cx);
2340 });
2341 }
2342
2343 // Helper to create a test project with test files
2344 async fn create_test_project(
2345 cx: &mut TestAppContext,
2346 files: serde_json::Value,
2347 ) -> Entity<Project> {
2348 let fs = FakeFs::new(cx.executor());
2349 fs.insert_tree(path!("/test"), files).await;
2350 Project::test(fs, [path!("/test").as_ref()], cx).await
2351 }
2352
2353 async fn setup_test_environment(
2354 cx: &mut TestAppContext,
2355 project: Entity<Project>,
2356 ) -> (
2357 Entity<Workspace>,
2358 Entity<ThreadStore>,
2359 Entity<Thread>,
2360 Entity<ContextStore>,
2361 ) {
2362 let (workspace, cx) =
2363 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2364
2365 let thread_store = cx
2366 .update(|_, cx| {
2367 ThreadStore::load(
2368 project.clone(),
2369 cx.new(|_| ToolWorkingSet::default()),
2370 Arc::new(PromptBuilder::new(None).unwrap()),
2371 cx,
2372 )
2373 })
2374 .await;
2375
2376 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2377 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2378
2379 (workspace, thread_store, thread, context_store)
2380 }
2381
2382 async fn add_file_to_context(
2383 project: &Entity<Project>,
2384 context_store: &Entity<ContextStore>,
2385 path: &str,
2386 cx: &mut TestAppContext,
2387 ) -> Result<Entity<language::Buffer>> {
2388 let buffer_path = project
2389 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2390 .unwrap();
2391
2392 let buffer = project
2393 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2394 .await
2395 .unwrap();
2396
2397 context_store
2398 .update(cx, |store, cx| {
2399 store.add_file_from_buffer(buffer.clone(), cx)
2400 })
2401 .await?;
2402
2403 Ok(buffer)
2404 }
2405}