1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
22 Role, StopReason, TokenUsage,
23};
24use project::Project;
25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
26use prompt_store::PromptBuilder;
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use settings::Settings;
30use thiserror::Error;
31use util::{ResultExt as _, TryFutureExt as _, post_inc};
32use uuid::Uuid;
33
34use crate::context::{AssistantContext, ContextId, format_context_as_string};
35use crate::thread_store::{
36 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
37 SerializedToolUse, SharedProjectContext,
38};
39use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
40
41#[derive(Debug, Clone, Copy)]
42pub enum RequestKind {
43 Chat,
44 /// Used when summarizing a thread.
45 Summarize,
46}
47
48#[derive(
49 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
50)]
51pub struct ThreadId(Arc<str>);
52
53impl ThreadId {
54 pub fn new() -> Self {
55 Self(Uuid::new_v4().to_string().into())
56 }
57}
58
59impl std::fmt::Display for ThreadId {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65impl From<&str> for ThreadId {
66 fn from(value: &str) -> Self {
67 Self(value.into())
68 }
69}
70
71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
72pub struct MessageId(pub(crate) usize);
73
74impl MessageId {
75 fn post_inc(&mut self) -> Self {
76 Self(post_inc(&mut self.0))
77 }
78}
79
80/// A message in a [`Thread`].
81#[derive(Debug, Clone)]
82pub struct Message {
83 pub id: MessageId,
84 pub role: Role,
85 pub segments: Vec<MessageSegment>,
86 pub context: String,
87}
88
89impl Message {
90 /// Returns whether the message contains any meaningful text that should be displayed
91 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
92 pub fn should_display_content(&self) -> bool {
93 self.segments.iter().all(|segment| segment.should_display())
94 }
95
96 pub fn push_thinking(&mut self, text: &str) {
97 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
98 segment.push_str(text);
99 } else {
100 self.segments
101 .push(MessageSegment::Thinking(text.to_string()));
102 }
103 }
104
105 pub fn push_text(&mut self, text: &str) {
106 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
107 segment.push_str(text);
108 } else {
109 self.segments.push(MessageSegment::Text(text.to_string()));
110 }
111 }
112
113 pub fn to_string(&self) -> String {
114 let mut result = String::new();
115
116 if !self.context.is_empty() {
117 result.push_str(&self.context);
118 }
119
120 for segment in &self.segments {
121 match segment {
122 MessageSegment::Text(text) => result.push_str(text),
123 MessageSegment::Thinking(text) => {
124 result.push_str("<think>");
125 result.push_str(text);
126 result.push_str("</think>");
127 }
128 }
129 }
130
131 result
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum MessageSegment {
137 Text(String),
138 Thinking(String),
139}
140
141impl MessageSegment {
142 pub fn text_mut(&mut self) -> &mut String {
143 match self {
144 Self::Text(text) => text,
145 Self::Thinking(text) => text,
146 }
147 }
148
149 pub fn should_display(&self) -> bool {
150 // We add USING_TOOL_MARKER when making a request that includes tool uses
151 // without non-whitespace text around them, and this can cause the model
152 // to mimic the pattern, so we consider those segments not displayable.
153 match self {
154 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
155 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
156 }
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ProjectSnapshot {
162 pub worktree_snapshots: Vec<WorktreeSnapshot>,
163 pub unsaved_buffer_paths: Vec<String>,
164 pub timestamp: DateTime<Utc>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct WorktreeSnapshot {
169 pub worktree_path: String,
170 pub git_state: Option<GitState>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct GitState {
175 pub remote_url: Option<String>,
176 pub head_sha: Option<String>,
177 pub current_branch: Option<String>,
178 pub diff: Option<String>,
179}
180
181#[derive(Clone)]
182pub struct ThreadCheckpoint {
183 message_id: MessageId,
184 git_checkpoint: GitStoreCheckpoint,
185}
186
187#[derive(Copy, Clone, Debug, PartialEq, Eq)]
188pub enum ThreadFeedback {
189 Positive,
190 Negative,
191}
192
193pub enum LastRestoreCheckpoint {
194 Pending {
195 message_id: MessageId,
196 },
197 Error {
198 message_id: MessageId,
199 error: String,
200 },
201}
202
203impl LastRestoreCheckpoint {
204 pub fn message_id(&self) -> MessageId {
205 match self {
206 LastRestoreCheckpoint::Pending { message_id } => *message_id,
207 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
208 }
209 }
210}
211
212#[derive(Clone, Debug, Default, Serialize, Deserialize)]
213pub enum DetailedSummaryState {
214 #[default]
215 NotGenerated,
216 Generating {
217 message_id: MessageId,
218 },
219 Generated {
220 text: SharedString,
221 message_id: MessageId,
222 },
223}
224
225#[derive(Default)]
226pub struct TotalTokenUsage {
227 pub total: usize,
228 pub max: usize,
229 pub ratio: TokenUsageRatio,
230}
231
232#[derive(Debug, Default, PartialEq, Eq)]
233pub enum TokenUsageRatio {
234 #[default]
235 Normal,
236 Warning,
237 Exceeded,
238}
239
240/// A thread of conversation with the LLM.
241pub struct Thread {
242 id: ThreadId,
243 updated_at: DateTime<Utc>,
244 summary: Option<SharedString>,
245 pending_summary: Task<Option<()>>,
246 detailed_summary_state: DetailedSummaryState,
247 messages: Vec<Message>,
248 next_message_id: MessageId,
249 context: BTreeMap<ContextId, AssistantContext>,
250 context_by_message: HashMap<MessageId, Vec<ContextId>>,
251 project_context: SharedProjectContext,
252 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
253 completion_count: usize,
254 pending_completions: Vec<PendingCompletion>,
255 project: Entity<Project>,
256 prompt_builder: Arc<PromptBuilder>,
257 tools: Arc<ToolWorkingSet>,
258 tool_use: ToolUseState,
259 action_log: Entity<ActionLog>,
260 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
261 pending_checkpoint: Option<ThreadCheckpoint>,
262 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
263 cumulative_token_usage: TokenUsage,
264 exceeded_window_error: Option<ExceededWindowError>,
265 feedback: Option<ThreadFeedback>,
266 message_feedback: HashMap<MessageId, ThreadFeedback>,
267 last_auto_capture_at: Option<Instant>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct ExceededWindowError {
272 /// Model used when last message exceeded context window
273 model_id: LanguageModelId,
274 /// Token count including last message
275 token_count: usize,
276}
277
278impl Thread {
279 pub fn new(
280 project: Entity<Project>,
281 tools: Arc<ToolWorkingSet>,
282 prompt_builder: Arc<PromptBuilder>,
283 system_prompt: SharedProjectContext,
284 cx: &mut Context<Self>,
285 ) -> Self {
286 Self {
287 id: ThreadId::new(),
288 updated_at: Utc::now(),
289 summary: None,
290 pending_summary: Task::ready(None),
291 detailed_summary_state: DetailedSummaryState::NotGenerated,
292 messages: Vec::new(),
293 next_message_id: MessageId(0),
294 context: BTreeMap::default(),
295 context_by_message: HashMap::default(),
296 project_context: system_prompt,
297 checkpoints_by_message: HashMap::default(),
298 completion_count: 0,
299 pending_completions: Vec::new(),
300 project: project.clone(),
301 prompt_builder,
302 tools: tools.clone(),
303 last_restore_checkpoint: None,
304 pending_checkpoint: None,
305 tool_use: ToolUseState::new(tools.clone()),
306 action_log: cx.new(|_| ActionLog::new(project.clone())),
307 initial_project_snapshot: {
308 let project_snapshot = Self::project_snapshot(project, cx);
309 cx.foreground_executor()
310 .spawn(async move { Some(project_snapshot.await) })
311 .shared()
312 },
313 cumulative_token_usage: TokenUsage::default(),
314 exceeded_window_error: None,
315 feedback: None,
316 message_feedback: HashMap::default(),
317 last_auto_capture_at: None,
318 }
319 }
320
321 pub fn deserialize(
322 id: ThreadId,
323 serialized: SerializedThread,
324 project: Entity<Project>,
325 tools: Arc<ToolWorkingSet>,
326 prompt_builder: Arc<PromptBuilder>,
327 project_context: SharedProjectContext,
328 cx: &mut Context<Self>,
329 ) -> Self {
330 let next_message_id = MessageId(
331 serialized
332 .messages
333 .last()
334 .map(|message| message.id.0 + 1)
335 .unwrap_or(0),
336 );
337 let tool_use =
338 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
339
340 Self {
341 id,
342 updated_at: serialized.updated_at,
343 summary: Some(serialized.summary),
344 pending_summary: Task::ready(None),
345 detailed_summary_state: serialized.detailed_summary_state,
346 messages: serialized
347 .messages
348 .into_iter()
349 .map(|message| Message {
350 id: message.id,
351 role: message.role,
352 segments: message
353 .segments
354 .into_iter()
355 .map(|segment| match segment {
356 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
357 SerializedMessageSegment::Thinking { text } => {
358 MessageSegment::Thinking(text)
359 }
360 })
361 .collect(),
362 context: message.context,
363 })
364 .collect(),
365 next_message_id,
366 context: BTreeMap::default(),
367 context_by_message: HashMap::default(),
368 project_context,
369 checkpoints_by_message: HashMap::default(),
370 completion_count: 0,
371 pending_completions: Vec::new(),
372 last_restore_checkpoint: None,
373 pending_checkpoint: None,
374 project: project.clone(),
375 prompt_builder,
376 tools,
377 tool_use,
378 action_log: cx.new(|_| ActionLog::new(project)),
379 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
380 cumulative_token_usage: serialized.cumulative_token_usage,
381 exceeded_window_error: None,
382 feedback: None,
383 message_feedback: HashMap::default(),
384 last_auto_capture_at: None,
385 }
386 }
387
388 pub fn id(&self) -> &ThreadId {
389 &self.id
390 }
391
392 pub fn is_empty(&self) -> bool {
393 self.messages.is_empty()
394 }
395
396 pub fn updated_at(&self) -> DateTime<Utc> {
397 self.updated_at
398 }
399
400 pub fn touch_updated_at(&mut self) {
401 self.updated_at = Utc::now();
402 }
403
404 pub fn summary(&self) -> Option<SharedString> {
405 self.summary.clone()
406 }
407
408 pub fn project_context(&self) -> SharedProjectContext {
409 self.project_context.clone()
410 }
411
412 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
413
414 pub fn summary_or_default(&self) -> SharedString {
415 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
416 }
417
418 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
419 let Some(current_summary) = &self.summary else {
420 // Don't allow setting summary until generated
421 return;
422 };
423
424 let mut new_summary = new_summary.into();
425
426 if new_summary.is_empty() {
427 new_summary = Self::DEFAULT_SUMMARY;
428 }
429
430 if current_summary != &new_summary {
431 self.summary = Some(new_summary);
432 cx.emit(ThreadEvent::SummaryChanged);
433 }
434 }
435
436 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
437 self.latest_detailed_summary()
438 .unwrap_or_else(|| self.text().into())
439 }
440
441 fn latest_detailed_summary(&self) -> Option<SharedString> {
442 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
443 Some(text.clone())
444 } else {
445 None
446 }
447 }
448
449 pub fn message(&self, id: MessageId) -> Option<&Message> {
450 self.messages.iter().find(|message| message.id == id)
451 }
452
453 pub fn messages(&self) -> impl Iterator<Item = &Message> {
454 self.messages.iter()
455 }
456
457 pub fn is_generating(&self) -> bool {
458 !self.pending_completions.is_empty() || !self.all_tools_finished()
459 }
460
461 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
462 &self.tools
463 }
464
465 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
466 self.tool_use
467 .pending_tool_uses()
468 .into_iter()
469 .find(|tool_use| &tool_use.id == id)
470 }
471
472 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
473 self.tool_use
474 .pending_tool_uses()
475 .into_iter()
476 .filter(|tool_use| tool_use.status.needs_confirmation())
477 }
478
479 pub fn has_pending_tool_uses(&self) -> bool {
480 !self.tool_use.pending_tool_uses().is_empty()
481 }
482
483 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
484 self.checkpoints_by_message.get(&id).cloned()
485 }
486
487 pub fn restore_checkpoint(
488 &mut self,
489 checkpoint: ThreadCheckpoint,
490 cx: &mut Context<Self>,
491 ) -> Task<Result<()>> {
492 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
493 message_id: checkpoint.message_id,
494 });
495 cx.emit(ThreadEvent::CheckpointChanged);
496 cx.notify();
497
498 let git_store = self.project().read(cx).git_store().clone();
499 let restore = git_store.update(cx, |git_store, cx| {
500 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
501 });
502
503 cx.spawn(async move |this, cx| {
504 let result = restore.await;
505 this.update(cx, |this, cx| {
506 if let Err(err) = result.as_ref() {
507 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
508 message_id: checkpoint.message_id,
509 error: err.to_string(),
510 });
511 } else {
512 this.truncate(checkpoint.message_id, cx);
513 this.last_restore_checkpoint = None;
514 }
515 this.pending_checkpoint = None;
516 cx.emit(ThreadEvent::CheckpointChanged);
517 cx.notify();
518 })?;
519 result
520 })
521 }
522
523 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
524 let pending_checkpoint = if self.is_generating() {
525 return;
526 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
527 checkpoint
528 } else {
529 return;
530 };
531
532 let git_store = self.project.read(cx).git_store().clone();
533 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
534 cx.spawn(async move |this, cx| match final_checkpoint.await {
535 Ok(final_checkpoint) => {
536 let equal = git_store
537 .update(cx, |store, cx| {
538 store.compare_checkpoints(
539 pending_checkpoint.git_checkpoint.clone(),
540 final_checkpoint.clone(),
541 cx,
542 )
543 })?
544 .await
545 .unwrap_or(false);
546
547 if equal {
548 git_store
549 .update(cx, |store, cx| {
550 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
551 })?
552 .detach();
553 } else {
554 this.update(cx, |this, cx| {
555 this.insert_checkpoint(pending_checkpoint, cx)
556 })?;
557 }
558
559 git_store
560 .update(cx, |store, cx| {
561 store.delete_checkpoint(final_checkpoint, cx)
562 })?
563 .detach();
564
565 Ok(())
566 }
567 Err(_) => this.update(cx, |this, cx| {
568 this.insert_checkpoint(pending_checkpoint, cx)
569 }),
570 })
571 .detach();
572 }
573
574 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
575 self.checkpoints_by_message
576 .insert(checkpoint.message_id, checkpoint);
577 cx.emit(ThreadEvent::CheckpointChanged);
578 cx.notify();
579 }
580
581 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
582 self.last_restore_checkpoint.as_ref()
583 }
584
585 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
586 let Some(message_ix) = self
587 .messages
588 .iter()
589 .rposition(|message| message.id == message_id)
590 else {
591 return;
592 };
593 for deleted_message in self.messages.drain(message_ix..) {
594 self.context_by_message.remove(&deleted_message.id);
595 self.checkpoints_by_message.remove(&deleted_message.id);
596 }
597 cx.notify();
598 }
599
600 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
601 self.context_by_message
602 .get(&id)
603 .into_iter()
604 .flat_map(|context| {
605 context
606 .iter()
607 .filter_map(|context_id| self.context.get(&context_id))
608 })
609 }
610
611 /// Returns whether all of the tool uses have finished running.
612 pub fn all_tools_finished(&self) -> bool {
613 // If the only pending tool uses left are the ones with errors, then
614 // that means that we've finished running all of the pending tools.
615 self.tool_use
616 .pending_tool_uses()
617 .iter()
618 .all(|tool_use| tool_use.status.is_error())
619 }
620
621 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
622 self.tool_use.tool_uses_for_message(id, cx)
623 }
624
625 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
626 self.tool_use.tool_results_for_message(id)
627 }
628
629 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
630 self.tool_use.tool_result(id)
631 }
632
633 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
634 self.tool_use.message_has_tool_results(message_id)
635 }
636
637 pub fn insert_user_message(
638 &mut self,
639 text: impl Into<String>,
640 context: Vec<AssistantContext>,
641 git_checkpoint: Option<GitStoreCheckpoint>,
642 cx: &mut Context<Self>,
643 ) -> MessageId {
644 let text = text.into();
645
646 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
647
648 // Filter out contexts that have already been included in previous messages
649 let new_context: Vec<_> = context
650 .into_iter()
651 .filter(|ctx| !self.context.contains_key(&ctx.id()))
652 .collect();
653
654 if !new_context.is_empty() {
655 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
656 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
657 message.context = context_string;
658 }
659 }
660
661 self.action_log.update(cx, |log, cx| {
662 // Track all buffers added as context
663 for ctx in &new_context {
664 match ctx {
665 AssistantContext::File(file_ctx) => {
666 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
667 }
668 AssistantContext::Directory(dir_ctx) => {
669 for context_buffer in &dir_ctx.context_buffers {
670 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
671 }
672 }
673 AssistantContext::Symbol(symbol_ctx) => {
674 log.buffer_added_as_context(
675 symbol_ctx.context_symbol.buffer.clone(),
676 cx,
677 );
678 }
679 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
680 }
681 }
682 });
683 }
684
685 let context_ids = new_context
686 .iter()
687 .map(|context| context.id())
688 .collect::<Vec<_>>();
689 self.context.extend(
690 new_context
691 .into_iter()
692 .map(|context| (context.id(), context)),
693 );
694 self.context_by_message.insert(message_id, context_ids);
695
696 if let Some(git_checkpoint) = git_checkpoint {
697 self.pending_checkpoint = Some(ThreadCheckpoint {
698 message_id,
699 git_checkpoint,
700 });
701 }
702
703 self.auto_capture_telemetry(cx);
704
705 message_id
706 }
707
708 pub fn insert_message(
709 &mut self,
710 role: Role,
711 segments: Vec<MessageSegment>,
712 cx: &mut Context<Self>,
713 ) -> MessageId {
714 let id = self.next_message_id.post_inc();
715 self.messages.push(Message {
716 id,
717 role,
718 segments,
719 context: String::new(),
720 });
721 self.touch_updated_at();
722 cx.emit(ThreadEvent::MessageAdded(id));
723 id
724 }
725
726 pub fn edit_message(
727 &mut self,
728 id: MessageId,
729 new_role: Role,
730 new_segments: Vec<MessageSegment>,
731 cx: &mut Context<Self>,
732 ) -> bool {
733 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
734 return false;
735 };
736 message.role = new_role;
737 message.segments = new_segments;
738 self.touch_updated_at();
739 cx.emit(ThreadEvent::MessageEdited(id));
740 true
741 }
742
743 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
744 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
745 return false;
746 };
747 self.messages.remove(index);
748 self.context_by_message.remove(&id);
749 self.touch_updated_at();
750 cx.emit(ThreadEvent::MessageDeleted(id));
751 true
752 }
753
754 /// Returns the representation of this [`Thread`] in a textual form.
755 ///
756 /// This is the representation we use when attaching a thread as context to another thread.
757 pub fn text(&self) -> String {
758 let mut text = String::new();
759
760 for message in &self.messages {
761 text.push_str(match message.role {
762 language_model::Role::User => "User:",
763 language_model::Role::Assistant => "Assistant:",
764 language_model::Role::System => "System:",
765 });
766 text.push('\n');
767
768 for segment in &message.segments {
769 match segment {
770 MessageSegment::Text(content) => text.push_str(content),
771 MessageSegment::Thinking(content) => {
772 text.push_str(&format!("<think>{}</think>", content))
773 }
774 }
775 }
776 text.push('\n');
777 }
778
779 text
780 }
781
782 /// Serializes this thread into a format for storage or telemetry.
783 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
784 let initial_project_snapshot = self.initial_project_snapshot.clone();
785 cx.spawn(async move |this, cx| {
786 let initial_project_snapshot = initial_project_snapshot.await;
787 this.read_with(cx, |this, cx| SerializedThread {
788 version: SerializedThread::VERSION.to_string(),
789 summary: this.summary_or_default(),
790 updated_at: this.updated_at(),
791 messages: this
792 .messages()
793 .map(|message| SerializedMessage {
794 id: message.id,
795 role: message.role,
796 segments: message
797 .segments
798 .iter()
799 .map(|segment| match segment {
800 MessageSegment::Text(text) => {
801 SerializedMessageSegment::Text { text: text.clone() }
802 }
803 MessageSegment::Thinking(text) => {
804 SerializedMessageSegment::Thinking { text: text.clone() }
805 }
806 })
807 .collect(),
808 tool_uses: this
809 .tool_uses_for_message(message.id, cx)
810 .into_iter()
811 .map(|tool_use| SerializedToolUse {
812 id: tool_use.id,
813 name: tool_use.name,
814 input: tool_use.input,
815 })
816 .collect(),
817 tool_results: this
818 .tool_results_for_message(message.id)
819 .into_iter()
820 .map(|tool_result| SerializedToolResult {
821 tool_use_id: tool_result.tool_use_id.clone(),
822 is_error: tool_result.is_error,
823 content: tool_result.content.clone(),
824 })
825 .collect(),
826 context: message.context.clone(),
827 })
828 .collect(),
829 initial_project_snapshot,
830 cumulative_token_usage: this.cumulative_token_usage.clone(),
831 detailed_summary_state: this.detailed_summary_state.clone(),
832 exceeded_window_error: this.exceeded_window_error.clone(),
833 })
834 })
835 }
836
837 pub fn send_to_model(
838 &mut self,
839 model: Arc<dyn LanguageModel>,
840 request_kind: RequestKind,
841 cx: &mut Context<Self>,
842 ) {
843 let mut request = self.to_completion_request(request_kind, cx);
844 if model.supports_tools() {
845 request.tools = {
846 let mut tools = Vec::new();
847 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
848 LanguageModelRequestTool {
849 name: tool.name(),
850 description: tool.description(),
851 input_schema: tool.input_schema(model.tool_input_format()),
852 }
853 }));
854
855 tools
856 };
857 }
858
859 self.stream_completion(request, model, cx);
860 }
861
862 pub fn used_tools_since_last_user_message(&self) -> bool {
863 for message in self.messages.iter().rev() {
864 if self.tool_use.message_has_tool_results(message.id) {
865 return true;
866 } else if message.role == Role::User {
867 return false;
868 }
869 }
870
871 false
872 }
873
874 pub fn to_completion_request(
875 &self,
876 request_kind: RequestKind,
877 cx: &App,
878 ) -> LanguageModelRequest {
879 let mut request = LanguageModelRequest {
880 messages: vec![],
881 tools: Vec::new(),
882 stop: Vec::new(),
883 temperature: None,
884 };
885
886 if let Some(project_context) = self.project_context.borrow().as_ref() {
887 if let Some(system_prompt) = self
888 .prompt_builder
889 .generate_assistant_system_prompt(project_context)
890 .context("failed to generate assistant system prompt")
891 .log_err()
892 {
893 request.messages.push(LanguageModelRequestMessage {
894 role: Role::System,
895 content: vec![MessageContent::Text(system_prompt)],
896 cache: true,
897 });
898 }
899 } else {
900 log::error!("project_context not set.")
901 }
902
903 for message in &self.messages {
904 let mut request_message = LanguageModelRequestMessage {
905 role: message.role,
906 content: Vec::new(),
907 cache: false,
908 };
909
910 match request_kind {
911 RequestKind::Chat => {
912 self.tool_use
913 .attach_tool_results(message.id, &mut request_message);
914 }
915 RequestKind::Summarize => {
916 // We don't care about tool use during summarization.
917 if self.tool_use.message_has_tool_results(message.id) {
918 continue;
919 }
920 }
921 }
922
923 if !message.segments.is_empty() {
924 request_message
925 .content
926 .push(MessageContent::Text(message.to_string()));
927 }
928
929 match request_kind {
930 RequestKind::Chat => {
931 self.tool_use
932 .attach_tool_uses(message.id, &mut request_message);
933 }
934 RequestKind::Summarize => {
935 // We don't care about tool use during summarization.
936 }
937 };
938
939 request.messages.push(request_message);
940 }
941
942 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
943 if let Some(last) = request.messages.last_mut() {
944 last.cache = true;
945 }
946
947 self.attached_tracked_files_state(&mut request.messages, cx);
948
949 request
950 }
951
952 fn attached_tracked_files_state(
953 &self,
954 messages: &mut Vec<LanguageModelRequestMessage>,
955 cx: &App,
956 ) {
957 const STALE_FILES_HEADER: &str = "These files changed since last read:";
958
959 let mut stale_message = String::new();
960
961 let action_log = self.action_log.read(cx);
962
963 for stale_file in action_log.stale_buffers(cx) {
964 let Some(file) = stale_file.read(cx).file() else {
965 continue;
966 };
967
968 if stale_message.is_empty() {
969 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
970 }
971
972 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
973 }
974
975 let mut content = Vec::with_capacity(2);
976
977 if !stale_message.is_empty() {
978 content.push(stale_message.into());
979 }
980
981 if action_log.has_edited_files_since_project_diagnostics_check() {
982 content.push(
983 "\n\nWhen you're done making changes, make sure to check project diagnostics \
984 and fix all errors AND warnings you introduced! \
985 DO NOT mention you're going to do this until you're done."
986 .into(),
987 );
988 }
989
990 if !content.is_empty() {
991 let context_message = LanguageModelRequestMessage {
992 role: Role::User,
993 content,
994 cache: false,
995 };
996
997 messages.push(context_message);
998 }
999 }
1000
1001 pub fn stream_completion(
1002 &mut self,
1003 request: LanguageModelRequest,
1004 model: Arc<dyn LanguageModel>,
1005 cx: &mut Context<Self>,
1006 ) {
1007 let pending_completion_id = post_inc(&mut self.completion_count);
1008
1009 let task = cx.spawn(async move |thread, cx| {
1010 let stream = model.stream_completion(request, &cx);
1011 let initial_token_usage =
1012 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1013 let stream_completion = async {
1014 let mut events = stream.await?;
1015 let mut stop_reason = StopReason::EndTurn;
1016 let mut current_token_usage = TokenUsage::default();
1017
1018 while let Some(event) = events.next().await {
1019 let event = event?;
1020
1021 thread.update(cx, |thread, cx| {
1022 match event {
1023 LanguageModelCompletionEvent::StartMessage { .. } => {
1024 thread.insert_message(
1025 Role::Assistant,
1026 vec![MessageSegment::Text(String::new())],
1027 cx,
1028 );
1029 }
1030 LanguageModelCompletionEvent::Stop(reason) => {
1031 stop_reason = reason;
1032 }
1033 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1034 thread.cumulative_token_usage =
1035 thread.cumulative_token_usage.clone() + token_usage.clone()
1036 - current_token_usage.clone();
1037 current_token_usage = token_usage;
1038 }
1039 LanguageModelCompletionEvent::Text(chunk) => {
1040 if let Some(last_message) = thread.messages.last_mut() {
1041 if last_message.role == Role::Assistant {
1042 last_message.push_text(&chunk);
1043 cx.emit(ThreadEvent::StreamedAssistantText(
1044 last_message.id,
1045 chunk,
1046 ));
1047 } else {
1048 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1049 // of a new Assistant response.
1050 //
1051 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1052 // will result in duplicating the text of the chunk in the rendered Markdown.
1053 thread.insert_message(
1054 Role::Assistant,
1055 vec![MessageSegment::Text(chunk.to_string())],
1056 cx,
1057 );
1058 };
1059 }
1060 }
1061 LanguageModelCompletionEvent::Thinking(chunk) => {
1062 if let Some(last_message) = thread.messages.last_mut() {
1063 if last_message.role == Role::Assistant {
1064 last_message.push_thinking(&chunk);
1065 cx.emit(ThreadEvent::StreamedAssistantThinking(
1066 last_message.id,
1067 chunk,
1068 ));
1069 } else {
1070 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1071 // of a new Assistant response.
1072 //
1073 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1074 // will result in duplicating the text of the chunk in the rendered Markdown.
1075 thread.insert_message(
1076 Role::Assistant,
1077 vec![MessageSegment::Thinking(chunk.to_string())],
1078 cx,
1079 );
1080 };
1081 }
1082 }
1083 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1084 let last_assistant_message_id = thread
1085 .messages
1086 .iter_mut()
1087 .rfind(|message| message.role == Role::Assistant)
1088 .map(|message| message.id)
1089 .unwrap_or_else(|| {
1090 thread.insert_message(Role::Assistant, vec![], cx)
1091 });
1092
1093 thread.tool_use.request_tool_use(
1094 last_assistant_message_id,
1095 tool_use,
1096 cx,
1097 );
1098 }
1099 }
1100
1101 thread.touch_updated_at();
1102 cx.emit(ThreadEvent::StreamedCompletion);
1103 cx.notify();
1104
1105 thread.auto_capture_telemetry(cx);
1106 })?;
1107
1108 smol::future::yield_now().await;
1109 }
1110
1111 thread.update(cx, |thread, cx| {
1112 thread
1113 .pending_completions
1114 .retain(|completion| completion.id != pending_completion_id);
1115
1116 if thread.summary.is_none() && thread.messages.len() >= 2 {
1117 thread.summarize(cx);
1118 }
1119 })?;
1120
1121 anyhow::Ok(stop_reason)
1122 };
1123
1124 let result = stream_completion.await;
1125
1126 thread
1127 .update(cx, |thread, cx| {
1128 thread.finalize_pending_checkpoint(cx);
1129 match result.as_ref() {
1130 Ok(stop_reason) => match stop_reason {
1131 StopReason::ToolUse => {
1132 let tool_uses = thread.use_pending_tools(cx);
1133 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1134 }
1135 StopReason::EndTurn => {}
1136 StopReason::MaxTokens => {}
1137 },
1138 Err(error) => {
1139 if error.is::<PaymentRequiredError>() {
1140 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1141 } else if error.is::<MaxMonthlySpendReachedError>() {
1142 cx.emit(ThreadEvent::ShowError(
1143 ThreadError::MaxMonthlySpendReached,
1144 ));
1145 } else if let Some(known_error) =
1146 error.downcast_ref::<LanguageModelKnownError>()
1147 {
1148 match known_error {
1149 LanguageModelKnownError::ContextWindowLimitExceeded {
1150 tokens,
1151 } => {
1152 thread.exceeded_window_error = Some(ExceededWindowError {
1153 model_id: model.id(),
1154 token_count: *tokens,
1155 });
1156 cx.notify();
1157 }
1158 }
1159 } else {
1160 let error_message = error
1161 .chain()
1162 .map(|err| err.to_string())
1163 .collect::<Vec<_>>()
1164 .join("\n");
1165 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1166 header: "Error interacting with language model".into(),
1167 message: SharedString::from(error_message.clone()),
1168 }));
1169 }
1170
1171 thread.cancel_last_completion(cx);
1172 }
1173 }
1174 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1175
1176 thread.auto_capture_telemetry(cx);
1177
1178 if let Ok(initial_usage) = initial_token_usage {
1179 let usage = thread.cumulative_token_usage.clone() - initial_usage;
1180
1181 telemetry::event!(
1182 "Assistant Thread Completion",
1183 thread_id = thread.id().to_string(),
1184 model = model.telemetry_id(),
1185 model_provider = model.provider_id().to_string(),
1186 input_tokens = usage.input_tokens,
1187 output_tokens = usage.output_tokens,
1188 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1189 cache_read_input_tokens = usage.cache_read_input_tokens,
1190 );
1191 }
1192 })
1193 .ok();
1194 });
1195
1196 self.pending_completions.push(PendingCompletion {
1197 id: pending_completion_id,
1198 _task: task,
1199 });
1200 }
1201
1202 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1203 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1204 return;
1205 };
1206
1207 if !model.provider.is_authenticated(cx) {
1208 return;
1209 }
1210
1211 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1212 request.messages.push(LanguageModelRequestMessage {
1213 role: Role::User,
1214 content: vec![
1215 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1216 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1217 If the conversation is about a specific subject, include it in the title. \
1218 Be descriptive. DO NOT speak in the first person."
1219 .into(),
1220 ],
1221 cache: false,
1222 });
1223
1224 self.pending_summary = cx.spawn(async move |this, cx| {
1225 async move {
1226 let stream = model.model.stream_completion_text(request, &cx);
1227 let mut messages = stream.await?;
1228
1229 let mut new_summary = String::new();
1230 while let Some(message) = messages.stream.next().await {
1231 let text = message?;
1232 let mut lines = text.lines();
1233 new_summary.extend(lines.next());
1234
1235 // Stop if the LLM generated multiple lines.
1236 if lines.next().is_some() {
1237 break;
1238 }
1239 }
1240
1241 this.update(cx, |this, cx| {
1242 if !new_summary.is_empty() {
1243 this.summary = Some(new_summary.into());
1244 }
1245
1246 cx.emit(ThreadEvent::SummaryGenerated);
1247 })?;
1248
1249 anyhow::Ok(())
1250 }
1251 .log_err()
1252 .await
1253 });
1254 }
1255
1256 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1257 let last_message_id = self.messages.last().map(|message| message.id)?;
1258
1259 match &self.detailed_summary_state {
1260 DetailedSummaryState::Generating { message_id, .. }
1261 | DetailedSummaryState::Generated { message_id, .. }
1262 if *message_id == last_message_id =>
1263 {
1264 // Already up-to-date
1265 return None;
1266 }
1267 _ => {}
1268 }
1269
1270 let ConfiguredModel { model, provider } =
1271 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1272
1273 if !provider.is_authenticated(cx) {
1274 return None;
1275 }
1276
1277 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1278
1279 request.messages.push(LanguageModelRequestMessage {
1280 role: Role::User,
1281 content: vec![
1282 "Generate a detailed summary of this conversation. Include:\n\
1283 1. A brief overview of what was discussed\n\
1284 2. Key facts or information discovered\n\
1285 3. Outcomes or conclusions reached\n\
1286 4. Any action items or next steps if any\n\
1287 Format it in Markdown with headings and bullet points."
1288 .into(),
1289 ],
1290 cache: false,
1291 });
1292
1293 let task = cx.spawn(async move |thread, cx| {
1294 let stream = model.stream_completion_text(request, &cx);
1295 let Some(mut messages) = stream.await.log_err() else {
1296 thread
1297 .update(cx, |this, _cx| {
1298 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1299 })
1300 .log_err();
1301
1302 return;
1303 };
1304
1305 let mut new_detailed_summary = String::new();
1306
1307 while let Some(chunk) = messages.stream.next().await {
1308 if let Some(chunk) = chunk.log_err() {
1309 new_detailed_summary.push_str(&chunk);
1310 }
1311 }
1312
1313 thread
1314 .update(cx, |this, _cx| {
1315 this.detailed_summary_state = DetailedSummaryState::Generated {
1316 text: new_detailed_summary.into(),
1317 message_id: last_message_id,
1318 };
1319 })
1320 .log_err();
1321 });
1322
1323 self.detailed_summary_state = DetailedSummaryState::Generating {
1324 message_id: last_message_id,
1325 };
1326
1327 Some(task)
1328 }
1329
1330 pub fn is_generating_detailed_summary(&self) -> bool {
1331 matches!(
1332 self.detailed_summary_state,
1333 DetailedSummaryState::Generating { .. }
1334 )
1335 }
1336
1337 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1338 self.auto_capture_telemetry(cx);
1339 let request = self.to_completion_request(RequestKind::Chat, cx);
1340 let messages = Arc::new(request.messages);
1341 let pending_tool_uses = self
1342 .tool_use
1343 .pending_tool_uses()
1344 .into_iter()
1345 .filter(|tool_use| tool_use.status.is_idle())
1346 .cloned()
1347 .collect::<Vec<_>>();
1348
1349 for tool_use in pending_tool_uses.iter() {
1350 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1351 if tool.needs_confirmation(&tool_use.input, cx)
1352 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1353 {
1354 self.tool_use.confirm_tool_use(
1355 tool_use.id.clone(),
1356 tool_use.ui_text.clone(),
1357 tool_use.input.clone(),
1358 messages.clone(),
1359 tool,
1360 );
1361 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1362 } else {
1363 self.run_tool(
1364 tool_use.id.clone(),
1365 tool_use.ui_text.clone(),
1366 tool_use.input.clone(),
1367 &messages,
1368 tool,
1369 cx,
1370 );
1371 }
1372 }
1373 }
1374
1375 pending_tool_uses
1376 }
1377
1378 pub fn run_tool(
1379 &mut self,
1380 tool_use_id: LanguageModelToolUseId,
1381 ui_text: impl Into<SharedString>,
1382 input: serde_json::Value,
1383 messages: &[LanguageModelRequestMessage],
1384 tool: Arc<dyn Tool>,
1385 cx: &mut Context<Thread>,
1386 ) {
1387 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1388 self.tool_use
1389 .run_pending_tool(tool_use_id, ui_text.into(), task);
1390 }
1391
1392 fn spawn_tool_use(
1393 &mut self,
1394 tool_use_id: LanguageModelToolUseId,
1395 messages: &[LanguageModelRequestMessage],
1396 input: serde_json::Value,
1397 tool: Arc<dyn Tool>,
1398 cx: &mut Context<Thread>,
1399 ) -> Task<()> {
1400 let tool_name: Arc<str> = tool.name().into();
1401
1402 let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1403 Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1404 } else {
1405 tool.run(
1406 input,
1407 messages,
1408 self.project.clone(),
1409 self.action_log.clone(),
1410 cx,
1411 )
1412 };
1413
1414 cx.spawn({
1415 async move |thread: WeakEntity<Thread>, cx| {
1416 let output = run_tool.await;
1417
1418 thread
1419 .update(cx, |thread, cx| {
1420 let pending_tool_use = thread.tool_use.insert_tool_output(
1421 tool_use_id.clone(),
1422 tool_name,
1423 output,
1424 cx,
1425 );
1426 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1427 })
1428 .ok();
1429 }
1430 })
1431 }
1432
1433 fn tool_finished(
1434 &mut self,
1435 tool_use_id: LanguageModelToolUseId,
1436 pending_tool_use: Option<PendingToolUse>,
1437 canceled: bool,
1438 cx: &mut Context<Self>,
1439 ) {
1440 if self.all_tools_finished() {
1441 let model_registry = LanguageModelRegistry::read_global(cx);
1442 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1443 self.attach_tool_results(cx);
1444 if !canceled {
1445 self.send_to_model(model, RequestKind::Chat, cx);
1446 }
1447 }
1448 }
1449
1450 cx.emit(ThreadEvent::ToolFinished {
1451 tool_use_id,
1452 pending_tool_use,
1453 });
1454 }
1455
1456 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1457 // Insert a user message to contain the tool results.
1458 self.insert_user_message(
1459 // TODO: Sending up a user message without any content results in the model sending back
1460 // responses that also don't have any content. We currently don't handle this case well,
1461 // so for now we provide some text to keep the model on track.
1462 "Here are the tool results.",
1463 Vec::new(),
1464 None,
1465 cx,
1466 );
1467 }
1468
1469 /// Cancels the last pending completion, if there are any pending.
1470 ///
1471 /// Returns whether a completion was canceled.
1472 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1473 let canceled = if self.pending_completions.pop().is_some() {
1474 true
1475 } else {
1476 let mut canceled = false;
1477 for pending_tool_use in self.tool_use.cancel_pending() {
1478 canceled = true;
1479 self.tool_finished(
1480 pending_tool_use.id.clone(),
1481 Some(pending_tool_use),
1482 true,
1483 cx,
1484 );
1485 }
1486 canceled
1487 };
1488 self.finalize_pending_checkpoint(cx);
1489 canceled
1490 }
1491
1492 pub fn feedback(&self) -> Option<ThreadFeedback> {
1493 self.feedback
1494 }
1495
1496 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1497 self.message_feedback.get(&message_id).copied()
1498 }
1499
1500 pub fn report_message_feedback(
1501 &mut self,
1502 message_id: MessageId,
1503 feedback: ThreadFeedback,
1504 cx: &mut Context<Self>,
1505 ) -> Task<Result<()>> {
1506 if self.message_feedback.get(&message_id) == Some(&feedback) {
1507 return Task::ready(Ok(()));
1508 }
1509
1510 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1511 let serialized_thread = self.serialize(cx);
1512 let thread_id = self.id().clone();
1513 let client = self.project.read(cx).client();
1514
1515 let enabled_tool_names: Vec<String> = self
1516 .tools()
1517 .enabled_tools(cx)
1518 .iter()
1519 .map(|tool| tool.name().to_string())
1520 .collect();
1521
1522 self.message_feedback.insert(message_id, feedback);
1523
1524 cx.notify();
1525
1526 let message_content = self
1527 .message(message_id)
1528 .map(|msg| msg.to_string())
1529 .unwrap_or_default();
1530
1531 cx.background_spawn(async move {
1532 let final_project_snapshot = final_project_snapshot.await;
1533 let serialized_thread = serialized_thread.await?;
1534 let thread_data =
1535 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1536
1537 let rating = match feedback {
1538 ThreadFeedback::Positive => "positive",
1539 ThreadFeedback::Negative => "negative",
1540 };
1541 telemetry::event!(
1542 "Assistant Thread Rated",
1543 rating,
1544 thread_id,
1545 enabled_tool_names,
1546 message_id = message_id.0,
1547 message_content,
1548 thread_data,
1549 final_project_snapshot
1550 );
1551 client.telemetry().flush_events();
1552
1553 Ok(())
1554 })
1555 }
1556
1557 pub fn report_feedback(
1558 &mut self,
1559 feedback: ThreadFeedback,
1560 cx: &mut Context<Self>,
1561 ) -> Task<Result<()>> {
1562 let last_assistant_message_id = self
1563 .messages
1564 .iter()
1565 .rev()
1566 .find(|msg| msg.role == Role::Assistant)
1567 .map(|msg| msg.id);
1568
1569 if let Some(message_id) = last_assistant_message_id {
1570 self.report_message_feedback(message_id, feedback, cx)
1571 } else {
1572 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1573 let serialized_thread = self.serialize(cx);
1574 let thread_id = self.id().clone();
1575 let client = self.project.read(cx).client();
1576 self.feedback = Some(feedback);
1577 cx.notify();
1578
1579 cx.background_spawn(async move {
1580 let final_project_snapshot = final_project_snapshot.await;
1581 let serialized_thread = serialized_thread.await?;
1582 let thread_data = serde_json::to_value(serialized_thread)
1583 .unwrap_or_else(|_| serde_json::Value::Null);
1584
1585 let rating = match feedback {
1586 ThreadFeedback::Positive => "positive",
1587 ThreadFeedback::Negative => "negative",
1588 };
1589 telemetry::event!(
1590 "Assistant Thread Rated",
1591 rating,
1592 thread_id,
1593 thread_data,
1594 final_project_snapshot
1595 );
1596 client.telemetry().flush_events();
1597
1598 Ok(())
1599 })
1600 }
1601 }
1602
1603 /// Create a snapshot of the current project state including git information and unsaved buffers.
1604 fn project_snapshot(
1605 project: Entity<Project>,
1606 cx: &mut Context<Self>,
1607 ) -> Task<Arc<ProjectSnapshot>> {
1608 let git_store = project.read(cx).git_store().clone();
1609 let worktree_snapshots: Vec<_> = project
1610 .read(cx)
1611 .visible_worktrees(cx)
1612 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1613 .collect();
1614
1615 cx.spawn(async move |_, cx| {
1616 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1617
1618 let mut unsaved_buffers = Vec::new();
1619 cx.update(|app_cx| {
1620 let buffer_store = project.read(app_cx).buffer_store();
1621 for buffer_handle in buffer_store.read(app_cx).buffers() {
1622 let buffer = buffer_handle.read(app_cx);
1623 if buffer.is_dirty() {
1624 if let Some(file) = buffer.file() {
1625 let path = file.path().to_string_lossy().to_string();
1626 unsaved_buffers.push(path);
1627 }
1628 }
1629 }
1630 })
1631 .ok();
1632
1633 Arc::new(ProjectSnapshot {
1634 worktree_snapshots,
1635 unsaved_buffer_paths: unsaved_buffers,
1636 timestamp: Utc::now(),
1637 })
1638 })
1639 }
1640
1641 fn worktree_snapshot(
1642 worktree: Entity<project::Worktree>,
1643 git_store: Entity<GitStore>,
1644 cx: &App,
1645 ) -> Task<WorktreeSnapshot> {
1646 cx.spawn(async move |cx| {
1647 // Get worktree path and snapshot
1648 let worktree_info = cx.update(|app_cx| {
1649 let worktree = worktree.read(app_cx);
1650 let path = worktree.abs_path().to_string_lossy().to_string();
1651 let snapshot = worktree.snapshot();
1652 (path, snapshot)
1653 });
1654
1655 let Ok((worktree_path, _snapshot)) = worktree_info else {
1656 return WorktreeSnapshot {
1657 worktree_path: String::new(),
1658 git_state: None,
1659 };
1660 };
1661
1662 let git_state = git_store
1663 .update(cx, |git_store, cx| {
1664 git_store
1665 .repositories()
1666 .values()
1667 .find(|repo| {
1668 repo.read(cx)
1669 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1670 .is_some()
1671 })
1672 .cloned()
1673 })
1674 .ok()
1675 .flatten()
1676 .map(|repo| {
1677 repo.update(cx, |repo, _| {
1678 let current_branch =
1679 repo.branch.as_ref().map(|branch| branch.name.to_string());
1680 repo.send_job(None, |state, _| async move {
1681 let RepositoryState::Local { backend, .. } = state else {
1682 return GitState {
1683 remote_url: None,
1684 head_sha: None,
1685 current_branch,
1686 diff: None,
1687 };
1688 };
1689
1690 let remote_url = backend.remote_url("origin");
1691 let head_sha = backend.head_sha();
1692 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1693
1694 GitState {
1695 remote_url,
1696 head_sha,
1697 current_branch,
1698 diff,
1699 }
1700 })
1701 })
1702 });
1703
1704 let git_state = match git_state {
1705 Some(git_state) => match git_state.ok() {
1706 Some(git_state) => git_state.await.ok(),
1707 None => None,
1708 },
1709 None => None,
1710 };
1711
1712 WorktreeSnapshot {
1713 worktree_path,
1714 git_state,
1715 }
1716 })
1717 }
1718
1719 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1720 let mut markdown = Vec::new();
1721
1722 if let Some(summary) = self.summary() {
1723 writeln!(markdown, "# {summary}\n")?;
1724 };
1725
1726 for message in self.messages() {
1727 writeln!(
1728 markdown,
1729 "## {role}\n",
1730 role = match message.role {
1731 Role::User => "User",
1732 Role::Assistant => "Assistant",
1733 Role::System => "System",
1734 }
1735 )?;
1736
1737 if !message.context.is_empty() {
1738 writeln!(markdown, "{}", message.context)?;
1739 }
1740
1741 for segment in &message.segments {
1742 match segment {
1743 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1744 MessageSegment::Thinking(text) => {
1745 writeln!(markdown, "<think>{}</think>\n", text)?
1746 }
1747 }
1748 }
1749
1750 for tool_use in self.tool_uses_for_message(message.id, cx) {
1751 writeln!(
1752 markdown,
1753 "**Use Tool: {} ({})**",
1754 tool_use.name, tool_use.id
1755 )?;
1756 writeln!(markdown, "```json")?;
1757 writeln!(
1758 markdown,
1759 "{}",
1760 serde_json::to_string_pretty(&tool_use.input)?
1761 )?;
1762 writeln!(markdown, "```")?;
1763 }
1764
1765 for tool_result in self.tool_results_for_message(message.id) {
1766 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1767 if tool_result.is_error {
1768 write!(markdown, " (Error)")?;
1769 }
1770
1771 writeln!(markdown, "**\n")?;
1772 writeln!(markdown, "{}", tool_result.content)?;
1773 }
1774 }
1775
1776 Ok(String::from_utf8_lossy(&markdown).to_string())
1777 }
1778
1779 pub fn keep_edits_in_range(
1780 &mut self,
1781 buffer: Entity<language::Buffer>,
1782 buffer_range: Range<language::Anchor>,
1783 cx: &mut Context<Self>,
1784 ) {
1785 self.action_log.update(cx, |action_log, cx| {
1786 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1787 });
1788 }
1789
1790 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1791 self.action_log
1792 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1793 }
1794
1795 pub fn reject_edits_in_range(
1796 &mut self,
1797 buffer: Entity<language::Buffer>,
1798 buffer_range: Range<language::Anchor>,
1799 cx: &mut Context<Self>,
1800 ) -> Task<Result<()>> {
1801 self.action_log.update(cx, |action_log, cx| {
1802 action_log.reject_edits_in_range(buffer, buffer_range, cx)
1803 })
1804 }
1805
1806 pub fn action_log(&self) -> &Entity<ActionLog> {
1807 &self.action_log
1808 }
1809
1810 pub fn project(&self) -> &Entity<Project> {
1811 &self.project
1812 }
1813
1814 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1815 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1816 return;
1817 }
1818
1819 let now = Instant::now();
1820 if let Some(last) = self.last_auto_capture_at {
1821 if now.duration_since(last).as_secs() < 10 {
1822 return;
1823 }
1824 }
1825
1826 self.last_auto_capture_at = Some(now);
1827
1828 let thread_id = self.id().clone();
1829 let github_login = self
1830 .project
1831 .read(cx)
1832 .user_store()
1833 .read(cx)
1834 .current_user()
1835 .map(|user| user.github_login.clone());
1836 let client = self.project.read(cx).client().clone();
1837 let serialize_task = self.serialize(cx);
1838
1839 cx.background_executor()
1840 .spawn(async move {
1841 if let Ok(serialized_thread) = serialize_task.await {
1842 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1843 telemetry::event!(
1844 "Agent Thread Auto-Captured",
1845 thread_id = thread_id.to_string(),
1846 thread_data = thread_data,
1847 auto_capture_reason = "tracked_user",
1848 github_login = github_login
1849 );
1850
1851 client.telemetry().flush_events();
1852 }
1853 }
1854 })
1855 .detach();
1856 }
1857
1858 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1859 let model_registry = LanguageModelRegistry::read_global(cx);
1860 let Some(model) = model_registry.default_model() else {
1861 return TotalTokenUsage::default();
1862 };
1863
1864 let max = model.model.max_token_count();
1865
1866 if let Some(exceeded_error) = &self.exceeded_window_error {
1867 if model.model.id() == exceeded_error.model_id {
1868 return TotalTokenUsage {
1869 total: exceeded_error.token_count,
1870 max,
1871 ratio: TokenUsageRatio::Exceeded,
1872 };
1873 }
1874 }
1875
1876 #[cfg(debug_assertions)]
1877 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1878 .unwrap_or("0.8".to_string())
1879 .parse()
1880 .unwrap();
1881 #[cfg(not(debug_assertions))]
1882 let warning_threshold: f32 = 0.8;
1883
1884 let total = self.cumulative_token_usage.total_tokens() as usize;
1885
1886 let ratio = if total >= max {
1887 TokenUsageRatio::Exceeded
1888 } else if total as f32 / max as f32 >= warning_threshold {
1889 TokenUsageRatio::Warning
1890 } else {
1891 TokenUsageRatio::Normal
1892 };
1893
1894 TotalTokenUsage { total, max, ratio }
1895 }
1896
1897 pub fn deny_tool_use(
1898 &mut self,
1899 tool_use_id: LanguageModelToolUseId,
1900 tool_name: Arc<str>,
1901 cx: &mut Context<Self>,
1902 ) {
1903 let err = Err(anyhow::anyhow!(
1904 "Permission to run tool action denied by user"
1905 ));
1906
1907 self.tool_use
1908 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1909 self.tool_finished(tool_use_id.clone(), None, true, cx);
1910 }
1911}
1912
1913#[derive(Debug, Clone, Error)]
1914pub enum ThreadError {
1915 #[error("Payment required")]
1916 PaymentRequired,
1917 #[error("Max monthly spend reached")]
1918 MaxMonthlySpendReached,
1919 #[error("Message {header}: {message}")]
1920 Message {
1921 header: SharedString,
1922 message: SharedString,
1923 },
1924}
1925
1926#[derive(Debug, Clone)]
1927pub enum ThreadEvent {
1928 ShowError(ThreadError),
1929 StreamedCompletion,
1930 StreamedAssistantText(MessageId, String),
1931 StreamedAssistantThinking(MessageId, String),
1932 Stopped(Result<StopReason, Arc<anyhow::Error>>),
1933 MessageAdded(MessageId),
1934 MessageEdited(MessageId),
1935 MessageDeleted(MessageId),
1936 SummaryGenerated,
1937 SummaryChanged,
1938 UsePendingTools {
1939 tool_uses: Vec<PendingToolUse>,
1940 },
1941 ToolFinished {
1942 #[allow(unused)]
1943 tool_use_id: LanguageModelToolUseId,
1944 /// The pending tool use that corresponds to this tool.
1945 pending_tool_use: Option<PendingToolUse>,
1946 },
1947 CheckpointChanged,
1948 ToolConfirmationNeeded,
1949}
1950
1951impl EventEmitter<ThreadEvent> for Thread {}
1952
1953struct PendingCompletion {
1954 id: usize,
1955 _task: Task<()>,
1956}
1957
1958#[cfg(test)]
1959mod tests {
1960 use super::*;
1961 use crate::{ThreadStore, context_store::ContextStore, thread_store};
1962 use assistant_settings::AssistantSettings;
1963 use context_server::ContextServerSettings;
1964 use editor::EditorSettings;
1965 use gpui::TestAppContext;
1966 use project::{FakeFs, Project};
1967 use prompt_store::PromptBuilder;
1968 use serde_json::json;
1969 use settings::{Settings, SettingsStore};
1970 use std::sync::Arc;
1971 use theme::ThemeSettings;
1972 use util::path;
1973 use workspace::Workspace;
1974
1975 #[gpui::test]
1976 async fn test_message_with_context(cx: &mut TestAppContext) {
1977 init_test_settings(cx);
1978
1979 let project = create_test_project(
1980 cx,
1981 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1982 )
1983 .await;
1984
1985 let (_workspace, _thread_store, thread, context_store) =
1986 setup_test_environment(cx, project.clone()).await;
1987
1988 add_file_to_context(&project, &context_store, "test/code.rs", cx)
1989 .await
1990 .unwrap();
1991
1992 let context =
1993 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1994
1995 // Insert user message with context
1996 let message_id = thread.update(cx, |thread, cx| {
1997 thread.insert_user_message("Please explain this code", vec![context], None, cx)
1998 });
1999
2000 // Check content and context in message object
2001 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2002
2003 // Use different path format strings based on platform for the test
2004 #[cfg(windows)]
2005 let path_part = r"test\code.rs";
2006 #[cfg(not(windows))]
2007 let path_part = "test/code.rs";
2008
2009 let expected_context = format!(
2010 r#"
2011<context>
2012The following items were attached by the user. You don't need to use other tools to read them.
2013
2014<files>
2015```rs {path_part}
2016fn main() {{
2017 println!("Hello, world!");
2018}}
2019```
2020</files>
2021</context>
2022"#
2023 );
2024
2025 assert_eq!(message.role, Role::User);
2026 assert_eq!(message.segments.len(), 1);
2027 assert_eq!(
2028 message.segments[0],
2029 MessageSegment::Text("Please explain this code".to_string())
2030 );
2031 assert_eq!(message.context, expected_context);
2032
2033 // Check message in request
2034 let request = thread.read_with(cx, |thread, cx| {
2035 thread.to_completion_request(RequestKind::Chat, cx)
2036 });
2037
2038 assert_eq!(request.messages.len(), 2);
2039 let expected_full_message = format!("{}Please explain this code", expected_context);
2040 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2041 }
2042
2043 #[gpui::test]
2044 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2045 init_test_settings(cx);
2046
2047 let project = create_test_project(
2048 cx,
2049 json!({
2050 "file1.rs": "fn function1() {}\n",
2051 "file2.rs": "fn function2() {}\n",
2052 "file3.rs": "fn function3() {}\n",
2053 }),
2054 )
2055 .await;
2056
2057 let (_, _thread_store, thread, context_store) =
2058 setup_test_environment(cx, project.clone()).await;
2059
2060 // Open files individually
2061 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2062 .await
2063 .unwrap();
2064 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2065 .await
2066 .unwrap();
2067 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2068 .await
2069 .unwrap();
2070
2071 // Get the context objects
2072 let contexts = context_store.update(cx, |store, _| store.context().clone());
2073 assert_eq!(contexts.len(), 3);
2074
2075 // First message with context 1
2076 let message1_id = thread.update(cx, |thread, cx| {
2077 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2078 });
2079
2080 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2081 let message2_id = thread.update(cx, |thread, cx| {
2082 thread.insert_user_message(
2083 "Message 2",
2084 vec![contexts[0].clone(), contexts[1].clone()],
2085 None,
2086 cx,
2087 )
2088 });
2089
2090 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2091 let message3_id = thread.update(cx, |thread, cx| {
2092 thread.insert_user_message(
2093 "Message 3",
2094 vec![
2095 contexts[0].clone(),
2096 contexts[1].clone(),
2097 contexts[2].clone(),
2098 ],
2099 None,
2100 cx,
2101 )
2102 });
2103
2104 // Check what contexts are included in each message
2105 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2106 (
2107 thread.message(message1_id).unwrap().clone(),
2108 thread.message(message2_id).unwrap().clone(),
2109 thread.message(message3_id).unwrap().clone(),
2110 )
2111 });
2112
2113 // First message should include context 1
2114 assert!(message1.context.contains("file1.rs"));
2115
2116 // Second message should include only context 2 (not 1)
2117 assert!(!message2.context.contains("file1.rs"));
2118 assert!(message2.context.contains("file2.rs"));
2119
2120 // Third message should include only context 3 (not 1 or 2)
2121 assert!(!message3.context.contains("file1.rs"));
2122 assert!(!message3.context.contains("file2.rs"));
2123 assert!(message3.context.contains("file3.rs"));
2124
2125 // Check entire request to make sure all contexts are properly included
2126 let request = thread.read_with(cx, |thread, cx| {
2127 thread.to_completion_request(RequestKind::Chat, cx)
2128 });
2129
2130 // The request should contain all 3 messages
2131 assert_eq!(request.messages.len(), 4);
2132
2133 // Check that the contexts are properly formatted in each message
2134 assert!(request.messages[1].string_contents().contains("file1.rs"));
2135 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2136 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2137
2138 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2139 assert!(request.messages[2].string_contents().contains("file2.rs"));
2140 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2141
2142 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2143 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2144 assert!(request.messages[3].string_contents().contains("file3.rs"));
2145 }
2146
2147 #[gpui::test]
2148 async fn test_message_without_files(cx: &mut TestAppContext) {
2149 init_test_settings(cx);
2150
2151 let project = create_test_project(
2152 cx,
2153 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2154 )
2155 .await;
2156
2157 let (_, _thread_store, thread, _context_store) =
2158 setup_test_environment(cx, project.clone()).await;
2159
2160 // Insert user message without any context (empty context vector)
2161 let message_id = thread.update(cx, |thread, cx| {
2162 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2163 });
2164
2165 // Check content and context in message object
2166 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2167
2168 // Context should be empty when no files are included
2169 assert_eq!(message.role, Role::User);
2170 assert_eq!(message.segments.len(), 1);
2171 assert_eq!(
2172 message.segments[0],
2173 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2174 );
2175 assert_eq!(message.context, "");
2176
2177 // Check message in request
2178 let request = thread.read_with(cx, |thread, cx| {
2179 thread.to_completion_request(RequestKind::Chat, cx)
2180 });
2181
2182 assert_eq!(request.messages.len(), 2);
2183 assert_eq!(
2184 request.messages[1].string_contents(),
2185 "What is the best way to learn Rust?"
2186 );
2187
2188 // Add second message, also without context
2189 let message2_id = thread.update(cx, |thread, cx| {
2190 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2191 });
2192
2193 let message2 =
2194 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2195 assert_eq!(message2.context, "");
2196
2197 // Check that both messages appear in the request
2198 let request = thread.read_with(cx, |thread, cx| {
2199 thread.to_completion_request(RequestKind::Chat, cx)
2200 });
2201
2202 assert_eq!(request.messages.len(), 3);
2203 assert_eq!(
2204 request.messages[1].string_contents(),
2205 "What is the best way to learn Rust?"
2206 );
2207 assert_eq!(
2208 request.messages[2].string_contents(),
2209 "Are there any good books?"
2210 );
2211 }
2212
2213 #[gpui::test]
2214 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2215 init_test_settings(cx);
2216
2217 let project = create_test_project(
2218 cx,
2219 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2220 )
2221 .await;
2222
2223 let (_workspace, _thread_store, thread, context_store) =
2224 setup_test_environment(cx, project.clone()).await;
2225
2226 // Open buffer and add it to context
2227 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2228 .await
2229 .unwrap();
2230
2231 let context =
2232 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2233
2234 // Insert user message with the buffer as context
2235 thread.update(cx, |thread, cx| {
2236 thread.insert_user_message("Explain this code", vec![context], None, cx)
2237 });
2238
2239 // Create a request and check that it doesn't have a stale buffer warning yet
2240 let initial_request = thread.read_with(cx, |thread, cx| {
2241 thread.to_completion_request(RequestKind::Chat, cx)
2242 });
2243
2244 // Make sure we don't have a stale file warning yet
2245 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2246 msg.string_contents()
2247 .contains("These files changed since last read:")
2248 });
2249 assert!(
2250 !has_stale_warning,
2251 "Should not have stale buffer warning before buffer is modified"
2252 );
2253
2254 // Modify the buffer
2255 buffer.update(cx, |buffer, cx| {
2256 // Find a position at the end of line 1
2257 buffer.edit(
2258 [(1..1, "\n println!(\"Added a new line\");\n")],
2259 None,
2260 cx,
2261 );
2262 });
2263
2264 // Insert another user message without context
2265 thread.update(cx, |thread, cx| {
2266 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2267 });
2268
2269 // Create a new request and check for the stale buffer warning
2270 let new_request = thread.read_with(cx, |thread, cx| {
2271 thread.to_completion_request(RequestKind::Chat, cx)
2272 });
2273
2274 // We should have a stale file warning as the last message
2275 let last_message = new_request
2276 .messages
2277 .last()
2278 .expect("Request should have messages");
2279
2280 // The last message should be the stale buffer notification
2281 assert_eq!(last_message.role, Role::User);
2282
2283 // Check the exact content of the message
2284 let expected_content = "These files changed since last read:\n- code.rs\n";
2285 assert_eq!(
2286 last_message.string_contents(),
2287 expected_content,
2288 "Last message should be exactly the stale buffer notification"
2289 );
2290 }
2291
2292 fn init_test_settings(cx: &mut TestAppContext) {
2293 cx.update(|cx| {
2294 let settings_store = SettingsStore::test(cx);
2295 cx.set_global(settings_store);
2296 language::init(cx);
2297 Project::init_settings(cx);
2298 AssistantSettings::register(cx);
2299 thread_store::init(cx);
2300 workspace::init_settings(cx);
2301 ThemeSettings::register(cx);
2302 ContextServerSettings::register(cx);
2303 EditorSettings::register(cx);
2304 });
2305 }
2306
2307 // Helper to create a test project with test files
2308 async fn create_test_project(
2309 cx: &mut TestAppContext,
2310 files: serde_json::Value,
2311 ) -> Entity<Project> {
2312 let fs = FakeFs::new(cx.executor());
2313 fs.insert_tree(path!("/test"), files).await;
2314 Project::test(fs, [path!("/test").as_ref()], cx).await
2315 }
2316
2317 async fn setup_test_environment(
2318 cx: &mut TestAppContext,
2319 project: Entity<Project>,
2320 ) -> (
2321 Entity<Workspace>,
2322 Entity<ThreadStore>,
2323 Entity<Thread>,
2324 Entity<ContextStore>,
2325 ) {
2326 let (workspace, cx) =
2327 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2328
2329 let thread_store = cx
2330 .update(|_, cx| {
2331 ThreadStore::load(
2332 project.clone(),
2333 Arc::default(),
2334 Arc::new(PromptBuilder::new(None).unwrap()),
2335 cx,
2336 )
2337 })
2338 .await;
2339
2340 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2341 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2342
2343 (workspace, thread_store, thread, context_store)
2344 }
2345
2346 async fn add_file_to_context(
2347 project: &Entity<Project>,
2348 context_store: &Entity<ContextStore>,
2349 path: &str,
2350 cx: &mut TestAppContext,
2351 ) -> Result<Entity<language::Buffer>> {
2352 let buffer_path = project
2353 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2354 .unwrap();
2355
2356 let buffer = project
2357 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2358 .await
2359 .unwrap();
2360
2361 context_store
2362 .update(cx, |store, cx| {
2363 store.add_file_from_buffer(buffer.clone(), cx)
2364 })
2365 .await?;
2366
2367 Ok(buffer)
2368 }
2369}