1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
23};
24use project::Project;
25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
26use prompt_store::PromptBuilder;
27use proto::Plan;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::Settings;
31use thiserror::Error;
32use util::{ResultExt as _, TryFutureExt as _, post_inc};
33use uuid::Uuid;
34
35use crate::context::{AssistantContext, ContextId, format_context_as_string};
36use crate::thread_store::{
37 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
38 SerializedToolUse, SharedProjectContext,
39};
40use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
41
42#[derive(Debug, Clone, Copy)]
43pub enum RequestKind {
44 Chat,
45 /// Used when summarizing a thread.
46 Summarize,
47}
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
73pub struct MessageId(pub(crate) usize);
74
75impl MessageId {
76 fn post_inc(&mut self) -> Self {
77 Self(post_inc(&mut self.0))
78 }
79}
80
81/// A message in a [`Thread`].
82#[derive(Debug, Clone)]
83pub struct Message {
84 pub id: MessageId,
85 pub role: Role,
86 pub segments: Vec<MessageSegment>,
87 pub context: String,
88}
89
90impl Message {
91 /// Returns whether the message contains any meaningful text that should be displayed
92 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
93 pub fn should_display_content(&self) -> bool {
94 self.segments.iter().all(|segment| segment.should_display())
95 }
96
97 pub fn push_thinking(&mut self, text: &str) {
98 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
99 segment.push_str(text);
100 } else {
101 self.segments
102 .push(MessageSegment::Thinking(text.to_string()));
103 }
104 }
105
106 pub fn push_text(&mut self, text: &str) {
107 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
108 segment.push_str(text);
109 } else {
110 self.segments.push(MessageSegment::Text(text.to_string()));
111 }
112 }
113
114 pub fn to_string(&self) -> String {
115 let mut result = String::new();
116
117 if !self.context.is_empty() {
118 result.push_str(&self.context);
119 }
120
121 for segment in &self.segments {
122 match segment {
123 MessageSegment::Text(text) => result.push_str(text),
124 MessageSegment::Thinking(text) => {
125 result.push_str("<think>");
126 result.push_str(text);
127 result.push_str("</think>");
128 }
129 }
130 }
131
132 result
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub enum MessageSegment {
138 Text(String),
139 Thinking(String),
140}
141
142impl MessageSegment {
143 pub fn text_mut(&mut self) -> &mut String {
144 match self {
145 Self::Text(text) => text,
146 Self::Thinking(text) => text,
147 }
148 }
149
150 pub fn should_display(&self) -> bool {
151 // We add USING_TOOL_MARKER when making a request that includes tool uses
152 // without non-whitespace text around them, and this can cause the model
153 // to mimic the pattern, so we consider those segments not displayable.
154 match self {
155 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
156 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ProjectSnapshot {
163 pub worktree_snapshots: Vec<WorktreeSnapshot>,
164 pub unsaved_buffer_paths: Vec<String>,
165 pub timestamp: DateTime<Utc>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct WorktreeSnapshot {
170 pub worktree_path: String,
171 pub git_state: Option<GitState>,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct GitState {
176 pub remote_url: Option<String>,
177 pub head_sha: Option<String>,
178 pub current_branch: Option<String>,
179 pub diff: Option<String>,
180}
181
182#[derive(Clone)]
183pub struct ThreadCheckpoint {
184 message_id: MessageId,
185 git_checkpoint: GitStoreCheckpoint,
186}
187
188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
189pub enum ThreadFeedback {
190 Positive,
191 Negative,
192}
193
194pub enum LastRestoreCheckpoint {
195 Pending {
196 message_id: MessageId,
197 },
198 Error {
199 message_id: MessageId,
200 error: String,
201 },
202}
203
204impl LastRestoreCheckpoint {
205 pub fn message_id(&self) -> MessageId {
206 match self {
207 LastRestoreCheckpoint::Pending { message_id } => *message_id,
208 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
209 }
210 }
211}
212
213#[derive(Clone, Debug, Default, Serialize, Deserialize)]
214pub enum DetailedSummaryState {
215 #[default]
216 NotGenerated,
217 Generating {
218 message_id: MessageId,
219 },
220 Generated {
221 text: SharedString,
222 message_id: MessageId,
223 },
224}
225
226#[derive(Default)]
227pub struct TotalTokenUsage {
228 pub total: usize,
229 pub max: usize,
230 pub ratio: TokenUsageRatio,
231}
232
233#[derive(Debug, Default, PartialEq, Eq)]
234pub enum TokenUsageRatio {
235 #[default]
236 Normal,
237 Warning,
238 Exceeded,
239}
240
241/// A thread of conversation with the LLM.
242pub struct Thread {
243 id: ThreadId,
244 updated_at: DateTime<Utc>,
245 summary: Option<SharedString>,
246 pending_summary: Task<Option<()>>,
247 detailed_summary_state: DetailedSummaryState,
248 messages: Vec<Message>,
249 next_message_id: MessageId,
250 context: BTreeMap<ContextId, AssistantContext>,
251 context_by_message: HashMap<MessageId, Vec<ContextId>>,
252 project_context: SharedProjectContext,
253 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
254 completion_count: usize,
255 pending_completions: Vec<PendingCompletion>,
256 project: Entity<Project>,
257 prompt_builder: Arc<PromptBuilder>,
258 tools: Entity<ToolWorkingSet>,
259 tool_use: ToolUseState,
260 action_log: Entity<ActionLog>,
261 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
262 pending_checkpoint: Option<ThreadCheckpoint>,
263 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
264 cumulative_token_usage: TokenUsage,
265 exceeded_window_error: Option<ExceededWindowError>,
266 feedback: Option<ThreadFeedback>,
267 message_feedback: HashMap<MessageId, ThreadFeedback>,
268 last_auto_capture_at: Option<Instant>,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct ExceededWindowError {
273 /// Model used when last message exceeded context window
274 model_id: LanguageModelId,
275 /// Token count including last message
276 token_count: usize,
277}
278
279impl Thread {
280 pub fn new(
281 project: Entity<Project>,
282 tools: Entity<ToolWorkingSet>,
283 prompt_builder: Arc<PromptBuilder>,
284 system_prompt: SharedProjectContext,
285 cx: &mut Context<Self>,
286 ) -> Self {
287 Self {
288 id: ThreadId::new(),
289 updated_at: Utc::now(),
290 summary: None,
291 pending_summary: Task::ready(None),
292 detailed_summary_state: DetailedSummaryState::NotGenerated,
293 messages: Vec::new(),
294 next_message_id: MessageId(0),
295 context: BTreeMap::default(),
296 context_by_message: HashMap::default(),
297 project_context: system_prompt,
298 checkpoints_by_message: HashMap::default(),
299 completion_count: 0,
300 pending_completions: Vec::new(),
301 project: project.clone(),
302 prompt_builder,
303 tools: tools.clone(),
304 last_restore_checkpoint: None,
305 pending_checkpoint: None,
306 tool_use: ToolUseState::new(tools.clone()),
307 action_log: cx.new(|_| ActionLog::new(project.clone())),
308 initial_project_snapshot: {
309 let project_snapshot = Self::project_snapshot(project, cx);
310 cx.foreground_executor()
311 .spawn(async move { Some(project_snapshot.await) })
312 .shared()
313 },
314 cumulative_token_usage: TokenUsage::default(),
315 exceeded_window_error: None,
316 feedback: None,
317 message_feedback: HashMap::default(),
318 last_auto_capture_at: None,
319 }
320 }
321
322 pub fn deserialize(
323 id: ThreadId,
324 serialized: SerializedThread,
325 project: Entity<Project>,
326 tools: Entity<ToolWorkingSet>,
327 prompt_builder: Arc<PromptBuilder>,
328 project_context: SharedProjectContext,
329 cx: &mut Context<Self>,
330 ) -> Self {
331 let next_message_id = MessageId(
332 serialized
333 .messages
334 .last()
335 .map(|message| message.id.0 + 1)
336 .unwrap_or(0),
337 );
338 let tool_use =
339 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
340
341 Self {
342 id,
343 updated_at: serialized.updated_at,
344 summary: Some(serialized.summary),
345 pending_summary: Task::ready(None),
346 detailed_summary_state: serialized.detailed_summary_state,
347 messages: serialized
348 .messages
349 .into_iter()
350 .map(|message| Message {
351 id: message.id,
352 role: message.role,
353 segments: message
354 .segments
355 .into_iter()
356 .map(|segment| match segment {
357 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
358 SerializedMessageSegment::Thinking { text } => {
359 MessageSegment::Thinking(text)
360 }
361 })
362 .collect(),
363 context: message.context,
364 })
365 .collect(),
366 next_message_id,
367 context: BTreeMap::default(),
368 context_by_message: HashMap::default(),
369 project_context,
370 checkpoints_by_message: HashMap::default(),
371 completion_count: 0,
372 pending_completions: Vec::new(),
373 last_restore_checkpoint: None,
374 pending_checkpoint: None,
375 project: project.clone(),
376 prompt_builder,
377 tools,
378 tool_use,
379 action_log: cx.new(|_| ActionLog::new(project)),
380 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
381 cumulative_token_usage: serialized.cumulative_token_usage,
382 exceeded_window_error: None,
383 feedback: None,
384 message_feedback: HashMap::default(),
385 last_auto_capture_at: None,
386 }
387 }
388
389 pub fn id(&self) -> &ThreadId {
390 &self.id
391 }
392
393 pub fn is_empty(&self) -> bool {
394 self.messages.is_empty()
395 }
396
397 pub fn updated_at(&self) -> DateTime<Utc> {
398 self.updated_at
399 }
400
401 pub fn touch_updated_at(&mut self) {
402 self.updated_at = Utc::now();
403 }
404
405 pub fn summary(&self) -> Option<SharedString> {
406 self.summary.clone()
407 }
408
409 pub fn project_context(&self) -> SharedProjectContext {
410 self.project_context.clone()
411 }
412
413 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
414
415 pub fn summary_or_default(&self) -> SharedString {
416 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
417 }
418
419 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
420 let Some(current_summary) = &self.summary else {
421 // Don't allow setting summary until generated
422 return;
423 };
424
425 let mut new_summary = new_summary.into();
426
427 if new_summary.is_empty() {
428 new_summary = Self::DEFAULT_SUMMARY;
429 }
430
431 if current_summary != &new_summary {
432 self.summary = Some(new_summary);
433 cx.emit(ThreadEvent::SummaryChanged);
434 }
435 }
436
437 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
438 self.latest_detailed_summary()
439 .unwrap_or_else(|| self.text().into())
440 }
441
442 fn latest_detailed_summary(&self) -> Option<SharedString> {
443 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
444 Some(text.clone())
445 } else {
446 None
447 }
448 }
449
450 pub fn message(&self, id: MessageId) -> Option<&Message> {
451 self.messages.iter().find(|message| message.id == id)
452 }
453
454 pub fn messages(&self) -> impl Iterator<Item = &Message> {
455 self.messages.iter()
456 }
457
458 pub fn is_generating(&self) -> bool {
459 !self.pending_completions.is_empty() || !self.all_tools_finished()
460 }
461
462 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
463 &self.tools
464 }
465
466 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
467 self.tool_use
468 .pending_tool_uses()
469 .into_iter()
470 .find(|tool_use| &tool_use.id == id)
471 }
472
473 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
474 self.tool_use
475 .pending_tool_uses()
476 .into_iter()
477 .filter(|tool_use| tool_use.status.needs_confirmation())
478 }
479
480 pub fn has_pending_tool_uses(&self) -> bool {
481 !self.tool_use.pending_tool_uses().is_empty()
482 }
483
484 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
485 self.checkpoints_by_message.get(&id).cloned()
486 }
487
488 pub fn restore_checkpoint(
489 &mut self,
490 checkpoint: ThreadCheckpoint,
491 cx: &mut Context<Self>,
492 ) -> Task<Result<()>> {
493 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
494 message_id: checkpoint.message_id,
495 });
496 cx.emit(ThreadEvent::CheckpointChanged);
497 cx.notify();
498
499 let git_store = self.project().read(cx).git_store().clone();
500 let restore = git_store.update(cx, |git_store, cx| {
501 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
502 });
503
504 cx.spawn(async move |this, cx| {
505 let result = restore.await;
506 this.update(cx, |this, cx| {
507 if let Err(err) = result.as_ref() {
508 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
509 message_id: checkpoint.message_id,
510 error: err.to_string(),
511 });
512 } else {
513 this.truncate(checkpoint.message_id, cx);
514 this.last_restore_checkpoint = None;
515 }
516 this.pending_checkpoint = None;
517 cx.emit(ThreadEvent::CheckpointChanged);
518 cx.notify();
519 })?;
520 result
521 })
522 }
523
524 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
525 let pending_checkpoint = if self.is_generating() {
526 return;
527 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
528 checkpoint
529 } else {
530 return;
531 };
532
533 let git_store = self.project.read(cx).git_store().clone();
534 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
535 cx.spawn(async move |this, cx| match final_checkpoint.await {
536 Ok(final_checkpoint) => {
537 let equal = git_store
538 .update(cx, |store, cx| {
539 store.compare_checkpoints(
540 pending_checkpoint.git_checkpoint.clone(),
541 final_checkpoint.clone(),
542 cx,
543 )
544 })?
545 .await
546 .unwrap_or(false);
547
548 if equal {
549 git_store
550 .update(cx, |store, cx| {
551 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
552 })?
553 .detach();
554 } else {
555 this.update(cx, |this, cx| {
556 this.insert_checkpoint(pending_checkpoint, cx)
557 })?;
558 }
559
560 git_store
561 .update(cx, |store, cx| {
562 store.delete_checkpoint(final_checkpoint, cx)
563 })?
564 .detach();
565
566 Ok(())
567 }
568 Err(_) => this.update(cx, |this, cx| {
569 this.insert_checkpoint(pending_checkpoint, cx)
570 }),
571 })
572 .detach();
573 }
574
575 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
576 self.checkpoints_by_message
577 .insert(checkpoint.message_id, checkpoint);
578 cx.emit(ThreadEvent::CheckpointChanged);
579 cx.notify();
580 }
581
582 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
583 self.last_restore_checkpoint.as_ref()
584 }
585
586 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
587 let Some(message_ix) = self
588 .messages
589 .iter()
590 .rposition(|message| message.id == message_id)
591 else {
592 return;
593 };
594 for deleted_message in self.messages.drain(message_ix..) {
595 self.context_by_message.remove(&deleted_message.id);
596 self.checkpoints_by_message.remove(&deleted_message.id);
597 }
598 cx.notify();
599 }
600
601 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
602 self.context_by_message
603 .get(&id)
604 .into_iter()
605 .flat_map(|context| {
606 context
607 .iter()
608 .filter_map(|context_id| self.context.get(&context_id))
609 })
610 }
611
612 /// Returns whether all of the tool uses have finished running.
613 pub fn all_tools_finished(&self) -> bool {
614 // If the only pending tool uses left are the ones with errors, then
615 // that means that we've finished running all of the pending tools.
616 self.tool_use
617 .pending_tool_uses()
618 .iter()
619 .all(|tool_use| tool_use.status.is_error())
620 }
621
622 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
623 self.tool_use.tool_uses_for_message(id, cx)
624 }
625
626 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
627 self.tool_use.tool_results_for_message(id)
628 }
629
630 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
631 self.tool_use.tool_result(id)
632 }
633
634 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
635 self.tool_use.message_has_tool_results(message_id)
636 }
637
638 pub fn insert_user_message(
639 &mut self,
640 text: impl Into<String>,
641 context: Vec<AssistantContext>,
642 git_checkpoint: Option<GitStoreCheckpoint>,
643 cx: &mut Context<Self>,
644 ) -> MessageId {
645 let text = text.into();
646
647 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
648
649 // Filter out contexts that have already been included in previous messages
650 let new_context: Vec<_> = context
651 .into_iter()
652 .filter(|ctx| !self.context.contains_key(&ctx.id()))
653 .collect();
654
655 if !new_context.is_empty() {
656 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
657 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
658 message.context = context_string;
659 }
660 }
661
662 self.action_log.update(cx, |log, cx| {
663 // Track all buffers added as context
664 for ctx in &new_context {
665 match ctx {
666 AssistantContext::File(file_ctx) => {
667 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
668 }
669 AssistantContext::Directory(dir_ctx) => {
670 for context_buffer in &dir_ctx.context_buffers {
671 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
672 }
673 }
674 AssistantContext::Symbol(symbol_ctx) => {
675 log.buffer_added_as_context(
676 symbol_ctx.context_symbol.buffer.clone(),
677 cx,
678 );
679 }
680 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
681 }
682 }
683 });
684 }
685
686 let context_ids = new_context
687 .iter()
688 .map(|context| context.id())
689 .collect::<Vec<_>>();
690 self.context.extend(
691 new_context
692 .into_iter()
693 .map(|context| (context.id(), context)),
694 );
695 self.context_by_message.insert(message_id, context_ids);
696
697 if let Some(git_checkpoint) = git_checkpoint {
698 self.pending_checkpoint = Some(ThreadCheckpoint {
699 message_id,
700 git_checkpoint,
701 });
702 }
703
704 self.auto_capture_telemetry(cx);
705
706 message_id
707 }
708
709 pub fn insert_message(
710 &mut self,
711 role: Role,
712 segments: Vec<MessageSegment>,
713 cx: &mut Context<Self>,
714 ) -> MessageId {
715 let id = self.next_message_id.post_inc();
716 self.messages.push(Message {
717 id,
718 role,
719 segments,
720 context: String::new(),
721 });
722 self.touch_updated_at();
723 cx.emit(ThreadEvent::MessageAdded(id));
724 id
725 }
726
727 pub fn edit_message(
728 &mut self,
729 id: MessageId,
730 new_role: Role,
731 new_segments: Vec<MessageSegment>,
732 cx: &mut Context<Self>,
733 ) -> bool {
734 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
735 return false;
736 };
737 message.role = new_role;
738 message.segments = new_segments;
739 self.touch_updated_at();
740 cx.emit(ThreadEvent::MessageEdited(id));
741 true
742 }
743
744 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
745 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
746 return false;
747 };
748 self.messages.remove(index);
749 self.context_by_message.remove(&id);
750 self.touch_updated_at();
751 cx.emit(ThreadEvent::MessageDeleted(id));
752 true
753 }
754
755 /// Returns the representation of this [`Thread`] in a textual form.
756 ///
757 /// This is the representation we use when attaching a thread as context to another thread.
758 pub fn text(&self) -> String {
759 let mut text = String::new();
760
761 for message in &self.messages {
762 text.push_str(match message.role {
763 language_model::Role::User => "User:",
764 language_model::Role::Assistant => "Assistant:",
765 language_model::Role::System => "System:",
766 });
767 text.push('\n');
768
769 for segment in &message.segments {
770 match segment {
771 MessageSegment::Text(content) => text.push_str(content),
772 MessageSegment::Thinking(content) => {
773 text.push_str(&format!("<think>{}</think>", content))
774 }
775 }
776 }
777 text.push('\n');
778 }
779
780 text
781 }
782
783 /// Serializes this thread into a format for storage or telemetry.
784 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
785 let initial_project_snapshot = self.initial_project_snapshot.clone();
786 cx.spawn(async move |this, cx| {
787 let initial_project_snapshot = initial_project_snapshot.await;
788 this.read_with(cx, |this, cx| SerializedThread {
789 version: SerializedThread::VERSION.to_string(),
790 summary: this.summary_or_default(),
791 updated_at: this.updated_at(),
792 messages: this
793 .messages()
794 .map(|message| SerializedMessage {
795 id: message.id,
796 role: message.role,
797 segments: message
798 .segments
799 .iter()
800 .map(|segment| match segment {
801 MessageSegment::Text(text) => {
802 SerializedMessageSegment::Text { text: text.clone() }
803 }
804 MessageSegment::Thinking(text) => {
805 SerializedMessageSegment::Thinking { text: text.clone() }
806 }
807 })
808 .collect(),
809 tool_uses: this
810 .tool_uses_for_message(message.id, cx)
811 .into_iter()
812 .map(|tool_use| SerializedToolUse {
813 id: tool_use.id,
814 name: tool_use.name,
815 input: tool_use.input,
816 })
817 .collect(),
818 tool_results: this
819 .tool_results_for_message(message.id)
820 .into_iter()
821 .map(|tool_result| SerializedToolResult {
822 tool_use_id: tool_result.tool_use_id.clone(),
823 is_error: tool_result.is_error,
824 content: tool_result.content.clone(),
825 })
826 .collect(),
827 context: message.context.clone(),
828 })
829 .collect(),
830 initial_project_snapshot,
831 cumulative_token_usage: this.cumulative_token_usage,
832 detailed_summary_state: this.detailed_summary_state.clone(),
833 exceeded_window_error: this.exceeded_window_error.clone(),
834 })
835 })
836 }
837
838 pub fn send_to_model(
839 &mut self,
840 model: Arc<dyn LanguageModel>,
841 request_kind: RequestKind,
842 cx: &mut Context<Self>,
843 ) {
844 let mut request = self.to_completion_request(request_kind, cx);
845 if model.supports_tools() {
846 request.tools = {
847 let mut tools = Vec::new();
848 tools.extend(
849 self.tools()
850 .read(cx)
851 .enabled_tools(cx)
852 .into_iter()
853 .filter_map(|tool| {
854 // Skip tools that cannot be supported
855 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
856 Some(LanguageModelRequestTool {
857 name: tool.name(),
858 description: tool.description(),
859 input_schema,
860 })
861 }),
862 );
863
864 tools
865 };
866 }
867
868 self.stream_completion(request, model, cx);
869 }
870
871 pub fn used_tools_since_last_user_message(&self) -> bool {
872 for message in self.messages.iter().rev() {
873 if self.tool_use.message_has_tool_results(message.id) {
874 return true;
875 } else if message.role == Role::User {
876 return false;
877 }
878 }
879
880 false
881 }
882
883 pub fn to_completion_request(
884 &self,
885 request_kind: RequestKind,
886 cx: &App,
887 ) -> LanguageModelRequest {
888 let mut request = LanguageModelRequest {
889 messages: vec![],
890 tools: Vec::new(),
891 stop: Vec::new(),
892 temperature: None,
893 };
894
895 if let Some(project_context) = self.project_context.borrow().as_ref() {
896 if let Some(system_prompt) = self
897 .prompt_builder
898 .generate_assistant_system_prompt(project_context)
899 .context("failed to generate assistant system prompt")
900 .log_err()
901 {
902 request.messages.push(LanguageModelRequestMessage {
903 role: Role::System,
904 content: vec![MessageContent::Text(system_prompt)],
905 cache: true,
906 });
907 }
908 } else {
909 log::error!("project_context not set.")
910 }
911
912 for message in &self.messages {
913 let mut request_message = LanguageModelRequestMessage {
914 role: message.role,
915 content: Vec::new(),
916 cache: false,
917 };
918
919 match request_kind {
920 RequestKind::Chat => {
921 self.tool_use
922 .attach_tool_results(message.id, &mut request_message);
923 }
924 RequestKind::Summarize => {
925 // We don't care about tool use during summarization.
926 if self.tool_use.message_has_tool_results(message.id) {
927 continue;
928 }
929 }
930 }
931
932 if !message.segments.is_empty() {
933 request_message
934 .content
935 .push(MessageContent::Text(message.to_string()));
936 }
937
938 match request_kind {
939 RequestKind::Chat => {
940 self.tool_use
941 .attach_tool_uses(message.id, &mut request_message);
942 }
943 RequestKind::Summarize => {
944 // We don't care about tool use during summarization.
945 }
946 };
947
948 request.messages.push(request_message);
949 }
950
951 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
952 if let Some(last) = request.messages.last_mut() {
953 last.cache = true;
954 }
955
956 self.attached_tracked_files_state(&mut request.messages, cx);
957
958 request
959 }
960
961 fn attached_tracked_files_state(
962 &self,
963 messages: &mut Vec<LanguageModelRequestMessage>,
964 cx: &App,
965 ) {
966 const STALE_FILES_HEADER: &str = "These files changed since last read:";
967
968 let mut stale_message = String::new();
969
970 let action_log = self.action_log.read(cx);
971
972 for stale_file in action_log.stale_buffers(cx) {
973 let Some(file) = stale_file.read(cx).file() else {
974 continue;
975 };
976
977 if stale_message.is_empty() {
978 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
979 }
980
981 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
982 }
983
984 let mut content = Vec::with_capacity(2);
985
986 if !stale_message.is_empty() {
987 content.push(stale_message.into());
988 }
989
990 if action_log.has_edited_files_since_project_diagnostics_check() {
991 content.push(
992 "\n\nWhen you're done making changes, make sure to check project diagnostics \
993 and fix all errors AND warnings you introduced! \
994 DO NOT mention you're going to do this until you're done."
995 .into(),
996 );
997 }
998
999 if !content.is_empty() {
1000 let context_message = LanguageModelRequestMessage {
1001 role: Role::User,
1002 content,
1003 cache: false,
1004 };
1005
1006 messages.push(context_message);
1007 }
1008 }
1009
1010 pub fn stream_completion(
1011 &mut self,
1012 request: LanguageModelRequest,
1013 model: Arc<dyn LanguageModel>,
1014 cx: &mut Context<Self>,
1015 ) {
1016 let pending_completion_id = post_inc(&mut self.completion_count);
1017
1018 let task = cx.spawn(async move |thread, cx| {
1019 let stream = model.stream_completion(request, &cx);
1020 let initial_token_usage =
1021 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1022 let stream_completion = async {
1023 let mut events = stream.await?;
1024 let mut stop_reason = StopReason::EndTurn;
1025 let mut current_token_usage = TokenUsage::default();
1026
1027 while let Some(event) = events.next().await {
1028 let event = event?;
1029
1030 thread.update(cx, |thread, cx| {
1031 match event {
1032 LanguageModelCompletionEvent::StartMessage { .. } => {
1033 thread.insert_message(
1034 Role::Assistant,
1035 vec![MessageSegment::Text(String::new())],
1036 cx,
1037 );
1038 }
1039 LanguageModelCompletionEvent::Stop(reason) => {
1040 stop_reason = reason;
1041 }
1042 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1043 thread.cumulative_token_usage = thread.cumulative_token_usage
1044 + token_usage
1045 - current_token_usage;
1046 current_token_usage = token_usage;
1047 }
1048 LanguageModelCompletionEvent::Text(chunk) => {
1049 if let Some(last_message) = thread.messages.last_mut() {
1050 if last_message.role == Role::Assistant {
1051 last_message.push_text(&chunk);
1052 cx.emit(ThreadEvent::StreamedAssistantText(
1053 last_message.id,
1054 chunk,
1055 ));
1056 } else {
1057 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1058 // of a new Assistant response.
1059 //
1060 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1061 // will result in duplicating the text of the chunk in the rendered Markdown.
1062 thread.insert_message(
1063 Role::Assistant,
1064 vec![MessageSegment::Text(chunk.to_string())],
1065 cx,
1066 );
1067 };
1068 }
1069 }
1070 LanguageModelCompletionEvent::Thinking(chunk) => {
1071 if let Some(last_message) = thread.messages.last_mut() {
1072 if last_message.role == Role::Assistant {
1073 last_message.push_thinking(&chunk);
1074 cx.emit(ThreadEvent::StreamedAssistantThinking(
1075 last_message.id,
1076 chunk,
1077 ));
1078 } else {
1079 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1080 // of a new Assistant response.
1081 //
1082 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1083 // will result in duplicating the text of the chunk in the rendered Markdown.
1084 thread.insert_message(
1085 Role::Assistant,
1086 vec![MessageSegment::Thinking(chunk.to_string())],
1087 cx,
1088 );
1089 };
1090 }
1091 }
1092 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1093 let last_assistant_message_id = thread
1094 .messages
1095 .iter_mut()
1096 .rfind(|message| message.role == Role::Assistant)
1097 .map(|message| message.id)
1098 .unwrap_or_else(|| {
1099 thread.insert_message(Role::Assistant, vec![], cx)
1100 });
1101
1102 thread.tool_use.request_tool_use(
1103 last_assistant_message_id,
1104 tool_use,
1105 cx,
1106 );
1107 }
1108 }
1109
1110 thread.touch_updated_at();
1111 cx.emit(ThreadEvent::StreamedCompletion);
1112 cx.notify();
1113
1114 thread.auto_capture_telemetry(cx);
1115 })?;
1116
1117 smol::future::yield_now().await;
1118 }
1119
1120 thread.update(cx, |thread, cx| {
1121 thread
1122 .pending_completions
1123 .retain(|completion| completion.id != pending_completion_id);
1124
1125 if thread.summary.is_none() && thread.messages.len() >= 2 {
1126 thread.summarize(cx);
1127 }
1128 })?;
1129
1130 anyhow::Ok(stop_reason)
1131 };
1132
1133 let result = stream_completion.await;
1134
1135 thread
1136 .update(cx, |thread, cx| {
1137 thread.finalize_pending_checkpoint(cx);
1138 match result.as_ref() {
1139 Ok(stop_reason) => match stop_reason {
1140 StopReason::ToolUse => {
1141 let tool_uses = thread.use_pending_tools(cx);
1142 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1143 }
1144 StopReason::EndTurn => {}
1145 StopReason::MaxTokens => {}
1146 },
1147 Err(error) => {
1148 if error.is::<PaymentRequiredError>() {
1149 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1150 } else if error.is::<MaxMonthlySpendReachedError>() {
1151 cx.emit(ThreadEvent::ShowError(
1152 ThreadError::MaxMonthlySpendReached,
1153 ));
1154 } else if let Some(error) =
1155 error.downcast_ref::<ModelRequestLimitReachedError>()
1156 {
1157 cx.emit(ThreadEvent::ShowError(
1158 ThreadError::ModelRequestLimitReached { plan: error.plan },
1159 ));
1160 } else if let Some(known_error) =
1161 error.downcast_ref::<LanguageModelKnownError>()
1162 {
1163 match known_error {
1164 LanguageModelKnownError::ContextWindowLimitExceeded {
1165 tokens,
1166 } => {
1167 thread.exceeded_window_error = Some(ExceededWindowError {
1168 model_id: model.id(),
1169 token_count: *tokens,
1170 });
1171 cx.notify();
1172 }
1173 }
1174 } else {
1175 let error_message = error
1176 .chain()
1177 .map(|err| err.to_string())
1178 .collect::<Vec<_>>()
1179 .join("\n");
1180 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1181 header: "Error interacting with language model".into(),
1182 message: SharedString::from(error_message.clone()),
1183 }));
1184 }
1185
1186 thread.cancel_last_completion(cx);
1187 }
1188 }
1189 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1190
1191 thread.auto_capture_telemetry(cx);
1192
1193 if let Ok(initial_usage) = initial_token_usage {
1194 let usage = thread.cumulative_token_usage - initial_usage;
1195
1196 telemetry::event!(
1197 "Assistant Thread Completion",
1198 thread_id = thread.id().to_string(),
1199 model = model.telemetry_id(),
1200 model_provider = model.provider_id().to_string(),
1201 input_tokens = usage.input_tokens,
1202 output_tokens = usage.output_tokens,
1203 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1204 cache_read_input_tokens = usage.cache_read_input_tokens,
1205 );
1206 }
1207 })
1208 .ok();
1209 });
1210
1211 self.pending_completions.push(PendingCompletion {
1212 id: pending_completion_id,
1213 _task: task,
1214 });
1215 }
1216
1217 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1218 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1219 return;
1220 };
1221
1222 if !model.provider.is_authenticated(cx) {
1223 return;
1224 }
1225
1226 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1227 request.messages.push(LanguageModelRequestMessage {
1228 role: Role::User,
1229 content: vec![
1230 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1231 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1232 If the conversation is about a specific subject, include it in the title. \
1233 Be descriptive. DO NOT speak in the first person."
1234 .into(),
1235 ],
1236 cache: false,
1237 });
1238
1239 self.pending_summary = cx.spawn(async move |this, cx| {
1240 async move {
1241 let stream = model.model.stream_completion_text(request, &cx);
1242 let mut messages = stream.await?;
1243
1244 let mut new_summary = String::new();
1245 while let Some(message) = messages.stream.next().await {
1246 let text = message?;
1247 let mut lines = text.lines();
1248 new_summary.extend(lines.next());
1249
1250 // Stop if the LLM generated multiple lines.
1251 if lines.next().is_some() {
1252 break;
1253 }
1254 }
1255
1256 this.update(cx, |this, cx| {
1257 if !new_summary.is_empty() {
1258 this.summary = Some(new_summary.into());
1259 }
1260
1261 cx.emit(ThreadEvent::SummaryGenerated);
1262 })?;
1263
1264 anyhow::Ok(())
1265 }
1266 .log_err()
1267 .await
1268 });
1269 }
1270
1271 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1272 let last_message_id = self.messages.last().map(|message| message.id)?;
1273
1274 match &self.detailed_summary_state {
1275 DetailedSummaryState::Generating { message_id, .. }
1276 | DetailedSummaryState::Generated { message_id, .. }
1277 if *message_id == last_message_id =>
1278 {
1279 // Already up-to-date
1280 return None;
1281 }
1282 _ => {}
1283 }
1284
1285 let ConfiguredModel { model, provider } =
1286 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1287
1288 if !provider.is_authenticated(cx) {
1289 return None;
1290 }
1291
1292 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1293
1294 request.messages.push(LanguageModelRequestMessage {
1295 role: Role::User,
1296 content: vec![
1297 "Generate a detailed summary of this conversation. Include:\n\
1298 1. A brief overview of what was discussed\n\
1299 2. Key facts or information discovered\n\
1300 3. Outcomes or conclusions reached\n\
1301 4. Any action items or next steps if any\n\
1302 Format it in Markdown with headings and bullet points."
1303 .into(),
1304 ],
1305 cache: false,
1306 });
1307
1308 let task = cx.spawn(async move |thread, cx| {
1309 let stream = model.stream_completion_text(request, &cx);
1310 let Some(mut messages) = stream.await.log_err() else {
1311 thread
1312 .update(cx, |this, _cx| {
1313 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1314 })
1315 .log_err();
1316
1317 return;
1318 };
1319
1320 let mut new_detailed_summary = String::new();
1321
1322 while let Some(chunk) = messages.stream.next().await {
1323 if let Some(chunk) = chunk.log_err() {
1324 new_detailed_summary.push_str(&chunk);
1325 }
1326 }
1327
1328 thread
1329 .update(cx, |this, _cx| {
1330 this.detailed_summary_state = DetailedSummaryState::Generated {
1331 text: new_detailed_summary.into(),
1332 message_id: last_message_id,
1333 };
1334 })
1335 .log_err();
1336 });
1337
1338 self.detailed_summary_state = DetailedSummaryState::Generating {
1339 message_id: last_message_id,
1340 };
1341
1342 Some(task)
1343 }
1344
1345 pub fn is_generating_detailed_summary(&self) -> bool {
1346 matches!(
1347 self.detailed_summary_state,
1348 DetailedSummaryState::Generating { .. }
1349 )
1350 }
1351
1352 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1353 self.auto_capture_telemetry(cx);
1354 let request = self.to_completion_request(RequestKind::Chat, cx);
1355 let messages = Arc::new(request.messages);
1356 let pending_tool_uses = self
1357 .tool_use
1358 .pending_tool_uses()
1359 .into_iter()
1360 .filter(|tool_use| tool_use.status.is_idle())
1361 .cloned()
1362 .collect::<Vec<_>>();
1363
1364 for tool_use in pending_tool_uses.iter() {
1365 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1366 if tool.needs_confirmation(&tool_use.input, cx)
1367 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1368 {
1369 self.tool_use.confirm_tool_use(
1370 tool_use.id.clone(),
1371 tool_use.ui_text.clone(),
1372 tool_use.input.clone(),
1373 messages.clone(),
1374 tool,
1375 );
1376 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1377 } else {
1378 self.run_tool(
1379 tool_use.id.clone(),
1380 tool_use.ui_text.clone(),
1381 tool_use.input.clone(),
1382 &messages,
1383 tool,
1384 cx,
1385 );
1386 }
1387 }
1388 }
1389
1390 pending_tool_uses
1391 }
1392
1393 pub fn run_tool(
1394 &mut self,
1395 tool_use_id: LanguageModelToolUseId,
1396 ui_text: impl Into<SharedString>,
1397 input: serde_json::Value,
1398 messages: &[LanguageModelRequestMessage],
1399 tool: Arc<dyn Tool>,
1400 cx: &mut Context<Thread>,
1401 ) {
1402 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1403 self.tool_use
1404 .run_pending_tool(tool_use_id, ui_text.into(), task);
1405 }
1406
1407 fn spawn_tool_use(
1408 &mut self,
1409 tool_use_id: LanguageModelToolUseId,
1410 messages: &[LanguageModelRequestMessage],
1411 input: serde_json::Value,
1412 tool: Arc<dyn Tool>,
1413 cx: &mut Context<Thread>,
1414 ) -> Task<()> {
1415 let tool_name: Arc<str> = tool.name().into();
1416
1417 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1418 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1419 } else {
1420 tool.run(
1421 input,
1422 messages,
1423 self.project.clone(),
1424 self.action_log.clone(),
1425 cx,
1426 )
1427 };
1428
1429 cx.spawn({
1430 async move |thread: WeakEntity<Thread>, cx| {
1431 let output = tool_result.output.await;
1432
1433 thread
1434 .update(cx, |thread, cx| {
1435 let pending_tool_use = thread.tool_use.insert_tool_output(
1436 tool_use_id.clone(),
1437 tool_name,
1438 output,
1439 cx,
1440 );
1441 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1442 })
1443 .ok();
1444 }
1445 })
1446 }
1447
1448 fn tool_finished(
1449 &mut self,
1450 tool_use_id: LanguageModelToolUseId,
1451 pending_tool_use: Option<PendingToolUse>,
1452 canceled: bool,
1453 cx: &mut Context<Self>,
1454 ) {
1455 if self.all_tools_finished() {
1456 let model_registry = LanguageModelRegistry::read_global(cx);
1457 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1458 self.attach_tool_results(cx);
1459 if !canceled {
1460 self.send_to_model(model, RequestKind::Chat, cx);
1461 }
1462 }
1463 }
1464
1465 cx.emit(ThreadEvent::ToolFinished {
1466 tool_use_id,
1467 pending_tool_use,
1468 });
1469 }
1470
1471 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1472 // Insert a user message to contain the tool results.
1473 self.insert_user_message(
1474 // TODO: Sending up a user message without any content results in the model sending back
1475 // responses that also don't have any content. We currently don't handle this case well,
1476 // so for now we provide some text to keep the model on track.
1477 "Here are the tool results.",
1478 Vec::new(),
1479 None,
1480 cx,
1481 );
1482 }
1483
1484 /// Cancels the last pending completion, if there are any pending.
1485 ///
1486 /// Returns whether a completion was canceled.
1487 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1488 let canceled = if self.pending_completions.pop().is_some() {
1489 true
1490 } else {
1491 let mut canceled = false;
1492 for pending_tool_use in self.tool_use.cancel_pending() {
1493 canceled = true;
1494 self.tool_finished(
1495 pending_tool_use.id.clone(),
1496 Some(pending_tool_use),
1497 true,
1498 cx,
1499 );
1500 }
1501 canceled
1502 };
1503 self.finalize_pending_checkpoint(cx);
1504 canceled
1505 }
1506
1507 pub fn feedback(&self) -> Option<ThreadFeedback> {
1508 self.feedback
1509 }
1510
1511 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1512 self.message_feedback.get(&message_id).copied()
1513 }
1514
1515 pub fn report_message_feedback(
1516 &mut self,
1517 message_id: MessageId,
1518 feedback: ThreadFeedback,
1519 cx: &mut Context<Self>,
1520 ) -> Task<Result<()>> {
1521 if self.message_feedback.get(&message_id) == Some(&feedback) {
1522 return Task::ready(Ok(()));
1523 }
1524
1525 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1526 let serialized_thread = self.serialize(cx);
1527 let thread_id = self.id().clone();
1528 let client = self.project.read(cx).client();
1529
1530 let enabled_tool_names: Vec<String> = self
1531 .tools()
1532 .read(cx)
1533 .enabled_tools(cx)
1534 .iter()
1535 .map(|tool| tool.name().to_string())
1536 .collect();
1537
1538 self.message_feedback.insert(message_id, feedback);
1539
1540 cx.notify();
1541
1542 let message_content = self
1543 .message(message_id)
1544 .map(|msg| msg.to_string())
1545 .unwrap_or_default();
1546
1547 cx.background_spawn(async move {
1548 let final_project_snapshot = final_project_snapshot.await;
1549 let serialized_thread = serialized_thread.await?;
1550 let thread_data =
1551 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1552
1553 let rating = match feedback {
1554 ThreadFeedback::Positive => "positive",
1555 ThreadFeedback::Negative => "negative",
1556 };
1557 telemetry::event!(
1558 "Assistant Thread Rated",
1559 rating,
1560 thread_id,
1561 enabled_tool_names,
1562 message_id = message_id.0,
1563 message_content,
1564 thread_data,
1565 final_project_snapshot
1566 );
1567 client.telemetry().flush_events();
1568
1569 Ok(())
1570 })
1571 }
1572
1573 pub fn report_feedback(
1574 &mut self,
1575 feedback: ThreadFeedback,
1576 cx: &mut Context<Self>,
1577 ) -> Task<Result<()>> {
1578 let last_assistant_message_id = self
1579 .messages
1580 .iter()
1581 .rev()
1582 .find(|msg| msg.role == Role::Assistant)
1583 .map(|msg| msg.id);
1584
1585 if let Some(message_id) = last_assistant_message_id {
1586 self.report_message_feedback(message_id, feedback, cx)
1587 } else {
1588 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1589 let serialized_thread = self.serialize(cx);
1590 let thread_id = self.id().clone();
1591 let client = self.project.read(cx).client();
1592 self.feedback = Some(feedback);
1593 cx.notify();
1594
1595 cx.background_spawn(async move {
1596 let final_project_snapshot = final_project_snapshot.await;
1597 let serialized_thread = serialized_thread.await?;
1598 let thread_data = serde_json::to_value(serialized_thread)
1599 .unwrap_or_else(|_| serde_json::Value::Null);
1600
1601 let rating = match feedback {
1602 ThreadFeedback::Positive => "positive",
1603 ThreadFeedback::Negative => "negative",
1604 };
1605 telemetry::event!(
1606 "Assistant Thread Rated",
1607 rating,
1608 thread_id,
1609 thread_data,
1610 final_project_snapshot
1611 );
1612 client.telemetry().flush_events();
1613
1614 Ok(())
1615 })
1616 }
1617 }
1618
1619 /// Create a snapshot of the current project state including git information and unsaved buffers.
1620 fn project_snapshot(
1621 project: Entity<Project>,
1622 cx: &mut Context<Self>,
1623 ) -> Task<Arc<ProjectSnapshot>> {
1624 let git_store = project.read(cx).git_store().clone();
1625 let worktree_snapshots: Vec<_> = project
1626 .read(cx)
1627 .visible_worktrees(cx)
1628 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1629 .collect();
1630
1631 cx.spawn(async move |_, cx| {
1632 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1633
1634 let mut unsaved_buffers = Vec::new();
1635 cx.update(|app_cx| {
1636 let buffer_store = project.read(app_cx).buffer_store();
1637 for buffer_handle in buffer_store.read(app_cx).buffers() {
1638 let buffer = buffer_handle.read(app_cx);
1639 if buffer.is_dirty() {
1640 if let Some(file) = buffer.file() {
1641 let path = file.path().to_string_lossy().to_string();
1642 unsaved_buffers.push(path);
1643 }
1644 }
1645 }
1646 })
1647 .ok();
1648
1649 Arc::new(ProjectSnapshot {
1650 worktree_snapshots,
1651 unsaved_buffer_paths: unsaved_buffers,
1652 timestamp: Utc::now(),
1653 })
1654 })
1655 }
1656
1657 fn worktree_snapshot(
1658 worktree: Entity<project::Worktree>,
1659 git_store: Entity<GitStore>,
1660 cx: &App,
1661 ) -> Task<WorktreeSnapshot> {
1662 cx.spawn(async move |cx| {
1663 // Get worktree path and snapshot
1664 let worktree_info = cx.update(|app_cx| {
1665 let worktree = worktree.read(app_cx);
1666 let path = worktree.abs_path().to_string_lossy().to_string();
1667 let snapshot = worktree.snapshot();
1668 (path, snapshot)
1669 });
1670
1671 let Ok((worktree_path, _snapshot)) = worktree_info else {
1672 return WorktreeSnapshot {
1673 worktree_path: String::new(),
1674 git_state: None,
1675 };
1676 };
1677
1678 let git_state = git_store
1679 .update(cx, |git_store, cx| {
1680 git_store
1681 .repositories()
1682 .values()
1683 .find(|repo| {
1684 repo.read(cx)
1685 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1686 .is_some()
1687 })
1688 .cloned()
1689 })
1690 .ok()
1691 .flatten()
1692 .map(|repo| {
1693 repo.update(cx, |repo, _| {
1694 let current_branch =
1695 repo.branch.as_ref().map(|branch| branch.name.to_string());
1696 repo.send_job(None, |state, _| async move {
1697 let RepositoryState::Local { backend, .. } = state else {
1698 return GitState {
1699 remote_url: None,
1700 head_sha: None,
1701 current_branch,
1702 diff: None,
1703 };
1704 };
1705
1706 let remote_url = backend.remote_url("origin");
1707 let head_sha = backend.head_sha();
1708 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1709
1710 GitState {
1711 remote_url,
1712 head_sha,
1713 current_branch,
1714 diff,
1715 }
1716 })
1717 })
1718 });
1719
1720 let git_state = match git_state {
1721 Some(git_state) => match git_state.ok() {
1722 Some(git_state) => git_state.await.ok(),
1723 None => None,
1724 },
1725 None => None,
1726 };
1727
1728 WorktreeSnapshot {
1729 worktree_path,
1730 git_state,
1731 }
1732 })
1733 }
1734
1735 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1736 let mut markdown = Vec::new();
1737
1738 if let Some(summary) = self.summary() {
1739 writeln!(markdown, "# {summary}\n")?;
1740 };
1741
1742 for message in self.messages() {
1743 writeln!(
1744 markdown,
1745 "## {role}\n",
1746 role = match message.role {
1747 Role::User => "User",
1748 Role::Assistant => "Assistant",
1749 Role::System => "System",
1750 }
1751 )?;
1752
1753 if !message.context.is_empty() {
1754 writeln!(markdown, "{}", message.context)?;
1755 }
1756
1757 for segment in &message.segments {
1758 match segment {
1759 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1760 MessageSegment::Thinking(text) => {
1761 writeln!(markdown, "<think>{}</think>\n", text)?
1762 }
1763 }
1764 }
1765
1766 for tool_use in self.tool_uses_for_message(message.id, cx) {
1767 writeln!(
1768 markdown,
1769 "**Use Tool: {} ({})**",
1770 tool_use.name, tool_use.id
1771 )?;
1772 writeln!(markdown, "```json")?;
1773 writeln!(
1774 markdown,
1775 "{}",
1776 serde_json::to_string_pretty(&tool_use.input)?
1777 )?;
1778 writeln!(markdown, "```")?;
1779 }
1780
1781 for tool_result in self.tool_results_for_message(message.id) {
1782 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1783 if tool_result.is_error {
1784 write!(markdown, " (Error)")?;
1785 }
1786
1787 writeln!(markdown, "**\n")?;
1788 writeln!(markdown, "{}", tool_result.content)?;
1789 }
1790 }
1791
1792 Ok(String::from_utf8_lossy(&markdown).to_string())
1793 }
1794
1795 pub fn keep_edits_in_range(
1796 &mut self,
1797 buffer: Entity<language::Buffer>,
1798 buffer_range: Range<language::Anchor>,
1799 cx: &mut Context<Self>,
1800 ) {
1801 self.action_log.update(cx, |action_log, cx| {
1802 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1803 });
1804 }
1805
1806 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1807 self.action_log
1808 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1809 }
1810
1811 pub fn reject_edits_in_ranges(
1812 &mut self,
1813 buffer: Entity<language::Buffer>,
1814 buffer_ranges: Vec<Range<language::Anchor>>,
1815 cx: &mut Context<Self>,
1816 ) -> Task<Result<()>> {
1817 self.action_log.update(cx, |action_log, cx| {
1818 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1819 })
1820 }
1821
1822 pub fn action_log(&self) -> &Entity<ActionLog> {
1823 &self.action_log
1824 }
1825
1826 pub fn project(&self) -> &Entity<Project> {
1827 &self.project
1828 }
1829
1830 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1831 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1832 return;
1833 }
1834
1835 let now = Instant::now();
1836 if let Some(last) = self.last_auto_capture_at {
1837 if now.duration_since(last).as_secs() < 10 {
1838 return;
1839 }
1840 }
1841
1842 self.last_auto_capture_at = Some(now);
1843
1844 let thread_id = self.id().clone();
1845 let github_login = self
1846 .project
1847 .read(cx)
1848 .user_store()
1849 .read(cx)
1850 .current_user()
1851 .map(|user| user.github_login.clone());
1852 let client = self.project.read(cx).client().clone();
1853 let serialize_task = self.serialize(cx);
1854
1855 cx.background_executor()
1856 .spawn(async move {
1857 if let Ok(serialized_thread) = serialize_task.await {
1858 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1859 telemetry::event!(
1860 "Agent Thread Auto-Captured",
1861 thread_id = thread_id.to_string(),
1862 thread_data = thread_data,
1863 auto_capture_reason = "tracked_user",
1864 github_login = github_login
1865 );
1866
1867 client.telemetry().flush_events();
1868 }
1869 }
1870 })
1871 .detach();
1872 }
1873
1874 pub fn cumulative_token_usage(&self) -> TokenUsage {
1875 self.cumulative_token_usage
1876 }
1877
1878 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1879 let model_registry = LanguageModelRegistry::read_global(cx);
1880 let Some(model) = model_registry.default_model() else {
1881 return TotalTokenUsage::default();
1882 };
1883
1884 let max = model.model.max_token_count();
1885
1886 if let Some(exceeded_error) = &self.exceeded_window_error {
1887 if model.model.id() == exceeded_error.model_id {
1888 return TotalTokenUsage {
1889 total: exceeded_error.token_count,
1890 max,
1891 ratio: TokenUsageRatio::Exceeded,
1892 };
1893 }
1894 }
1895
1896 #[cfg(debug_assertions)]
1897 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1898 .unwrap_or("0.8".to_string())
1899 .parse()
1900 .unwrap();
1901 #[cfg(not(debug_assertions))]
1902 let warning_threshold: f32 = 0.8;
1903
1904 let total = self.cumulative_token_usage.total_tokens() as usize;
1905
1906 let ratio = if total >= max {
1907 TokenUsageRatio::Exceeded
1908 } else if total as f32 / max as f32 >= warning_threshold {
1909 TokenUsageRatio::Warning
1910 } else {
1911 TokenUsageRatio::Normal
1912 };
1913
1914 TotalTokenUsage { total, max, ratio }
1915 }
1916
1917 pub fn deny_tool_use(
1918 &mut self,
1919 tool_use_id: LanguageModelToolUseId,
1920 tool_name: Arc<str>,
1921 cx: &mut Context<Self>,
1922 ) {
1923 let err = Err(anyhow::anyhow!(
1924 "Permission to run tool action denied by user"
1925 ));
1926
1927 self.tool_use
1928 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1929 self.tool_finished(tool_use_id.clone(), None, true, cx);
1930 }
1931}
1932
1933#[derive(Debug, Clone, Error)]
1934pub enum ThreadError {
1935 #[error("Payment required")]
1936 PaymentRequired,
1937 #[error("Max monthly spend reached")]
1938 MaxMonthlySpendReached,
1939 #[error("Model request limit reached")]
1940 ModelRequestLimitReached { plan: Plan },
1941 #[error("Message {header}: {message}")]
1942 Message {
1943 header: SharedString,
1944 message: SharedString,
1945 },
1946}
1947
1948#[derive(Debug, Clone)]
1949pub enum ThreadEvent {
1950 ShowError(ThreadError),
1951 StreamedCompletion,
1952 StreamedAssistantText(MessageId, String),
1953 StreamedAssistantThinking(MessageId, String),
1954 Stopped(Result<StopReason, Arc<anyhow::Error>>),
1955 MessageAdded(MessageId),
1956 MessageEdited(MessageId),
1957 MessageDeleted(MessageId),
1958 SummaryGenerated,
1959 SummaryChanged,
1960 UsePendingTools {
1961 tool_uses: Vec<PendingToolUse>,
1962 },
1963 ToolFinished {
1964 #[allow(unused)]
1965 tool_use_id: LanguageModelToolUseId,
1966 /// The pending tool use that corresponds to this tool.
1967 pending_tool_use: Option<PendingToolUse>,
1968 },
1969 CheckpointChanged,
1970 ToolConfirmationNeeded,
1971}
1972
1973impl EventEmitter<ThreadEvent> for Thread {}
1974
1975struct PendingCompletion {
1976 id: usize,
1977 _task: Task<()>,
1978}
1979
1980#[cfg(test)]
1981mod tests {
1982 use super::*;
1983 use crate::{ThreadStore, context_store::ContextStore, thread_store};
1984 use assistant_settings::AssistantSettings;
1985 use context_server::ContextServerSettings;
1986 use editor::EditorSettings;
1987 use gpui::TestAppContext;
1988 use project::{FakeFs, Project};
1989 use prompt_store::PromptBuilder;
1990 use serde_json::json;
1991 use settings::{Settings, SettingsStore};
1992 use std::sync::Arc;
1993 use theme::ThemeSettings;
1994 use util::path;
1995 use workspace::Workspace;
1996
1997 #[gpui::test]
1998 async fn test_message_with_context(cx: &mut TestAppContext) {
1999 init_test_settings(cx);
2000
2001 let project = create_test_project(
2002 cx,
2003 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2004 )
2005 .await;
2006
2007 let (_workspace, _thread_store, thread, context_store) =
2008 setup_test_environment(cx, project.clone()).await;
2009
2010 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2011 .await
2012 .unwrap();
2013
2014 let context =
2015 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2016
2017 // Insert user message with context
2018 let message_id = thread.update(cx, |thread, cx| {
2019 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2020 });
2021
2022 // Check content and context in message object
2023 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2024
2025 // Use different path format strings based on platform for the test
2026 #[cfg(windows)]
2027 let path_part = r"test\code.rs";
2028 #[cfg(not(windows))]
2029 let path_part = "test/code.rs";
2030
2031 let expected_context = format!(
2032 r#"
2033<context>
2034The following items were attached by the user. You don't need to use other tools to read them.
2035
2036<files>
2037```rs {path_part}
2038fn main() {{
2039 println!("Hello, world!");
2040}}
2041```
2042</files>
2043</context>
2044"#
2045 );
2046
2047 assert_eq!(message.role, Role::User);
2048 assert_eq!(message.segments.len(), 1);
2049 assert_eq!(
2050 message.segments[0],
2051 MessageSegment::Text("Please explain this code".to_string())
2052 );
2053 assert_eq!(message.context, expected_context);
2054
2055 // Check message in request
2056 let request = thread.read_with(cx, |thread, cx| {
2057 thread.to_completion_request(RequestKind::Chat, cx)
2058 });
2059
2060 assert_eq!(request.messages.len(), 2);
2061 let expected_full_message = format!("{}Please explain this code", expected_context);
2062 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2063 }
2064
2065 #[gpui::test]
2066 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2067 init_test_settings(cx);
2068
2069 let project = create_test_project(
2070 cx,
2071 json!({
2072 "file1.rs": "fn function1() {}\n",
2073 "file2.rs": "fn function2() {}\n",
2074 "file3.rs": "fn function3() {}\n",
2075 }),
2076 )
2077 .await;
2078
2079 let (_, _thread_store, thread, context_store) =
2080 setup_test_environment(cx, project.clone()).await;
2081
2082 // Open files individually
2083 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2084 .await
2085 .unwrap();
2086 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2087 .await
2088 .unwrap();
2089 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2090 .await
2091 .unwrap();
2092
2093 // Get the context objects
2094 let contexts = context_store.update(cx, |store, _| store.context().clone());
2095 assert_eq!(contexts.len(), 3);
2096
2097 // First message with context 1
2098 let message1_id = thread.update(cx, |thread, cx| {
2099 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2100 });
2101
2102 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2103 let message2_id = thread.update(cx, |thread, cx| {
2104 thread.insert_user_message(
2105 "Message 2",
2106 vec![contexts[0].clone(), contexts[1].clone()],
2107 None,
2108 cx,
2109 )
2110 });
2111
2112 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2113 let message3_id = thread.update(cx, |thread, cx| {
2114 thread.insert_user_message(
2115 "Message 3",
2116 vec![
2117 contexts[0].clone(),
2118 contexts[1].clone(),
2119 contexts[2].clone(),
2120 ],
2121 None,
2122 cx,
2123 )
2124 });
2125
2126 // Check what contexts are included in each message
2127 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2128 (
2129 thread.message(message1_id).unwrap().clone(),
2130 thread.message(message2_id).unwrap().clone(),
2131 thread.message(message3_id).unwrap().clone(),
2132 )
2133 });
2134
2135 // First message should include context 1
2136 assert!(message1.context.contains("file1.rs"));
2137
2138 // Second message should include only context 2 (not 1)
2139 assert!(!message2.context.contains("file1.rs"));
2140 assert!(message2.context.contains("file2.rs"));
2141
2142 // Third message should include only context 3 (not 1 or 2)
2143 assert!(!message3.context.contains("file1.rs"));
2144 assert!(!message3.context.contains("file2.rs"));
2145 assert!(message3.context.contains("file3.rs"));
2146
2147 // Check entire request to make sure all contexts are properly included
2148 let request = thread.read_with(cx, |thread, cx| {
2149 thread.to_completion_request(RequestKind::Chat, cx)
2150 });
2151
2152 // The request should contain all 3 messages
2153 assert_eq!(request.messages.len(), 4);
2154
2155 // Check that the contexts are properly formatted in each message
2156 assert!(request.messages[1].string_contents().contains("file1.rs"));
2157 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2158 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2159
2160 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2161 assert!(request.messages[2].string_contents().contains("file2.rs"));
2162 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2163
2164 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2165 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2166 assert!(request.messages[3].string_contents().contains("file3.rs"));
2167 }
2168
2169 #[gpui::test]
2170 async fn test_message_without_files(cx: &mut TestAppContext) {
2171 init_test_settings(cx);
2172
2173 let project = create_test_project(
2174 cx,
2175 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2176 )
2177 .await;
2178
2179 let (_, _thread_store, thread, _context_store) =
2180 setup_test_environment(cx, project.clone()).await;
2181
2182 // Insert user message without any context (empty context vector)
2183 let message_id = thread.update(cx, |thread, cx| {
2184 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2185 });
2186
2187 // Check content and context in message object
2188 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2189
2190 // Context should be empty when no files are included
2191 assert_eq!(message.role, Role::User);
2192 assert_eq!(message.segments.len(), 1);
2193 assert_eq!(
2194 message.segments[0],
2195 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2196 );
2197 assert_eq!(message.context, "");
2198
2199 // Check message in request
2200 let request = thread.read_with(cx, |thread, cx| {
2201 thread.to_completion_request(RequestKind::Chat, cx)
2202 });
2203
2204 assert_eq!(request.messages.len(), 2);
2205 assert_eq!(
2206 request.messages[1].string_contents(),
2207 "What is the best way to learn Rust?"
2208 );
2209
2210 // Add second message, also without context
2211 let message2_id = thread.update(cx, |thread, cx| {
2212 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2213 });
2214
2215 let message2 =
2216 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2217 assert_eq!(message2.context, "");
2218
2219 // Check that both messages appear in the request
2220 let request = thread.read_with(cx, |thread, cx| {
2221 thread.to_completion_request(RequestKind::Chat, cx)
2222 });
2223
2224 assert_eq!(request.messages.len(), 3);
2225 assert_eq!(
2226 request.messages[1].string_contents(),
2227 "What is the best way to learn Rust?"
2228 );
2229 assert_eq!(
2230 request.messages[2].string_contents(),
2231 "Are there any good books?"
2232 );
2233 }
2234
2235 #[gpui::test]
2236 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2237 init_test_settings(cx);
2238
2239 let project = create_test_project(
2240 cx,
2241 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2242 )
2243 .await;
2244
2245 let (_workspace, _thread_store, thread, context_store) =
2246 setup_test_environment(cx, project.clone()).await;
2247
2248 // Open buffer and add it to context
2249 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2250 .await
2251 .unwrap();
2252
2253 let context =
2254 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2255
2256 // Insert user message with the buffer as context
2257 thread.update(cx, |thread, cx| {
2258 thread.insert_user_message("Explain this code", vec![context], None, cx)
2259 });
2260
2261 // Create a request and check that it doesn't have a stale buffer warning yet
2262 let initial_request = thread.read_with(cx, |thread, cx| {
2263 thread.to_completion_request(RequestKind::Chat, cx)
2264 });
2265
2266 // Make sure we don't have a stale file warning yet
2267 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2268 msg.string_contents()
2269 .contains("These files changed since last read:")
2270 });
2271 assert!(
2272 !has_stale_warning,
2273 "Should not have stale buffer warning before buffer is modified"
2274 );
2275
2276 // Modify the buffer
2277 buffer.update(cx, |buffer, cx| {
2278 // Find a position at the end of line 1
2279 buffer.edit(
2280 [(1..1, "\n println!(\"Added a new line\");\n")],
2281 None,
2282 cx,
2283 );
2284 });
2285
2286 // Insert another user message without context
2287 thread.update(cx, |thread, cx| {
2288 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2289 });
2290
2291 // Create a new request and check for the stale buffer warning
2292 let new_request = thread.read_with(cx, |thread, cx| {
2293 thread.to_completion_request(RequestKind::Chat, cx)
2294 });
2295
2296 // We should have a stale file warning as the last message
2297 let last_message = new_request
2298 .messages
2299 .last()
2300 .expect("Request should have messages");
2301
2302 // The last message should be the stale buffer notification
2303 assert_eq!(last_message.role, Role::User);
2304
2305 // Check the exact content of the message
2306 let expected_content = "These files changed since last read:\n- code.rs\n";
2307 assert_eq!(
2308 last_message.string_contents(),
2309 expected_content,
2310 "Last message should be exactly the stale buffer notification"
2311 );
2312 }
2313
2314 fn init_test_settings(cx: &mut TestAppContext) {
2315 cx.update(|cx| {
2316 let settings_store = SettingsStore::test(cx);
2317 cx.set_global(settings_store);
2318 language::init(cx);
2319 Project::init_settings(cx);
2320 AssistantSettings::register(cx);
2321 thread_store::init(cx);
2322 workspace::init_settings(cx);
2323 ThemeSettings::register(cx);
2324 ContextServerSettings::register(cx);
2325 EditorSettings::register(cx);
2326 });
2327 }
2328
2329 // Helper to create a test project with test files
2330 async fn create_test_project(
2331 cx: &mut TestAppContext,
2332 files: serde_json::Value,
2333 ) -> Entity<Project> {
2334 let fs = FakeFs::new(cx.executor());
2335 fs.insert_tree(path!("/test"), files).await;
2336 Project::test(fs, [path!("/test").as_ref()], cx).await
2337 }
2338
2339 async fn setup_test_environment(
2340 cx: &mut TestAppContext,
2341 project: Entity<Project>,
2342 ) -> (
2343 Entity<Workspace>,
2344 Entity<ThreadStore>,
2345 Entity<Thread>,
2346 Entity<ContextStore>,
2347 ) {
2348 let (workspace, cx) =
2349 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2350
2351 let thread_store = cx
2352 .update(|_, cx| {
2353 ThreadStore::load(
2354 project.clone(),
2355 cx.new(|_| ToolWorkingSet::default()),
2356 Arc::new(PromptBuilder::new(None).unwrap()),
2357 cx,
2358 )
2359 })
2360 .await;
2361
2362 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2363 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2364
2365 (workspace, thread_store, thread, context_store)
2366 }
2367
2368 async fn add_file_to_context(
2369 project: &Entity<Project>,
2370 context_store: &Entity<ContextStore>,
2371 path: &str,
2372 cx: &mut TestAppContext,
2373 ) -> Result<Entity<language::Buffer>> {
2374 let buffer_path = project
2375 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2376 .unwrap();
2377
2378 let buffer = project
2379 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2380 .await
2381 .unwrap();
2382
2383 context_store
2384 .update(cx, |store, cx| {
2385 store.add_file_from_buffer(buffer.clone(), cx)
2386 })
2387 .await?;
2388
2389 Ok(buffer)
2390 }
2391}