1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
22 Role, StopReason, TokenUsage,
23};
24use project::Project;
25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
26use prompt_store::PromptBuilder;
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use settings::Settings;
30use thiserror::Error;
31use util::{ResultExt as _, TryFutureExt as _, post_inc};
32use uuid::Uuid;
33
34use crate::context::{AssistantContext, ContextId, format_context_as_string};
35use crate::thread_store::{
36 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
37 SerializedToolUse, SharedProjectContext,
38};
39use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
40
41#[derive(Debug, Clone, Copy)]
42pub enum RequestKind {
43 Chat,
44 /// Used when summarizing a thread.
45 Summarize,
46}
47
48#[derive(
49 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
50)]
51pub struct ThreadId(Arc<str>);
52
53impl ThreadId {
54 pub fn new() -> Self {
55 Self(Uuid::new_v4().to_string().into())
56 }
57}
58
59impl std::fmt::Display for ThreadId {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65impl From<&str> for ThreadId {
66 fn from(value: &str) -> Self {
67 Self(value.into())
68 }
69}
70
71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
72pub struct MessageId(pub(crate) usize);
73
74impl MessageId {
75 fn post_inc(&mut self) -> Self {
76 Self(post_inc(&mut self.0))
77 }
78}
79
80/// A message in a [`Thread`].
81#[derive(Debug, Clone)]
82pub struct Message {
83 pub id: MessageId,
84 pub role: Role,
85 pub segments: Vec<MessageSegment>,
86 pub context: String,
87}
88
89impl Message {
90 /// Returns whether the message contains any meaningful text that should be displayed
91 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
92 pub fn should_display_content(&self) -> bool {
93 self.segments.iter().all(|segment| segment.should_display())
94 }
95
96 pub fn push_thinking(&mut self, text: &str) {
97 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
98 segment.push_str(text);
99 } else {
100 self.segments
101 .push(MessageSegment::Thinking(text.to_string()));
102 }
103 }
104
105 pub fn push_text(&mut self, text: &str) {
106 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
107 segment.push_str(text);
108 } else {
109 self.segments.push(MessageSegment::Text(text.to_string()));
110 }
111 }
112
113 pub fn to_string(&self) -> String {
114 let mut result = String::new();
115
116 if !self.context.is_empty() {
117 result.push_str(&self.context);
118 }
119
120 for segment in &self.segments {
121 match segment {
122 MessageSegment::Text(text) => result.push_str(text),
123 MessageSegment::Thinking(text) => {
124 result.push_str("<think>");
125 result.push_str(text);
126 result.push_str("</think>");
127 }
128 }
129 }
130
131 result
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum MessageSegment {
137 Text(String),
138 Thinking(String),
139}
140
141impl MessageSegment {
142 pub fn text_mut(&mut self) -> &mut String {
143 match self {
144 Self::Text(text) => text,
145 Self::Thinking(text) => text,
146 }
147 }
148
149 pub fn should_display(&self) -> bool {
150 // We add USING_TOOL_MARKER when making a request that includes tool uses
151 // without non-whitespace text around them, and this can cause the model
152 // to mimic the pattern, so we consider those segments not displayable.
153 match self {
154 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
155 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
156 }
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ProjectSnapshot {
162 pub worktree_snapshots: Vec<WorktreeSnapshot>,
163 pub unsaved_buffer_paths: Vec<String>,
164 pub timestamp: DateTime<Utc>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct WorktreeSnapshot {
169 pub worktree_path: String,
170 pub git_state: Option<GitState>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct GitState {
175 pub remote_url: Option<String>,
176 pub head_sha: Option<String>,
177 pub current_branch: Option<String>,
178 pub diff: Option<String>,
179}
180
181#[derive(Clone)]
182pub struct ThreadCheckpoint {
183 message_id: MessageId,
184 git_checkpoint: GitStoreCheckpoint,
185}
186
187#[derive(Copy, Clone, Debug, PartialEq, Eq)]
188pub enum ThreadFeedback {
189 Positive,
190 Negative,
191}
192
193pub enum LastRestoreCheckpoint {
194 Pending {
195 message_id: MessageId,
196 },
197 Error {
198 message_id: MessageId,
199 error: String,
200 },
201}
202
203impl LastRestoreCheckpoint {
204 pub fn message_id(&self) -> MessageId {
205 match self {
206 LastRestoreCheckpoint::Pending { message_id } => *message_id,
207 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
208 }
209 }
210}
211
212#[derive(Clone, Debug, Default, Serialize, Deserialize)]
213pub enum DetailedSummaryState {
214 #[default]
215 NotGenerated,
216 Generating {
217 message_id: MessageId,
218 },
219 Generated {
220 text: SharedString,
221 message_id: MessageId,
222 },
223}
224
225#[derive(Default)]
226pub struct TotalTokenUsage {
227 pub total: usize,
228 pub max: usize,
229 pub ratio: TokenUsageRatio,
230}
231
232#[derive(Debug, Default, PartialEq, Eq)]
233pub enum TokenUsageRatio {
234 #[default]
235 Normal,
236 Warning,
237 Exceeded,
238}
239
240/// A thread of conversation with the LLM.
241pub struct Thread {
242 id: ThreadId,
243 updated_at: DateTime<Utc>,
244 summary: Option<SharedString>,
245 pending_summary: Task<Option<()>>,
246 detailed_summary_state: DetailedSummaryState,
247 messages: Vec<Message>,
248 next_message_id: MessageId,
249 context: BTreeMap<ContextId, AssistantContext>,
250 context_by_message: HashMap<MessageId, Vec<ContextId>>,
251 project_context: SharedProjectContext,
252 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
253 completion_count: usize,
254 pending_completions: Vec<PendingCompletion>,
255 project: Entity<Project>,
256 prompt_builder: Arc<PromptBuilder>,
257 tools: Entity<ToolWorkingSet>,
258 tool_use: ToolUseState,
259 action_log: Entity<ActionLog>,
260 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
261 pending_checkpoint: Option<ThreadCheckpoint>,
262 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
263 cumulative_token_usage: TokenUsage,
264 exceeded_window_error: Option<ExceededWindowError>,
265 feedback: Option<ThreadFeedback>,
266 message_feedback: HashMap<MessageId, ThreadFeedback>,
267 last_auto_capture_at: Option<Instant>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct ExceededWindowError {
272 /// Model used when last message exceeded context window
273 model_id: LanguageModelId,
274 /// Token count including last message
275 token_count: usize,
276}
277
278impl Thread {
279 pub fn new(
280 project: Entity<Project>,
281 tools: Entity<ToolWorkingSet>,
282 prompt_builder: Arc<PromptBuilder>,
283 system_prompt: SharedProjectContext,
284 cx: &mut Context<Self>,
285 ) -> Self {
286 Self {
287 id: ThreadId::new(),
288 updated_at: Utc::now(),
289 summary: None,
290 pending_summary: Task::ready(None),
291 detailed_summary_state: DetailedSummaryState::NotGenerated,
292 messages: Vec::new(),
293 next_message_id: MessageId(0),
294 context: BTreeMap::default(),
295 context_by_message: HashMap::default(),
296 project_context: system_prompt,
297 checkpoints_by_message: HashMap::default(),
298 completion_count: 0,
299 pending_completions: Vec::new(),
300 project: project.clone(),
301 prompt_builder,
302 tools: tools.clone(),
303 last_restore_checkpoint: None,
304 pending_checkpoint: None,
305 tool_use: ToolUseState::new(tools.clone()),
306 action_log: cx.new(|_| ActionLog::new(project.clone())),
307 initial_project_snapshot: {
308 let project_snapshot = Self::project_snapshot(project, cx);
309 cx.foreground_executor()
310 .spawn(async move { Some(project_snapshot.await) })
311 .shared()
312 },
313 cumulative_token_usage: TokenUsage::default(),
314 exceeded_window_error: None,
315 feedback: None,
316 message_feedback: HashMap::default(),
317 last_auto_capture_at: None,
318 }
319 }
320
321 pub fn deserialize(
322 id: ThreadId,
323 serialized: SerializedThread,
324 project: Entity<Project>,
325 tools: Entity<ToolWorkingSet>,
326 prompt_builder: Arc<PromptBuilder>,
327 project_context: SharedProjectContext,
328 cx: &mut Context<Self>,
329 ) -> Self {
330 let next_message_id = MessageId(
331 serialized
332 .messages
333 .last()
334 .map(|message| message.id.0 + 1)
335 .unwrap_or(0),
336 );
337 let tool_use =
338 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
339
340 Self {
341 id,
342 updated_at: serialized.updated_at,
343 summary: Some(serialized.summary),
344 pending_summary: Task::ready(None),
345 detailed_summary_state: serialized.detailed_summary_state,
346 messages: serialized
347 .messages
348 .into_iter()
349 .map(|message| Message {
350 id: message.id,
351 role: message.role,
352 segments: message
353 .segments
354 .into_iter()
355 .map(|segment| match segment {
356 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
357 SerializedMessageSegment::Thinking { text } => {
358 MessageSegment::Thinking(text)
359 }
360 })
361 .collect(),
362 context: message.context,
363 })
364 .collect(),
365 next_message_id,
366 context: BTreeMap::default(),
367 context_by_message: HashMap::default(),
368 project_context,
369 checkpoints_by_message: HashMap::default(),
370 completion_count: 0,
371 pending_completions: Vec::new(),
372 last_restore_checkpoint: None,
373 pending_checkpoint: None,
374 project: project.clone(),
375 prompt_builder,
376 tools,
377 tool_use,
378 action_log: cx.new(|_| ActionLog::new(project)),
379 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
380 cumulative_token_usage: serialized.cumulative_token_usage,
381 exceeded_window_error: None,
382 feedback: None,
383 message_feedback: HashMap::default(),
384 last_auto_capture_at: None,
385 }
386 }
387
388 pub fn id(&self) -> &ThreadId {
389 &self.id
390 }
391
392 pub fn is_empty(&self) -> bool {
393 self.messages.is_empty()
394 }
395
396 pub fn updated_at(&self) -> DateTime<Utc> {
397 self.updated_at
398 }
399
400 pub fn touch_updated_at(&mut self) {
401 self.updated_at = Utc::now();
402 }
403
404 pub fn summary(&self) -> Option<SharedString> {
405 self.summary.clone()
406 }
407
408 pub fn project_context(&self) -> SharedProjectContext {
409 self.project_context.clone()
410 }
411
412 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
413
414 pub fn summary_or_default(&self) -> SharedString {
415 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
416 }
417
418 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
419 let Some(current_summary) = &self.summary else {
420 // Don't allow setting summary until generated
421 return;
422 };
423
424 let mut new_summary = new_summary.into();
425
426 if new_summary.is_empty() {
427 new_summary = Self::DEFAULT_SUMMARY;
428 }
429
430 if current_summary != &new_summary {
431 self.summary = Some(new_summary);
432 cx.emit(ThreadEvent::SummaryChanged);
433 }
434 }
435
436 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
437 self.latest_detailed_summary()
438 .unwrap_or_else(|| self.text().into())
439 }
440
441 fn latest_detailed_summary(&self) -> Option<SharedString> {
442 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
443 Some(text.clone())
444 } else {
445 None
446 }
447 }
448
449 pub fn message(&self, id: MessageId) -> Option<&Message> {
450 self.messages.iter().find(|message| message.id == id)
451 }
452
453 pub fn messages(&self) -> impl Iterator<Item = &Message> {
454 self.messages.iter()
455 }
456
457 pub fn is_generating(&self) -> bool {
458 !self.pending_completions.is_empty() || !self.all_tools_finished()
459 }
460
461 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
462 &self.tools
463 }
464
465 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
466 self.tool_use
467 .pending_tool_uses()
468 .into_iter()
469 .find(|tool_use| &tool_use.id == id)
470 }
471
472 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
473 self.tool_use
474 .pending_tool_uses()
475 .into_iter()
476 .filter(|tool_use| tool_use.status.needs_confirmation())
477 }
478
479 pub fn has_pending_tool_uses(&self) -> bool {
480 !self.tool_use.pending_tool_uses().is_empty()
481 }
482
483 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
484 self.checkpoints_by_message.get(&id).cloned()
485 }
486
487 pub fn restore_checkpoint(
488 &mut self,
489 checkpoint: ThreadCheckpoint,
490 cx: &mut Context<Self>,
491 ) -> Task<Result<()>> {
492 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
493 message_id: checkpoint.message_id,
494 });
495 cx.emit(ThreadEvent::CheckpointChanged);
496 cx.notify();
497
498 let git_store = self.project().read(cx).git_store().clone();
499 let restore = git_store.update(cx, |git_store, cx| {
500 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
501 });
502
503 cx.spawn(async move |this, cx| {
504 let result = restore.await;
505 this.update(cx, |this, cx| {
506 if let Err(err) = result.as_ref() {
507 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
508 message_id: checkpoint.message_id,
509 error: err.to_string(),
510 });
511 } else {
512 this.truncate(checkpoint.message_id, cx);
513 this.last_restore_checkpoint = None;
514 }
515 this.pending_checkpoint = None;
516 cx.emit(ThreadEvent::CheckpointChanged);
517 cx.notify();
518 })?;
519 result
520 })
521 }
522
523 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
524 let pending_checkpoint = if self.is_generating() {
525 return;
526 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
527 checkpoint
528 } else {
529 return;
530 };
531
532 let git_store = self.project.read(cx).git_store().clone();
533 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
534 cx.spawn(async move |this, cx| match final_checkpoint.await {
535 Ok(final_checkpoint) => {
536 let equal = git_store
537 .update(cx, |store, cx| {
538 store.compare_checkpoints(
539 pending_checkpoint.git_checkpoint.clone(),
540 final_checkpoint.clone(),
541 cx,
542 )
543 })?
544 .await
545 .unwrap_or(false);
546
547 if equal {
548 git_store
549 .update(cx, |store, cx| {
550 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
551 })?
552 .detach();
553 } else {
554 this.update(cx, |this, cx| {
555 this.insert_checkpoint(pending_checkpoint, cx)
556 })?;
557 }
558
559 git_store
560 .update(cx, |store, cx| {
561 store.delete_checkpoint(final_checkpoint, cx)
562 })?
563 .detach();
564
565 Ok(())
566 }
567 Err(_) => this.update(cx, |this, cx| {
568 this.insert_checkpoint(pending_checkpoint, cx)
569 }),
570 })
571 .detach();
572 }
573
574 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
575 self.checkpoints_by_message
576 .insert(checkpoint.message_id, checkpoint);
577 cx.emit(ThreadEvent::CheckpointChanged);
578 cx.notify();
579 }
580
581 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
582 self.last_restore_checkpoint.as_ref()
583 }
584
585 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
586 let Some(message_ix) = self
587 .messages
588 .iter()
589 .rposition(|message| message.id == message_id)
590 else {
591 return;
592 };
593 for deleted_message in self.messages.drain(message_ix..) {
594 self.context_by_message.remove(&deleted_message.id);
595 self.checkpoints_by_message.remove(&deleted_message.id);
596 }
597 cx.notify();
598 }
599
600 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
601 self.context_by_message
602 .get(&id)
603 .into_iter()
604 .flat_map(|context| {
605 context
606 .iter()
607 .filter_map(|context_id| self.context.get(&context_id))
608 })
609 }
610
611 /// Returns whether all of the tool uses have finished running.
612 pub fn all_tools_finished(&self) -> bool {
613 // If the only pending tool uses left are the ones with errors, then
614 // that means that we've finished running all of the pending tools.
615 self.tool_use
616 .pending_tool_uses()
617 .iter()
618 .all(|tool_use| tool_use.status.is_error())
619 }
620
621 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
622 self.tool_use.tool_uses_for_message(id, cx)
623 }
624
625 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
626 self.tool_use.tool_results_for_message(id)
627 }
628
629 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
630 self.tool_use.tool_result(id)
631 }
632
633 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
634 self.tool_use.message_has_tool_results(message_id)
635 }
636
637 pub fn insert_user_message(
638 &mut self,
639 text: impl Into<String>,
640 context: Vec<AssistantContext>,
641 git_checkpoint: Option<GitStoreCheckpoint>,
642 cx: &mut Context<Self>,
643 ) -> MessageId {
644 let text = text.into();
645
646 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
647
648 // Filter out contexts that have already been included in previous messages
649 let new_context: Vec<_> = context
650 .into_iter()
651 .filter(|ctx| !self.context.contains_key(&ctx.id()))
652 .collect();
653
654 if !new_context.is_empty() {
655 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
656 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
657 message.context = context_string;
658 }
659 }
660
661 self.action_log.update(cx, |log, cx| {
662 // Track all buffers added as context
663 for ctx in &new_context {
664 match ctx {
665 AssistantContext::File(file_ctx) => {
666 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
667 }
668 AssistantContext::Directory(dir_ctx) => {
669 for context_buffer in &dir_ctx.context_buffers {
670 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
671 }
672 }
673 AssistantContext::Symbol(symbol_ctx) => {
674 log.buffer_added_as_context(
675 symbol_ctx.context_symbol.buffer.clone(),
676 cx,
677 );
678 }
679 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
680 }
681 }
682 });
683 }
684
685 let context_ids = new_context
686 .iter()
687 .map(|context| context.id())
688 .collect::<Vec<_>>();
689 self.context.extend(
690 new_context
691 .into_iter()
692 .map(|context| (context.id(), context)),
693 );
694 self.context_by_message.insert(message_id, context_ids);
695
696 if let Some(git_checkpoint) = git_checkpoint {
697 self.pending_checkpoint = Some(ThreadCheckpoint {
698 message_id,
699 git_checkpoint,
700 });
701 }
702
703 self.auto_capture_telemetry(cx);
704
705 message_id
706 }
707
708 pub fn insert_message(
709 &mut self,
710 role: Role,
711 segments: Vec<MessageSegment>,
712 cx: &mut Context<Self>,
713 ) -> MessageId {
714 let id = self.next_message_id.post_inc();
715 self.messages.push(Message {
716 id,
717 role,
718 segments,
719 context: String::new(),
720 });
721 self.touch_updated_at();
722 cx.emit(ThreadEvent::MessageAdded(id));
723 id
724 }
725
726 pub fn edit_message(
727 &mut self,
728 id: MessageId,
729 new_role: Role,
730 new_segments: Vec<MessageSegment>,
731 cx: &mut Context<Self>,
732 ) -> bool {
733 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
734 return false;
735 };
736 message.role = new_role;
737 message.segments = new_segments;
738 self.touch_updated_at();
739 cx.emit(ThreadEvent::MessageEdited(id));
740 true
741 }
742
743 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
744 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
745 return false;
746 };
747 self.messages.remove(index);
748 self.context_by_message.remove(&id);
749 self.touch_updated_at();
750 cx.emit(ThreadEvent::MessageDeleted(id));
751 true
752 }
753
754 /// Returns the representation of this [`Thread`] in a textual form.
755 ///
756 /// This is the representation we use when attaching a thread as context to another thread.
757 pub fn text(&self) -> String {
758 let mut text = String::new();
759
760 for message in &self.messages {
761 text.push_str(match message.role {
762 language_model::Role::User => "User:",
763 language_model::Role::Assistant => "Assistant:",
764 language_model::Role::System => "System:",
765 });
766 text.push('\n');
767
768 for segment in &message.segments {
769 match segment {
770 MessageSegment::Text(content) => text.push_str(content),
771 MessageSegment::Thinking(content) => {
772 text.push_str(&format!("<think>{}</think>", content))
773 }
774 }
775 }
776 text.push('\n');
777 }
778
779 text
780 }
781
782 /// Serializes this thread into a format for storage or telemetry.
783 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
784 let initial_project_snapshot = self.initial_project_snapshot.clone();
785 cx.spawn(async move |this, cx| {
786 let initial_project_snapshot = initial_project_snapshot.await;
787 this.read_with(cx, |this, cx| SerializedThread {
788 version: SerializedThread::VERSION.to_string(),
789 summary: this.summary_or_default(),
790 updated_at: this.updated_at(),
791 messages: this
792 .messages()
793 .map(|message| SerializedMessage {
794 id: message.id,
795 role: message.role,
796 segments: message
797 .segments
798 .iter()
799 .map(|segment| match segment {
800 MessageSegment::Text(text) => {
801 SerializedMessageSegment::Text { text: text.clone() }
802 }
803 MessageSegment::Thinking(text) => {
804 SerializedMessageSegment::Thinking { text: text.clone() }
805 }
806 })
807 .collect(),
808 tool_uses: this
809 .tool_uses_for_message(message.id, cx)
810 .into_iter()
811 .map(|tool_use| SerializedToolUse {
812 id: tool_use.id,
813 name: tool_use.name,
814 input: tool_use.input,
815 })
816 .collect(),
817 tool_results: this
818 .tool_results_for_message(message.id)
819 .into_iter()
820 .map(|tool_result| SerializedToolResult {
821 tool_use_id: tool_result.tool_use_id.clone(),
822 is_error: tool_result.is_error,
823 content: tool_result.content.clone(),
824 })
825 .collect(),
826 context: message.context.clone(),
827 })
828 .collect(),
829 initial_project_snapshot,
830 cumulative_token_usage: this.cumulative_token_usage,
831 detailed_summary_state: this.detailed_summary_state.clone(),
832 exceeded_window_error: this.exceeded_window_error.clone(),
833 })
834 })
835 }
836
837 pub fn send_to_model(
838 &mut self,
839 model: Arc<dyn LanguageModel>,
840 request_kind: RequestKind,
841 cx: &mut Context<Self>,
842 ) {
843 let mut request = self.to_completion_request(request_kind, cx);
844 if model.supports_tools() {
845 request.tools = {
846 let mut tools = Vec::new();
847 tools.extend(
848 self.tools()
849 .read(cx)
850 .enabled_tools(cx)
851 .into_iter()
852 .filter_map(|tool| {
853 // Skip tools that cannot be supported
854 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
855 Some(LanguageModelRequestTool {
856 name: tool.name(),
857 description: tool.description(),
858 input_schema,
859 })
860 }),
861 );
862
863 tools
864 };
865 }
866
867 self.stream_completion(request, model, cx);
868 }
869
870 pub fn used_tools_since_last_user_message(&self) -> bool {
871 for message in self.messages.iter().rev() {
872 if self.tool_use.message_has_tool_results(message.id) {
873 return true;
874 } else if message.role == Role::User {
875 return false;
876 }
877 }
878
879 false
880 }
881
882 pub fn to_completion_request(
883 &self,
884 request_kind: RequestKind,
885 cx: &App,
886 ) -> LanguageModelRequest {
887 let mut request = LanguageModelRequest {
888 messages: vec![],
889 tools: Vec::new(),
890 stop: Vec::new(),
891 temperature: None,
892 };
893
894 if let Some(project_context) = self.project_context.borrow().as_ref() {
895 if let Some(system_prompt) = self
896 .prompt_builder
897 .generate_assistant_system_prompt(project_context)
898 .context("failed to generate assistant system prompt")
899 .log_err()
900 {
901 request.messages.push(LanguageModelRequestMessage {
902 role: Role::System,
903 content: vec![MessageContent::Text(system_prompt)],
904 cache: true,
905 });
906 }
907 } else {
908 log::error!("project_context not set.")
909 }
910
911 for message in &self.messages {
912 let mut request_message = LanguageModelRequestMessage {
913 role: message.role,
914 content: Vec::new(),
915 cache: false,
916 };
917
918 match request_kind {
919 RequestKind::Chat => {
920 self.tool_use
921 .attach_tool_results(message.id, &mut request_message);
922 }
923 RequestKind::Summarize => {
924 // We don't care about tool use during summarization.
925 if self.tool_use.message_has_tool_results(message.id) {
926 continue;
927 }
928 }
929 }
930
931 if !message.segments.is_empty() {
932 request_message
933 .content
934 .push(MessageContent::Text(message.to_string()));
935 }
936
937 match request_kind {
938 RequestKind::Chat => {
939 self.tool_use
940 .attach_tool_uses(message.id, &mut request_message);
941 }
942 RequestKind::Summarize => {
943 // We don't care about tool use during summarization.
944 }
945 };
946
947 request.messages.push(request_message);
948 }
949
950 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
951 if let Some(last) = request.messages.last_mut() {
952 last.cache = true;
953 }
954
955 self.attached_tracked_files_state(&mut request.messages, cx);
956
957 request
958 }
959
960 fn attached_tracked_files_state(
961 &self,
962 messages: &mut Vec<LanguageModelRequestMessage>,
963 cx: &App,
964 ) {
965 const STALE_FILES_HEADER: &str = "These files changed since last read:";
966
967 let mut stale_message = String::new();
968
969 let action_log = self.action_log.read(cx);
970
971 for stale_file in action_log.stale_buffers(cx) {
972 let Some(file) = stale_file.read(cx).file() else {
973 continue;
974 };
975
976 if stale_message.is_empty() {
977 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
978 }
979
980 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
981 }
982
983 let mut content = Vec::with_capacity(2);
984
985 if !stale_message.is_empty() {
986 content.push(stale_message.into());
987 }
988
989 if action_log.has_edited_files_since_project_diagnostics_check() {
990 content.push(
991 "\n\nWhen you're done making changes, make sure to check project diagnostics \
992 and fix all errors AND warnings you introduced! \
993 DO NOT mention you're going to do this until you're done."
994 .into(),
995 );
996 }
997
998 if !content.is_empty() {
999 let context_message = LanguageModelRequestMessage {
1000 role: Role::User,
1001 content,
1002 cache: false,
1003 };
1004
1005 messages.push(context_message);
1006 }
1007 }
1008
1009 pub fn stream_completion(
1010 &mut self,
1011 request: LanguageModelRequest,
1012 model: Arc<dyn LanguageModel>,
1013 cx: &mut Context<Self>,
1014 ) {
1015 let pending_completion_id = post_inc(&mut self.completion_count);
1016
1017 let task = cx.spawn(async move |thread, cx| {
1018 let stream = model.stream_completion(request, &cx);
1019 let initial_token_usage =
1020 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1021 let stream_completion = async {
1022 let mut events = stream.await?;
1023 let mut stop_reason = StopReason::EndTurn;
1024 let mut current_token_usage = TokenUsage::default();
1025
1026 while let Some(event) = events.next().await {
1027 let event = event?;
1028
1029 thread.update(cx, |thread, cx| {
1030 match event {
1031 LanguageModelCompletionEvent::StartMessage { .. } => {
1032 thread.insert_message(
1033 Role::Assistant,
1034 vec![MessageSegment::Text(String::new())],
1035 cx,
1036 );
1037 }
1038 LanguageModelCompletionEvent::Stop(reason) => {
1039 stop_reason = reason;
1040 }
1041 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1042 thread.cumulative_token_usage = thread.cumulative_token_usage
1043 + token_usage
1044 - current_token_usage;
1045 current_token_usage = token_usage;
1046 }
1047 LanguageModelCompletionEvent::Text(chunk) => {
1048 if let Some(last_message) = thread.messages.last_mut() {
1049 if last_message.role == Role::Assistant {
1050 last_message.push_text(&chunk);
1051 cx.emit(ThreadEvent::StreamedAssistantText(
1052 last_message.id,
1053 chunk,
1054 ));
1055 } else {
1056 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1057 // of a new Assistant response.
1058 //
1059 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1060 // will result in duplicating the text of the chunk in the rendered Markdown.
1061 thread.insert_message(
1062 Role::Assistant,
1063 vec![MessageSegment::Text(chunk.to_string())],
1064 cx,
1065 );
1066 };
1067 }
1068 }
1069 LanguageModelCompletionEvent::Thinking(chunk) => {
1070 if let Some(last_message) = thread.messages.last_mut() {
1071 if last_message.role == Role::Assistant {
1072 last_message.push_thinking(&chunk);
1073 cx.emit(ThreadEvent::StreamedAssistantThinking(
1074 last_message.id,
1075 chunk,
1076 ));
1077 } else {
1078 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1079 // of a new Assistant response.
1080 //
1081 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1082 // will result in duplicating the text of the chunk in the rendered Markdown.
1083 thread.insert_message(
1084 Role::Assistant,
1085 vec![MessageSegment::Thinking(chunk.to_string())],
1086 cx,
1087 );
1088 };
1089 }
1090 }
1091 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1092 let last_assistant_message_id = thread
1093 .messages
1094 .iter_mut()
1095 .rfind(|message| message.role == Role::Assistant)
1096 .map(|message| message.id)
1097 .unwrap_or_else(|| {
1098 thread.insert_message(Role::Assistant, vec![], cx)
1099 });
1100
1101 thread.tool_use.request_tool_use(
1102 last_assistant_message_id,
1103 tool_use,
1104 cx,
1105 );
1106 }
1107 }
1108
1109 thread.touch_updated_at();
1110 cx.emit(ThreadEvent::StreamedCompletion);
1111 cx.notify();
1112
1113 thread.auto_capture_telemetry(cx);
1114 })?;
1115
1116 smol::future::yield_now().await;
1117 }
1118
1119 thread.update(cx, |thread, cx| {
1120 thread
1121 .pending_completions
1122 .retain(|completion| completion.id != pending_completion_id);
1123
1124 if thread.summary.is_none() && thread.messages.len() >= 2 {
1125 thread.summarize(cx);
1126 }
1127 })?;
1128
1129 anyhow::Ok(stop_reason)
1130 };
1131
1132 let result = stream_completion.await;
1133
1134 thread
1135 .update(cx, |thread, cx| {
1136 thread.finalize_pending_checkpoint(cx);
1137 match result.as_ref() {
1138 Ok(stop_reason) => match stop_reason {
1139 StopReason::ToolUse => {
1140 let tool_uses = thread.use_pending_tools(cx);
1141 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1142 }
1143 StopReason::EndTurn => {}
1144 StopReason::MaxTokens => {}
1145 },
1146 Err(error) => {
1147 if error.is::<PaymentRequiredError>() {
1148 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1149 } else if error.is::<MaxMonthlySpendReachedError>() {
1150 cx.emit(ThreadEvent::ShowError(
1151 ThreadError::MaxMonthlySpendReached,
1152 ));
1153 } else if let Some(known_error) =
1154 error.downcast_ref::<LanguageModelKnownError>()
1155 {
1156 match known_error {
1157 LanguageModelKnownError::ContextWindowLimitExceeded {
1158 tokens,
1159 } => {
1160 thread.exceeded_window_error = Some(ExceededWindowError {
1161 model_id: model.id(),
1162 token_count: *tokens,
1163 });
1164 cx.notify();
1165 }
1166 }
1167 } else {
1168 let error_message = error
1169 .chain()
1170 .map(|err| err.to_string())
1171 .collect::<Vec<_>>()
1172 .join("\n");
1173 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1174 header: "Error interacting with language model".into(),
1175 message: SharedString::from(error_message.clone()),
1176 }));
1177 }
1178
1179 thread.cancel_last_completion(cx);
1180 }
1181 }
1182 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1183
1184 thread.auto_capture_telemetry(cx);
1185
1186 if let Ok(initial_usage) = initial_token_usage {
1187 let usage = thread.cumulative_token_usage - initial_usage;
1188
1189 telemetry::event!(
1190 "Assistant Thread Completion",
1191 thread_id = thread.id().to_string(),
1192 model = model.telemetry_id(),
1193 model_provider = model.provider_id().to_string(),
1194 input_tokens = usage.input_tokens,
1195 output_tokens = usage.output_tokens,
1196 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1197 cache_read_input_tokens = usage.cache_read_input_tokens,
1198 );
1199 }
1200 })
1201 .ok();
1202 });
1203
1204 self.pending_completions.push(PendingCompletion {
1205 id: pending_completion_id,
1206 _task: task,
1207 });
1208 }
1209
1210 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1211 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1212 return;
1213 };
1214
1215 if !model.provider.is_authenticated(cx) {
1216 return;
1217 }
1218
1219 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1220 request.messages.push(LanguageModelRequestMessage {
1221 role: Role::User,
1222 content: vec![
1223 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1224 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1225 If the conversation is about a specific subject, include it in the title. \
1226 Be descriptive. DO NOT speak in the first person."
1227 .into(),
1228 ],
1229 cache: false,
1230 });
1231
1232 self.pending_summary = cx.spawn(async move |this, cx| {
1233 async move {
1234 let stream = model.model.stream_completion_text(request, &cx);
1235 let mut messages = stream.await?;
1236
1237 let mut new_summary = String::new();
1238 while let Some(message) = messages.stream.next().await {
1239 let text = message?;
1240 let mut lines = text.lines();
1241 new_summary.extend(lines.next());
1242
1243 // Stop if the LLM generated multiple lines.
1244 if lines.next().is_some() {
1245 break;
1246 }
1247 }
1248
1249 this.update(cx, |this, cx| {
1250 if !new_summary.is_empty() {
1251 this.summary = Some(new_summary.into());
1252 }
1253
1254 cx.emit(ThreadEvent::SummaryGenerated);
1255 })?;
1256
1257 anyhow::Ok(())
1258 }
1259 .log_err()
1260 .await
1261 });
1262 }
1263
1264 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1265 let last_message_id = self.messages.last().map(|message| message.id)?;
1266
1267 match &self.detailed_summary_state {
1268 DetailedSummaryState::Generating { message_id, .. }
1269 | DetailedSummaryState::Generated { message_id, .. }
1270 if *message_id == last_message_id =>
1271 {
1272 // Already up-to-date
1273 return None;
1274 }
1275 _ => {}
1276 }
1277
1278 let ConfiguredModel { model, provider } =
1279 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1280
1281 if !provider.is_authenticated(cx) {
1282 return None;
1283 }
1284
1285 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1286
1287 request.messages.push(LanguageModelRequestMessage {
1288 role: Role::User,
1289 content: vec![
1290 "Generate a detailed summary of this conversation. Include:\n\
1291 1. A brief overview of what was discussed\n\
1292 2. Key facts or information discovered\n\
1293 3. Outcomes or conclusions reached\n\
1294 4. Any action items or next steps if any\n\
1295 Format it in Markdown with headings and bullet points."
1296 .into(),
1297 ],
1298 cache: false,
1299 });
1300
1301 let task = cx.spawn(async move |thread, cx| {
1302 let stream = model.stream_completion_text(request, &cx);
1303 let Some(mut messages) = stream.await.log_err() else {
1304 thread
1305 .update(cx, |this, _cx| {
1306 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1307 })
1308 .log_err();
1309
1310 return;
1311 };
1312
1313 let mut new_detailed_summary = String::new();
1314
1315 while let Some(chunk) = messages.stream.next().await {
1316 if let Some(chunk) = chunk.log_err() {
1317 new_detailed_summary.push_str(&chunk);
1318 }
1319 }
1320
1321 thread
1322 .update(cx, |this, _cx| {
1323 this.detailed_summary_state = DetailedSummaryState::Generated {
1324 text: new_detailed_summary.into(),
1325 message_id: last_message_id,
1326 };
1327 })
1328 .log_err();
1329 });
1330
1331 self.detailed_summary_state = DetailedSummaryState::Generating {
1332 message_id: last_message_id,
1333 };
1334
1335 Some(task)
1336 }
1337
1338 pub fn is_generating_detailed_summary(&self) -> bool {
1339 matches!(
1340 self.detailed_summary_state,
1341 DetailedSummaryState::Generating { .. }
1342 )
1343 }
1344
1345 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1346 self.auto_capture_telemetry(cx);
1347 let request = self.to_completion_request(RequestKind::Chat, cx);
1348 let messages = Arc::new(request.messages);
1349 let pending_tool_uses = self
1350 .tool_use
1351 .pending_tool_uses()
1352 .into_iter()
1353 .filter(|tool_use| tool_use.status.is_idle())
1354 .cloned()
1355 .collect::<Vec<_>>();
1356
1357 for tool_use in pending_tool_uses.iter() {
1358 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1359 if tool.needs_confirmation(&tool_use.input, cx)
1360 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1361 {
1362 self.tool_use.confirm_tool_use(
1363 tool_use.id.clone(),
1364 tool_use.ui_text.clone(),
1365 tool_use.input.clone(),
1366 messages.clone(),
1367 tool,
1368 );
1369 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1370 } else {
1371 self.run_tool(
1372 tool_use.id.clone(),
1373 tool_use.ui_text.clone(),
1374 tool_use.input.clone(),
1375 &messages,
1376 tool,
1377 cx,
1378 );
1379 }
1380 }
1381 }
1382
1383 pending_tool_uses
1384 }
1385
1386 pub fn run_tool(
1387 &mut self,
1388 tool_use_id: LanguageModelToolUseId,
1389 ui_text: impl Into<SharedString>,
1390 input: serde_json::Value,
1391 messages: &[LanguageModelRequestMessage],
1392 tool: Arc<dyn Tool>,
1393 cx: &mut Context<Thread>,
1394 ) {
1395 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1396 self.tool_use
1397 .run_pending_tool(tool_use_id, ui_text.into(), task);
1398 }
1399
1400 fn spawn_tool_use(
1401 &mut self,
1402 tool_use_id: LanguageModelToolUseId,
1403 messages: &[LanguageModelRequestMessage],
1404 input: serde_json::Value,
1405 tool: Arc<dyn Tool>,
1406 cx: &mut Context<Thread>,
1407 ) -> Task<()> {
1408 let tool_name: Arc<str> = tool.name().into();
1409
1410 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1411 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1412 } else {
1413 tool.run(
1414 input,
1415 messages,
1416 self.project.clone(),
1417 self.action_log.clone(),
1418 cx,
1419 )
1420 };
1421
1422 cx.spawn({
1423 async move |thread: WeakEntity<Thread>, cx| {
1424 let output = tool_result.output.await;
1425
1426 thread
1427 .update(cx, |thread, cx| {
1428 let pending_tool_use = thread.tool_use.insert_tool_output(
1429 tool_use_id.clone(),
1430 tool_name,
1431 output,
1432 cx,
1433 );
1434 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1435 })
1436 .ok();
1437 }
1438 })
1439 }
1440
1441 fn tool_finished(
1442 &mut self,
1443 tool_use_id: LanguageModelToolUseId,
1444 pending_tool_use: Option<PendingToolUse>,
1445 canceled: bool,
1446 cx: &mut Context<Self>,
1447 ) {
1448 if self.all_tools_finished() {
1449 let model_registry = LanguageModelRegistry::read_global(cx);
1450 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1451 self.attach_tool_results(cx);
1452 if !canceled {
1453 self.send_to_model(model, RequestKind::Chat, cx);
1454 }
1455 }
1456 }
1457
1458 cx.emit(ThreadEvent::ToolFinished {
1459 tool_use_id,
1460 pending_tool_use,
1461 });
1462 }
1463
1464 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1465 // Insert a user message to contain the tool results.
1466 self.insert_user_message(
1467 // TODO: Sending up a user message without any content results in the model sending back
1468 // responses that also don't have any content. We currently don't handle this case well,
1469 // so for now we provide some text to keep the model on track.
1470 "Here are the tool results.",
1471 Vec::new(),
1472 None,
1473 cx,
1474 );
1475 }
1476
1477 /// Cancels the last pending completion, if there are any pending.
1478 ///
1479 /// Returns whether a completion was canceled.
1480 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1481 let canceled = if self.pending_completions.pop().is_some() {
1482 true
1483 } else {
1484 let mut canceled = false;
1485 for pending_tool_use in self.tool_use.cancel_pending() {
1486 canceled = true;
1487 self.tool_finished(
1488 pending_tool_use.id.clone(),
1489 Some(pending_tool_use),
1490 true,
1491 cx,
1492 );
1493 }
1494 canceled
1495 };
1496 self.finalize_pending_checkpoint(cx);
1497 canceled
1498 }
1499
1500 pub fn feedback(&self) -> Option<ThreadFeedback> {
1501 self.feedback
1502 }
1503
1504 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1505 self.message_feedback.get(&message_id).copied()
1506 }
1507
1508 pub fn report_message_feedback(
1509 &mut self,
1510 message_id: MessageId,
1511 feedback: ThreadFeedback,
1512 cx: &mut Context<Self>,
1513 ) -> Task<Result<()>> {
1514 if self.message_feedback.get(&message_id) == Some(&feedback) {
1515 return Task::ready(Ok(()));
1516 }
1517
1518 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1519 let serialized_thread = self.serialize(cx);
1520 let thread_id = self.id().clone();
1521 let client = self.project.read(cx).client();
1522
1523 let enabled_tool_names: Vec<String> = self
1524 .tools()
1525 .read(cx)
1526 .enabled_tools(cx)
1527 .iter()
1528 .map(|tool| tool.name().to_string())
1529 .collect();
1530
1531 self.message_feedback.insert(message_id, feedback);
1532
1533 cx.notify();
1534
1535 let message_content = self
1536 .message(message_id)
1537 .map(|msg| msg.to_string())
1538 .unwrap_or_default();
1539
1540 cx.background_spawn(async move {
1541 let final_project_snapshot = final_project_snapshot.await;
1542 let serialized_thread = serialized_thread.await?;
1543 let thread_data =
1544 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1545
1546 let rating = match feedback {
1547 ThreadFeedback::Positive => "positive",
1548 ThreadFeedback::Negative => "negative",
1549 };
1550 telemetry::event!(
1551 "Assistant Thread Rated",
1552 rating,
1553 thread_id,
1554 enabled_tool_names,
1555 message_id = message_id.0,
1556 message_content,
1557 thread_data,
1558 final_project_snapshot
1559 );
1560 client.telemetry().flush_events();
1561
1562 Ok(())
1563 })
1564 }
1565
1566 pub fn report_feedback(
1567 &mut self,
1568 feedback: ThreadFeedback,
1569 cx: &mut Context<Self>,
1570 ) -> Task<Result<()>> {
1571 let last_assistant_message_id = self
1572 .messages
1573 .iter()
1574 .rev()
1575 .find(|msg| msg.role == Role::Assistant)
1576 .map(|msg| msg.id);
1577
1578 if let Some(message_id) = last_assistant_message_id {
1579 self.report_message_feedback(message_id, feedback, cx)
1580 } else {
1581 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1582 let serialized_thread = self.serialize(cx);
1583 let thread_id = self.id().clone();
1584 let client = self.project.read(cx).client();
1585 self.feedback = Some(feedback);
1586 cx.notify();
1587
1588 cx.background_spawn(async move {
1589 let final_project_snapshot = final_project_snapshot.await;
1590 let serialized_thread = serialized_thread.await?;
1591 let thread_data = serde_json::to_value(serialized_thread)
1592 .unwrap_or_else(|_| serde_json::Value::Null);
1593
1594 let rating = match feedback {
1595 ThreadFeedback::Positive => "positive",
1596 ThreadFeedback::Negative => "negative",
1597 };
1598 telemetry::event!(
1599 "Assistant Thread Rated",
1600 rating,
1601 thread_id,
1602 thread_data,
1603 final_project_snapshot
1604 );
1605 client.telemetry().flush_events();
1606
1607 Ok(())
1608 })
1609 }
1610 }
1611
1612 /// Create a snapshot of the current project state including git information and unsaved buffers.
1613 fn project_snapshot(
1614 project: Entity<Project>,
1615 cx: &mut Context<Self>,
1616 ) -> Task<Arc<ProjectSnapshot>> {
1617 let git_store = project.read(cx).git_store().clone();
1618 let worktree_snapshots: Vec<_> = project
1619 .read(cx)
1620 .visible_worktrees(cx)
1621 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1622 .collect();
1623
1624 cx.spawn(async move |_, cx| {
1625 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1626
1627 let mut unsaved_buffers = Vec::new();
1628 cx.update(|app_cx| {
1629 let buffer_store = project.read(app_cx).buffer_store();
1630 for buffer_handle in buffer_store.read(app_cx).buffers() {
1631 let buffer = buffer_handle.read(app_cx);
1632 if buffer.is_dirty() {
1633 if let Some(file) = buffer.file() {
1634 let path = file.path().to_string_lossy().to_string();
1635 unsaved_buffers.push(path);
1636 }
1637 }
1638 }
1639 })
1640 .ok();
1641
1642 Arc::new(ProjectSnapshot {
1643 worktree_snapshots,
1644 unsaved_buffer_paths: unsaved_buffers,
1645 timestamp: Utc::now(),
1646 })
1647 })
1648 }
1649
1650 fn worktree_snapshot(
1651 worktree: Entity<project::Worktree>,
1652 git_store: Entity<GitStore>,
1653 cx: &App,
1654 ) -> Task<WorktreeSnapshot> {
1655 cx.spawn(async move |cx| {
1656 // Get worktree path and snapshot
1657 let worktree_info = cx.update(|app_cx| {
1658 let worktree = worktree.read(app_cx);
1659 let path = worktree.abs_path().to_string_lossy().to_string();
1660 let snapshot = worktree.snapshot();
1661 (path, snapshot)
1662 });
1663
1664 let Ok((worktree_path, _snapshot)) = worktree_info else {
1665 return WorktreeSnapshot {
1666 worktree_path: String::new(),
1667 git_state: None,
1668 };
1669 };
1670
1671 let git_state = git_store
1672 .update(cx, |git_store, cx| {
1673 git_store
1674 .repositories()
1675 .values()
1676 .find(|repo| {
1677 repo.read(cx)
1678 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1679 .is_some()
1680 })
1681 .cloned()
1682 })
1683 .ok()
1684 .flatten()
1685 .map(|repo| {
1686 repo.update(cx, |repo, _| {
1687 let current_branch =
1688 repo.branch.as_ref().map(|branch| branch.name.to_string());
1689 repo.send_job(None, |state, _| async move {
1690 let RepositoryState::Local { backend, .. } = state else {
1691 return GitState {
1692 remote_url: None,
1693 head_sha: None,
1694 current_branch,
1695 diff: None,
1696 };
1697 };
1698
1699 let remote_url = backend.remote_url("origin");
1700 let head_sha = backend.head_sha();
1701 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1702
1703 GitState {
1704 remote_url,
1705 head_sha,
1706 current_branch,
1707 diff,
1708 }
1709 })
1710 })
1711 });
1712
1713 let git_state = match git_state {
1714 Some(git_state) => match git_state.ok() {
1715 Some(git_state) => git_state.await.ok(),
1716 None => None,
1717 },
1718 None => None,
1719 };
1720
1721 WorktreeSnapshot {
1722 worktree_path,
1723 git_state,
1724 }
1725 })
1726 }
1727
1728 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1729 let mut markdown = Vec::new();
1730
1731 if let Some(summary) = self.summary() {
1732 writeln!(markdown, "# {summary}\n")?;
1733 };
1734
1735 for message in self.messages() {
1736 writeln!(
1737 markdown,
1738 "## {role}\n",
1739 role = match message.role {
1740 Role::User => "User",
1741 Role::Assistant => "Assistant",
1742 Role::System => "System",
1743 }
1744 )?;
1745
1746 if !message.context.is_empty() {
1747 writeln!(markdown, "{}", message.context)?;
1748 }
1749
1750 for segment in &message.segments {
1751 match segment {
1752 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1753 MessageSegment::Thinking(text) => {
1754 writeln!(markdown, "<think>{}</think>\n", text)?
1755 }
1756 }
1757 }
1758
1759 for tool_use in self.tool_uses_for_message(message.id, cx) {
1760 writeln!(
1761 markdown,
1762 "**Use Tool: {} ({})**",
1763 tool_use.name, tool_use.id
1764 )?;
1765 writeln!(markdown, "```json")?;
1766 writeln!(
1767 markdown,
1768 "{}",
1769 serde_json::to_string_pretty(&tool_use.input)?
1770 )?;
1771 writeln!(markdown, "```")?;
1772 }
1773
1774 for tool_result in self.tool_results_for_message(message.id) {
1775 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1776 if tool_result.is_error {
1777 write!(markdown, " (Error)")?;
1778 }
1779
1780 writeln!(markdown, "**\n")?;
1781 writeln!(markdown, "{}", tool_result.content)?;
1782 }
1783 }
1784
1785 Ok(String::from_utf8_lossy(&markdown).to_string())
1786 }
1787
1788 pub fn keep_edits_in_range(
1789 &mut self,
1790 buffer: Entity<language::Buffer>,
1791 buffer_range: Range<language::Anchor>,
1792 cx: &mut Context<Self>,
1793 ) {
1794 self.action_log.update(cx, |action_log, cx| {
1795 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1796 });
1797 }
1798
1799 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1800 self.action_log
1801 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1802 }
1803
1804 pub fn reject_edits_in_ranges(
1805 &mut self,
1806 buffer: Entity<language::Buffer>,
1807 buffer_ranges: Vec<Range<language::Anchor>>,
1808 cx: &mut Context<Self>,
1809 ) -> Task<Result<()>> {
1810 self.action_log.update(cx, |action_log, cx| {
1811 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1812 })
1813 }
1814
1815 pub fn action_log(&self) -> &Entity<ActionLog> {
1816 &self.action_log
1817 }
1818
1819 pub fn project(&self) -> &Entity<Project> {
1820 &self.project
1821 }
1822
1823 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1824 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1825 return;
1826 }
1827
1828 let now = Instant::now();
1829 if let Some(last) = self.last_auto_capture_at {
1830 if now.duration_since(last).as_secs() < 10 {
1831 return;
1832 }
1833 }
1834
1835 self.last_auto_capture_at = Some(now);
1836
1837 let thread_id = self.id().clone();
1838 let github_login = self
1839 .project
1840 .read(cx)
1841 .user_store()
1842 .read(cx)
1843 .current_user()
1844 .map(|user| user.github_login.clone());
1845 let client = self.project.read(cx).client().clone();
1846 let serialize_task = self.serialize(cx);
1847
1848 cx.background_executor()
1849 .spawn(async move {
1850 if let Ok(serialized_thread) = serialize_task.await {
1851 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1852 telemetry::event!(
1853 "Agent Thread Auto-Captured",
1854 thread_id = thread_id.to_string(),
1855 thread_data = thread_data,
1856 auto_capture_reason = "tracked_user",
1857 github_login = github_login
1858 );
1859
1860 client.telemetry().flush_events();
1861 }
1862 }
1863 })
1864 .detach();
1865 }
1866
1867 pub fn cumulative_token_usage(&self) -> TokenUsage {
1868 self.cumulative_token_usage
1869 }
1870
1871 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1872 let model_registry = LanguageModelRegistry::read_global(cx);
1873 let Some(model) = model_registry.default_model() else {
1874 return TotalTokenUsage::default();
1875 };
1876
1877 let max = model.model.max_token_count();
1878
1879 if let Some(exceeded_error) = &self.exceeded_window_error {
1880 if model.model.id() == exceeded_error.model_id {
1881 return TotalTokenUsage {
1882 total: exceeded_error.token_count,
1883 max,
1884 ratio: TokenUsageRatio::Exceeded,
1885 };
1886 }
1887 }
1888
1889 #[cfg(debug_assertions)]
1890 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1891 .unwrap_or("0.8".to_string())
1892 .parse()
1893 .unwrap();
1894 #[cfg(not(debug_assertions))]
1895 let warning_threshold: f32 = 0.8;
1896
1897 let total = self.cumulative_token_usage.total_tokens() as usize;
1898
1899 let ratio = if total >= max {
1900 TokenUsageRatio::Exceeded
1901 } else if total as f32 / max as f32 >= warning_threshold {
1902 TokenUsageRatio::Warning
1903 } else {
1904 TokenUsageRatio::Normal
1905 };
1906
1907 TotalTokenUsage { total, max, ratio }
1908 }
1909
1910 pub fn deny_tool_use(
1911 &mut self,
1912 tool_use_id: LanguageModelToolUseId,
1913 tool_name: Arc<str>,
1914 cx: &mut Context<Self>,
1915 ) {
1916 let err = Err(anyhow::anyhow!(
1917 "Permission to run tool action denied by user"
1918 ));
1919
1920 self.tool_use
1921 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1922 self.tool_finished(tool_use_id.clone(), None, true, cx);
1923 }
1924}
1925
1926#[derive(Debug, Clone, Error)]
1927pub enum ThreadError {
1928 #[error("Payment required")]
1929 PaymentRequired,
1930 #[error("Max monthly spend reached")]
1931 MaxMonthlySpendReached,
1932 #[error("Message {header}: {message}")]
1933 Message {
1934 header: SharedString,
1935 message: SharedString,
1936 },
1937}
1938
1939#[derive(Debug, Clone)]
1940pub enum ThreadEvent {
1941 ShowError(ThreadError),
1942 StreamedCompletion,
1943 StreamedAssistantText(MessageId, String),
1944 StreamedAssistantThinking(MessageId, String),
1945 Stopped(Result<StopReason, Arc<anyhow::Error>>),
1946 MessageAdded(MessageId),
1947 MessageEdited(MessageId),
1948 MessageDeleted(MessageId),
1949 SummaryGenerated,
1950 SummaryChanged,
1951 UsePendingTools {
1952 tool_uses: Vec<PendingToolUse>,
1953 },
1954 ToolFinished {
1955 #[allow(unused)]
1956 tool_use_id: LanguageModelToolUseId,
1957 /// The pending tool use that corresponds to this tool.
1958 pending_tool_use: Option<PendingToolUse>,
1959 },
1960 CheckpointChanged,
1961 ToolConfirmationNeeded,
1962}
1963
1964impl EventEmitter<ThreadEvent> for Thread {}
1965
1966struct PendingCompletion {
1967 id: usize,
1968 _task: Task<()>,
1969}
1970
1971#[cfg(test)]
1972mod tests {
1973 use super::*;
1974 use crate::{ThreadStore, context_store::ContextStore, thread_store};
1975 use assistant_settings::AssistantSettings;
1976 use context_server::ContextServerSettings;
1977 use editor::EditorSettings;
1978 use gpui::TestAppContext;
1979 use project::{FakeFs, Project};
1980 use prompt_store::PromptBuilder;
1981 use serde_json::json;
1982 use settings::{Settings, SettingsStore};
1983 use std::sync::Arc;
1984 use theme::ThemeSettings;
1985 use util::path;
1986 use workspace::Workspace;
1987
1988 #[gpui::test]
1989 async fn test_message_with_context(cx: &mut TestAppContext) {
1990 init_test_settings(cx);
1991
1992 let project = create_test_project(
1993 cx,
1994 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1995 )
1996 .await;
1997
1998 let (_workspace, _thread_store, thread, context_store) =
1999 setup_test_environment(cx, project.clone()).await;
2000
2001 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2002 .await
2003 .unwrap();
2004
2005 let context =
2006 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2007
2008 // Insert user message with context
2009 let message_id = thread.update(cx, |thread, cx| {
2010 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2011 });
2012
2013 // Check content and context in message object
2014 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2015
2016 // Use different path format strings based on platform for the test
2017 #[cfg(windows)]
2018 let path_part = r"test\code.rs";
2019 #[cfg(not(windows))]
2020 let path_part = "test/code.rs";
2021
2022 let expected_context = format!(
2023 r#"
2024<context>
2025The following items were attached by the user. You don't need to use other tools to read them.
2026
2027<files>
2028```rs {path_part}
2029fn main() {{
2030 println!("Hello, world!");
2031}}
2032```
2033</files>
2034</context>
2035"#
2036 );
2037
2038 assert_eq!(message.role, Role::User);
2039 assert_eq!(message.segments.len(), 1);
2040 assert_eq!(
2041 message.segments[0],
2042 MessageSegment::Text("Please explain this code".to_string())
2043 );
2044 assert_eq!(message.context, expected_context);
2045
2046 // Check message in request
2047 let request = thread.read_with(cx, |thread, cx| {
2048 thread.to_completion_request(RequestKind::Chat, cx)
2049 });
2050
2051 assert_eq!(request.messages.len(), 2);
2052 let expected_full_message = format!("{}Please explain this code", expected_context);
2053 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2054 }
2055
2056 #[gpui::test]
2057 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2058 init_test_settings(cx);
2059
2060 let project = create_test_project(
2061 cx,
2062 json!({
2063 "file1.rs": "fn function1() {}\n",
2064 "file2.rs": "fn function2() {}\n",
2065 "file3.rs": "fn function3() {}\n",
2066 }),
2067 )
2068 .await;
2069
2070 let (_, _thread_store, thread, context_store) =
2071 setup_test_environment(cx, project.clone()).await;
2072
2073 // Open files individually
2074 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2075 .await
2076 .unwrap();
2077 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2078 .await
2079 .unwrap();
2080 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2081 .await
2082 .unwrap();
2083
2084 // Get the context objects
2085 let contexts = context_store.update(cx, |store, _| store.context().clone());
2086 assert_eq!(contexts.len(), 3);
2087
2088 // First message with context 1
2089 let message1_id = thread.update(cx, |thread, cx| {
2090 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2091 });
2092
2093 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2094 let message2_id = thread.update(cx, |thread, cx| {
2095 thread.insert_user_message(
2096 "Message 2",
2097 vec![contexts[0].clone(), contexts[1].clone()],
2098 None,
2099 cx,
2100 )
2101 });
2102
2103 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2104 let message3_id = thread.update(cx, |thread, cx| {
2105 thread.insert_user_message(
2106 "Message 3",
2107 vec![
2108 contexts[0].clone(),
2109 contexts[1].clone(),
2110 contexts[2].clone(),
2111 ],
2112 None,
2113 cx,
2114 )
2115 });
2116
2117 // Check what contexts are included in each message
2118 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2119 (
2120 thread.message(message1_id).unwrap().clone(),
2121 thread.message(message2_id).unwrap().clone(),
2122 thread.message(message3_id).unwrap().clone(),
2123 )
2124 });
2125
2126 // First message should include context 1
2127 assert!(message1.context.contains("file1.rs"));
2128
2129 // Second message should include only context 2 (not 1)
2130 assert!(!message2.context.contains("file1.rs"));
2131 assert!(message2.context.contains("file2.rs"));
2132
2133 // Third message should include only context 3 (not 1 or 2)
2134 assert!(!message3.context.contains("file1.rs"));
2135 assert!(!message3.context.contains("file2.rs"));
2136 assert!(message3.context.contains("file3.rs"));
2137
2138 // Check entire request to make sure all contexts are properly included
2139 let request = thread.read_with(cx, |thread, cx| {
2140 thread.to_completion_request(RequestKind::Chat, cx)
2141 });
2142
2143 // The request should contain all 3 messages
2144 assert_eq!(request.messages.len(), 4);
2145
2146 // Check that the contexts are properly formatted in each message
2147 assert!(request.messages[1].string_contents().contains("file1.rs"));
2148 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2149 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2150
2151 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2152 assert!(request.messages[2].string_contents().contains("file2.rs"));
2153 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2154
2155 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2156 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2157 assert!(request.messages[3].string_contents().contains("file3.rs"));
2158 }
2159
2160 #[gpui::test]
2161 async fn test_message_without_files(cx: &mut TestAppContext) {
2162 init_test_settings(cx);
2163
2164 let project = create_test_project(
2165 cx,
2166 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2167 )
2168 .await;
2169
2170 let (_, _thread_store, thread, _context_store) =
2171 setup_test_environment(cx, project.clone()).await;
2172
2173 // Insert user message without any context (empty context vector)
2174 let message_id = thread.update(cx, |thread, cx| {
2175 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2176 });
2177
2178 // Check content and context in message object
2179 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2180
2181 // Context should be empty when no files are included
2182 assert_eq!(message.role, Role::User);
2183 assert_eq!(message.segments.len(), 1);
2184 assert_eq!(
2185 message.segments[0],
2186 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2187 );
2188 assert_eq!(message.context, "");
2189
2190 // Check message in request
2191 let request = thread.read_with(cx, |thread, cx| {
2192 thread.to_completion_request(RequestKind::Chat, cx)
2193 });
2194
2195 assert_eq!(request.messages.len(), 2);
2196 assert_eq!(
2197 request.messages[1].string_contents(),
2198 "What is the best way to learn Rust?"
2199 );
2200
2201 // Add second message, also without context
2202 let message2_id = thread.update(cx, |thread, cx| {
2203 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2204 });
2205
2206 let message2 =
2207 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2208 assert_eq!(message2.context, "");
2209
2210 // Check that both messages appear in the request
2211 let request = thread.read_with(cx, |thread, cx| {
2212 thread.to_completion_request(RequestKind::Chat, cx)
2213 });
2214
2215 assert_eq!(request.messages.len(), 3);
2216 assert_eq!(
2217 request.messages[1].string_contents(),
2218 "What is the best way to learn Rust?"
2219 );
2220 assert_eq!(
2221 request.messages[2].string_contents(),
2222 "Are there any good books?"
2223 );
2224 }
2225
2226 #[gpui::test]
2227 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2228 init_test_settings(cx);
2229
2230 let project = create_test_project(
2231 cx,
2232 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2233 )
2234 .await;
2235
2236 let (_workspace, _thread_store, thread, context_store) =
2237 setup_test_environment(cx, project.clone()).await;
2238
2239 // Open buffer and add it to context
2240 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2241 .await
2242 .unwrap();
2243
2244 let context =
2245 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2246
2247 // Insert user message with the buffer as context
2248 thread.update(cx, |thread, cx| {
2249 thread.insert_user_message("Explain this code", vec![context], None, cx)
2250 });
2251
2252 // Create a request and check that it doesn't have a stale buffer warning yet
2253 let initial_request = thread.read_with(cx, |thread, cx| {
2254 thread.to_completion_request(RequestKind::Chat, cx)
2255 });
2256
2257 // Make sure we don't have a stale file warning yet
2258 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2259 msg.string_contents()
2260 .contains("These files changed since last read:")
2261 });
2262 assert!(
2263 !has_stale_warning,
2264 "Should not have stale buffer warning before buffer is modified"
2265 );
2266
2267 // Modify the buffer
2268 buffer.update(cx, |buffer, cx| {
2269 // Find a position at the end of line 1
2270 buffer.edit(
2271 [(1..1, "\n println!(\"Added a new line\");\n")],
2272 None,
2273 cx,
2274 );
2275 });
2276
2277 // Insert another user message without context
2278 thread.update(cx, |thread, cx| {
2279 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2280 });
2281
2282 // Create a new request and check for the stale buffer warning
2283 let new_request = thread.read_with(cx, |thread, cx| {
2284 thread.to_completion_request(RequestKind::Chat, cx)
2285 });
2286
2287 // We should have a stale file warning as the last message
2288 let last_message = new_request
2289 .messages
2290 .last()
2291 .expect("Request should have messages");
2292
2293 // The last message should be the stale buffer notification
2294 assert_eq!(last_message.role, Role::User);
2295
2296 // Check the exact content of the message
2297 let expected_content = "These files changed since last read:\n- code.rs\n";
2298 assert_eq!(
2299 last_message.string_contents(),
2300 expected_content,
2301 "Last message should be exactly the stale buffer notification"
2302 );
2303 }
2304
2305 fn init_test_settings(cx: &mut TestAppContext) {
2306 cx.update(|cx| {
2307 let settings_store = SettingsStore::test(cx);
2308 cx.set_global(settings_store);
2309 language::init(cx);
2310 Project::init_settings(cx);
2311 AssistantSettings::register(cx);
2312 thread_store::init(cx);
2313 workspace::init_settings(cx);
2314 ThemeSettings::register(cx);
2315 ContextServerSettings::register(cx);
2316 EditorSettings::register(cx);
2317 });
2318 }
2319
2320 // Helper to create a test project with test files
2321 async fn create_test_project(
2322 cx: &mut TestAppContext,
2323 files: serde_json::Value,
2324 ) -> Entity<Project> {
2325 let fs = FakeFs::new(cx.executor());
2326 fs.insert_tree(path!("/test"), files).await;
2327 Project::test(fs, [path!("/test").as_ref()], cx).await
2328 }
2329
2330 async fn setup_test_environment(
2331 cx: &mut TestAppContext,
2332 project: Entity<Project>,
2333 ) -> (
2334 Entity<Workspace>,
2335 Entity<ThreadStore>,
2336 Entity<Thread>,
2337 Entity<ContextStore>,
2338 ) {
2339 let (workspace, cx) =
2340 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2341
2342 let thread_store = cx
2343 .update(|_, cx| {
2344 ThreadStore::load(
2345 project.clone(),
2346 cx.new(|_| ToolWorkingSet::default()),
2347 Arc::new(PromptBuilder::new(None).unwrap()),
2348 cx,
2349 )
2350 })
2351 .await;
2352
2353 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2354 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2355
2356 (workspace, thread_store, thread, context_store)
2357 }
2358
2359 async fn add_file_to_context(
2360 project: &Entity<Project>,
2361 context_store: &Entity<ContextStore>,
2362 path: &str,
2363 cx: &mut TestAppContext,
2364 ) -> Result<Entity<language::Buffer>> {
2365 let buffer_path = project
2366 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2367 .unwrap();
2368
2369 let buffer = project
2370 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2371 .await
2372 .unwrap();
2373
2374 context_store
2375 .update(cx, |store, cx| {
2376 store.add_file_from_buffer(buffer.clone(), cx)
2377 })
2378 .await?;
2379
2380 Ok(buffer)
2381 }
2382}