1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5
6use agent_rules::load_worktree_rules_file;
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use fs::Fs;
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
18use language_model::{
19 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
20 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
21 LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 PaymentRequiredError, Role, StopReason, TokenUsage,
23};
24use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
25use project::{Project, Worktree};
26use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use settings::Settings;
30use util::{ResultExt as _, TryFutureExt as _, post_inc};
31use uuid::Uuid;
32
33use crate::context::{AssistantContext, ContextId, format_context_as_string};
34use crate::thread_store::{
35 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
36 SerializedToolUse,
37};
38use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
39
40#[derive(Debug, Clone, Copy)]
41pub enum RequestKind {
42 Chat,
43 /// Used when summarizing a thread.
44 Summarize,
45}
46
47#[derive(
48 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
49)]
50pub struct ThreadId(Arc<str>);
51
52impl ThreadId {
53 pub fn new() -> Self {
54 Self(Uuid::new_v4().to_string().into())
55 }
56}
57
58impl std::fmt::Display for ThreadId {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "{}", self.0)
61 }
62}
63
64impl From<&str> for ThreadId {
65 fn from(value: &str) -> Self {
66 Self(value.into())
67 }
68}
69
70#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
71pub struct MessageId(pub(crate) usize);
72
73impl MessageId {
74 fn post_inc(&mut self) -> Self {
75 Self(post_inc(&mut self.0))
76 }
77}
78
79/// A message in a [`Thread`].
80#[derive(Debug, Clone)]
81pub struct Message {
82 pub id: MessageId,
83 pub role: Role,
84 pub segments: Vec<MessageSegment>,
85 pub context: String,
86}
87
88impl Message {
89 /// Returns whether the message contains any meaningful text that should be displayed
90 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
91 pub fn should_display_content(&self) -> bool {
92 self.segments.iter().all(|segment| segment.should_display())
93 }
94
95 pub fn push_thinking(&mut self, text: &str) {
96 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
97 segment.push_str(text);
98 } else {
99 self.segments
100 .push(MessageSegment::Thinking(text.to_string()));
101 }
102 }
103
104 pub fn push_text(&mut self, text: &str) {
105 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
106 segment.push_str(text);
107 } else {
108 self.segments.push(MessageSegment::Text(text.to_string()));
109 }
110 }
111
112 pub fn to_string(&self) -> String {
113 let mut result = String::new();
114
115 if !self.context.is_empty() {
116 result.push_str(&self.context);
117 }
118
119 for segment in &self.segments {
120 match segment {
121 MessageSegment::Text(text) => result.push_str(text),
122 MessageSegment::Thinking(text) => {
123 result.push_str("<think>");
124 result.push_str(text);
125 result.push_str("</think>");
126 }
127 }
128 }
129
130 result
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135pub enum MessageSegment {
136 Text(String),
137 Thinking(String),
138}
139
140impl MessageSegment {
141 pub fn text_mut(&mut self) -> &mut String {
142 match self {
143 Self::Text(text) => text,
144 Self::Thinking(text) => text,
145 }
146 }
147
148 pub fn should_display(&self) -> bool {
149 // We add USING_TOOL_MARKER when making a request that includes tool uses
150 // without non-whitespace text around them, and this can cause the model
151 // to mimic the pattern, so we consider those segments not displayable.
152 match self {
153 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
154 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
155 }
156 }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ProjectSnapshot {
161 pub worktree_snapshots: Vec<WorktreeSnapshot>,
162 pub unsaved_buffer_paths: Vec<String>,
163 pub timestamp: DateTime<Utc>,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct WorktreeSnapshot {
168 pub worktree_path: String,
169 pub git_state: Option<GitState>,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct GitState {
174 pub remote_url: Option<String>,
175 pub head_sha: Option<String>,
176 pub current_branch: Option<String>,
177 pub diff: Option<String>,
178}
179
180#[derive(Clone)]
181pub struct ThreadCheckpoint {
182 message_id: MessageId,
183 git_checkpoint: GitStoreCheckpoint,
184}
185
186#[derive(Copy, Clone, Debug, PartialEq, Eq)]
187pub enum ThreadFeedback {
188 Positive,
189 Negative,
190}
191
192pub enum LastRestoreCheckpoint {
193 Pending {
194 message_id: MessageId,
195 },
196 Error {
197 message_id: MessageId,
198 error: String,
199 },
200}
201
202impl LastRestoreCheckpoint {
203 pub fn message_id(&self) -> MessageId {
204 match self {
205 LastRestoreCheckpoint::Pending { message_id } => *message_id,
206 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
207 }
208 }
209}
210
211#[derive(Clone, Debug, Default, Serialize, Deserialize)]
212pub enum DetailedSummaryState {
213 #[default]
214 NotGenerated,
215 Generating {
216 message_id: MessageId,
217 },
218 Generated {
219 text: SharedString,
220 message_id: MessageId,
221 },
222}
223
224#[derive(Default)]
225pub struct TotalTokenUsage {
226 pub total: usize,
227 pub max: usize,
228 pub ratio: TokenUsageRatio,
229}
230
231#[derive(Default, PartialEq, Eq)]
232pub enum TokenUsageRatio {
233 #[default]
234 Normal,
235 Warning,
236 Exceeded,
237}
238
239/// A thread of conversation with the LLM.
240pub struct Thread {
241 id: ThreadId,
242 updated_at: DateTime<Utc>,
243 summary: Option<SharedString>,
244 pending_summary: Task<Option<()>>,
245 detailed_summary_state: DetailedSummaryState,
246 messages: Vec<Message>,
247 next_message_id: MessageId,
248 context: BTreeMap<ContextId, AssistantContext>,
249 context_by_message: HashMap<MessageId, Vec<ContextId>>,
250 system_prompt_context: Option<AssistantSystemPromptContext>,
251 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
252 completion_count: usize,
253 pending_completions: Vec<PendingCompletion>,
254 project: Entity<Project>,
255 prompt_builder: Arc<PromptBuilder>,
256 tools: Arc<ToolWorkingSet>,
257 tool_use: ToolUseState,
258 action_log: Entity<ActionLog>,
259 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
260 pending_checkpoint: Option<ThreadCheckpoint>,
261 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
262 cumulative_token_usage: TokenUsage,
263 feedback: Option<ThreadFeedback>,
264 message_feedback: HashMap<MessageId, ThreadFeedback>,
265}
266
267impl Thread {
268 pub fn new(
269 project: Entity<Project>,
270 tools: Arc<ToolWorkingSet>,
271 prompt_builder: Arc<PromptBuilder>,
272 cx: &mut Context<Self>,
273 ) -> Self {
274 Self {
275 id: ThreadId::new(),
276 updated_at: Utc::now(),
277 summary: None,
278 pending_summary: Task::ready(None),
279 detailed_summary_state: DetailedSummaryState::NotGenerated,
280 messages: Vec::new(),
281 next_message_id: MessageId(0),
282 context: BTreeMap::default(),
283 context_by_message: HashMap::default(),
284 system_prompt_context: None,
285 checkpoints_by_message: HashMap::default(),
286 completion_count: 0,
287 pending_completions: Vec::new(),
288 project: project.clone(),
289 prompt_builder,
290 tools: tools.clone(),
291 last_restore_checkpoint: None,
292 pending_checkpoint: None,
293 tool_use: ToolUseState::new(tools.clone()),
294 action_log: cx.new(|_| ActionLog::new(project.clone())),
295 initial_project_snapshot: {
296 let project_snapshot = Self::project_snapshot(project, cx);
297 cx.foreground_executor()
298 .spawn(async move { Some(project_snapshot.await) })
299 .shared()
300 },
301 cumulative_token_usage: TokenUsage::default(),
302 feedback: None,
303 message_feedback: HashMap::default(),
304 }
305 }
306
307 pub fn deserialize(
308 id: ThreadId,
309 serialized: SerializedThread,
310 project: Entity<Project>,
311 tools: Arc<ToolWorkingSet>,
312 prompt_builder: Arc<PromptBuilder>,
313 cx: &mut Context<Self>,
314 ) -> Self {
315 let next_message_id = MessageId(
316 serialized
317 .messages
318 .last()
319 .map(|message| message.id.0 + 1)
320 .unwrap_or(0),
321 );
322 let tool_use =
323 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
324
325 Self {
326 id,
327 updated_at: serialized.updated_at,
328 summary: Some(serialized.summary),
329 pending_summary: Task::ready(None),
330 detailed_summary_state: serialized.detailed_summary_state,
331 messages: serialized
332 .messages
333 .into_iter()
334 .map(|message| Message {
335 id: message.id,
336 role: message.role,
337 segments: message
338 .segments
339 .into_iter()
340 .map(|segment| match segment {
341 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
342 SerializedMessageSegment::Thinking { text } => {
343 MessageSegment::Thinking(text)
344 }
345 })
346 .collect(),
347 context: message.context,
348 })
349 .collect(),
350 next_message_id,
351 context: BTreeMap::default(),
352 context_by_message: HashMap::default(),
353 system_prompt_context: None,
354 checkpoints_by_message: HashMap::default(),
355 completion_count: 0,
356 pending_completions: Vec::new(),
357 last_restore_checkpoint: None,
358 pending_checkpoint: None,
359 project: project.clone(),
360 prompt_builder,
361 tools,
362 tool_use,
363 action_log: cx.new(|_| ActionLog::new(project)),
364 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
365 cumulative_token_usage: serialized.cumulative_token_usage,
366 feedback: None,
367 message_feedback: HashMap::default(),
368 }
369 }
370
371 pub fn id(&self) -> &ThreadId {
372 &self.id
373 }
374
375 pub fn is_empty(&self) -> bool {
376 self.messages.is_empty()
377 }
378
379 pub fn updated_at(&self) -> DateTime<Utc> {
380 self.updated_at
381 }
382
383 pub fn touch_updated_at(&mut self) {
384 self.updated_at = Utc::now();
385 }
386
387 pub fn summary(&self) -> Option<SharedString> {
388 self.summary.clone()
389 }
390
391 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
392
393 pub fn summary_or_default(&self) -> SharedString {
394 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
395 }
396
397 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
398 let Some(current_summary) = &self.summary else {
399 // Don't allow setting summary until generated
400 return;
401 };
402
403 let mut new_summary = new_summary.into();
404
405 if new_summary.is_empty() {
406 new_summary = Self::DEFAULT_SUMMARY;
407 }
408
409 if current_summary != &new_summary {
410 self.summary = Some(new_summary);
411 cx.emit(ThreadEvent::SummaryChanged);
412 }
413 }
414
415 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
416 self.latest_detailed_summary()
417 .unwrap_or_else(|| self.text().into())
418 }
419
420 fn latest_detailed_summary(&self) -> Option<SharedString> {
421 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
422 Some(text.clone())
423 } else {
424 None
425 }
426 }
427
428 pub fn message(&self, id: MessageId) -> Option<&Message> {
429 self.messages.iter().find(|message| message.id == id)
430 }
431
432 pub fn messages(&self) -> impl Iterator<Item = &Message> {
433 self.messages.iter()
434 }
435
436 pub fn is_generating(&self) -> bool {
437 !self.pending_completions.is_empty() || !self.all_tools_finished()
438 }
439
440 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
441 &self.tools
442 }
443
444 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
445 self.tool_use
446 .pending_tool_uses()
447 .into_iter()
448 .find(|tool_use| &tool_use.id == id)
449 }
450
451 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
452 self.tool_use
453 .pending_tool_uses()
454 .into_iter()
455 .filter(|tool_use| tool_use.status.needs_confirmation())
456 }
457
458 pub fn has_pending_tool_uses(&self) -> bool {
459 !self.tool_use.pending_tool_uses().is_empty()
460 }
461
462 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
463 self.checkpoints_by_message.get(&id).cloned()
464 }
465
466 pub fn restore_checkpoint(
467 &mut self,
468 checkpoint: ThreadCheckpoint,
469 cx: &mut Context<Self>,
470 ) -> Task<Result<()>> {
471 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
472 message_id: checkpoint.message_id,
473 });
474 cx.emit(ThreadEvent::CheckpointChanged);
475 cx.notify();
476
477 let git_store = self.project().read(cx).git_store().clone();
478 let restore = git_store.update(cx, |git_store, cx| {
479 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
480 });
481
482 cx.spawn(async move |this, cx| {
483 let result = restore.await;
484 this.update(cx, |this, cx| {
485 if let Err(err) = result.as_ref() {
486 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
487 message_id: checkpoint.message_id,
488 error: err.to_string(),
489 });
490 } else {
491 this.truncate(checkpoint.message_id, cx);
492 this.last_restore_checkpoint = None;
493 }
494 this.pending_checkpoint = None;
495 cx.emit(ThreadEvent::CheckpointChanged);
496 cx.notify();
497 })?;
498 result
499 })
500 }
501
502 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
503 let pending_checkpoint = if self.is_generating() {
504 return;
505 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
506 checkpoint
507 } else {
508 return;
509 };
510
511 let git_store = self.project.read(cx).git_store().clone();
512 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
513 cx.spawn(async move |this, cx| match final_checkpoint.await {
514 Ok(final_checkpoint) => {
515 let equal = git_store
516 .update(cx, |store, cx| {
517 store.compare_checkpoints(
518 pending_checkpoint.git_checkpoint.clone(),
519 final_checkpoint.clone(),
520 cx,
521 )
522 })?
523 .await
524 .unwrap_or(false);
525
526 if equal {
527 git_store
528 .update(cx, |store, cx| {
529 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
530 })?
531 .detach();
532 } else {
533 this.update(cx, |this, cx| {
534 this.insert_checkpoint(pending_checkpoint, cx)
535 })?;
536 }
537
538 git_store
539 .update(cx, |store, cx| {
540 store.delete_checkpoint(final_checkpoint, cx)
541 })?
542 .detach();
543
544 Ok(())
545 }
546 Err(_) => this.update(cx, |this, cx| {
547 this.insert_checkpoint(pending_checkpoint, cx)
548 }),
549 })
550 .detach();
551 }
552
553 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
554 self.checkpoints_by_message
555 .insert(checkpoint.message_id, checkpoint);
556 cx.emit(ThreadEvent::CheckpointChanged);
557 cx.notify();
558 }
559
560 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
561 self.last_restore_checkpoint.as_ref()
562 }
563
564 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
565 let Some(message_ix) = self
566 .messages
567 .iter()
568 .rposition(|message| message.id == message_id)
569 else {
570 return;
571 };
572 for deleted_message in self.messages.drain(message_ix..) {
573 self.context_by_message.remove(&deleted_message.id);
574 self.checkpoints_by_message.remove(&deleted_message.id);
575 }
576 cx.notify();
577 }
578
579 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
580 self.context_by_message
581 .get(&id)
582 .into_iter()
583 .flat_map(|context| {
584 context
585 .iter()
586 .filter_map(|context_id| self.context.get(&context_id))
587 })
588 }
589
590 /// Returns whether all of the tool uses have finished running.
591 pub fn all_tools_finished(&self) -> bool {
592 // If the only pending tool uses left are the ones with errors, then
593 // that means that we've finished running all of the pending tools.
594 self.tool_use
595 .pending_tool_uses()
596 .iter()
597 .all(|tool_use| tool_use.status.is_error())
598 }
599
600 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
601 self.tool_use.tool_uses_for_message(id, cx)
602 }
603
604 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
605 self.tool_use.tool_results_for_message(id)
606 }
607
608 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
609 self.tool_use.tool_result(id)
610 }
611
612 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
613 self.tool_use.message_has_tool_results(message_id)
614 }
615
616 pub fn insert_user_message(
617 &mut self,
618 text: impl Into<String>,
619 context: Vec<AssistantContext>,
620 git_checkpoint: Option<GitStoreCheckpoint>,
621 cx: &mut Context<Self>,
622 ) -> MessageId {
623 let text = text.into();
624
625 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
626
627 // Filter out contexts that have already been included in previous messages
628 let new_context: Vec<_> = context
629 .into_iter()
630 .filter(|ctx| !self.context.contains_key(&ctx.id()))
631 .collect();
632
633 if !new_context.is_empty() {
634 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
635 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
636 message.context = context_string;
637 }
638 }
639
640 self.action_log.update(cx, |log, cx| {
641 // Track all buffers added as context
642 for ctx in &new_context {
643 match ctx {
644 AssistantContext::File(file_ctx) => {
645 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
646 }
647 AssistantContext::Directory(dir_ctx) => {
648 for context_buffer in &dir_ctx.context_buffers {
649 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
650 }
651 }
652 AssistantContext::Symbol(symbol_ctx) => {
653 log.buffer_added_as_context(
654 symbol_ctx.context_symbol.buffer.clone(),
655 cx,
656 );
657 }
658 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
659 }
660 }
661 });
662 }
663
664 let context_ids = new_context
665 .iter()
666 .map(|context| context.id())
667 .collect::<Vec<_>>();
668 self.context.extend(
669 new_context
670 .into_iter()
671 .map(|context| (context.id(), context)),
672 );
673 self.context_by_message.insert(message_id, context_ids);
674
675 if let Some(git_checkpoint) = git_checkpoint {
676 self.pending_checkpoint = Some(ThreadCheckpoint {
677 message_id,
678 git_checkpoint,
679 });
680 }
681
682 self.auto_capture_telemetry(cx);
683
684 message_id
685 }
686
687 pub fn insert_message(
688 &mut self,
689 role: Role,
690 segments: Vec<MessageSegment>,
691 cx: &mut Context<Self>,
692 ) -> MessageId {
693 let id = self.next_message_id.post_inc();
694 self.messages.push(Message {
695 id,
696 role,
697 segments,
698 context: String::new(),
699 });
700 self.touch_updated_at();
701 cx.emit(ThreadEvent::MessageAdded(id));
702 id
703 }
704
705 pub fn edit_message(
706 &mut self,
707 id: MessageId,
708 new_role: Role,
709 new_segments: Vec<MessageSegment>,
710 cx: &mut Context<Self>,
711 ) -> bool {
712 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
713 return false;
714 };
715 message.role = new_role;
716 message.segments = new_segments;
717 self.touch_updated_at();
718 cx.emit(ThreadEvent::MessageEdited(id));
719 true
720 }
721
722 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
723 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
724 return false;
725 };
726 self.messages.remove(index);
727 self.context_by_message.remove(&id);
728 self.touch_updated_at();
729 cx.emit(ThreadEvent::MessageDeleted(id));
730 true
731 }
732
733 /// Returns the representation of this [`Thread`] in a textual form.
734 ///
735 /// This is the representation we use when attaching a thread as context to another thread.
736 pub fn text(&self) -> String {
737 let mut text = String::new();
738
739 for message in &self.messages {
740 text.push_str(match message.role {
741 language_model::Role::User => "User:",
742 language_model::Role::Assistant => "Assistant:",
743 language_model::Role::System => "System:",
744 });
745 text.push('\n');
746
747 for segment in &message.segments {
748 match segment {
749 MessageSegment::Text(content) => text.push_str(content),
750 MessageSegment::Thinking(content) => {
751 text.push_str(&format!("<think>{}</think>", content))
752 }
753 }
754 }
755 text.push('\n');
756 }
757
758 text
759 }
760
761 /// Serializes this thread into a format for storage or telemetry.
762 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
763 let initial_project_snapshot = self.initial_project_snapshot.clone();
764 cx.spawn(async move |this, cx| {
765 let initial_project_snapshot = initial_project_snapshot.await;
766 this.read_with(cx, |this, cx| SerializedThread {
767 version: SerializedThread::VERSION.to_string(),
768 summary: this.summary_or_default(),
769 updated_at: this.updated_at(),
770 messages: this
771 .messages()
772 .map(|message| SerializedMessage {
773 id: message.id,
774 role: message.role,
775 segments: message
776 .segments
777 .iter()
778 .map(|segment| match segment {
779 MessageSegment::Text(text) => {
780 SerializedMessageSegment::Text { text: text.clone() }
781 }
782 MessageSegment::Thinking(text) => {
783 SerializedMessageSegment::Thinking { text: text.clone() }
784 }
785 })
786 .collect(),
787 tool_uses: this
788 .tool_uses_for_message(message.id, cx)
789 .into_iter()
790 .map(|tool_use| SerializedToolUse {
791 id: tool_use.id,
792 name: tool_use.name,
793 input: tool_use.input,
794 })
795 .collect(),
796 tool_results: this
797 .tool_results_for_message(message.id)
798 .into_iter()
799 .map(|tool_result| SerializedToolResult {
800 tool_use_id: tool_result.tool_use_id.clone(),
801 is_error: tool_result.is_error,
802 content: tool_result.content.clone(),
803 })
804 .collect(),
805 context: message.context.clone(),
806 })
807 .collect(),
808 initial_project_snapshot,
809 cumulative_token_usage: this.cumulative_token_usage.clone(),
810 detailed_summary_state: this.detailed_summary_state.clone(),
811 })
812 })
813 }
814
815 pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
816 self.system_prompt_context = Some(context);
817 }
818
819 pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
820 &self.system_prompt_context
821 }
822
823 pub fn load_system_prompt_context(
824 &self,
825 cx: &App,
826 ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
827 let project = self.project.read(cx);
828 let tasks = project
829 .visible_worktrees(cx)
830 .map(|worktree| {
831 Self::load_worktree_info_for_system_prompt(
832 project.fs().clone(),
833 worktree.read(cx),
834 cx,
835 )
836 })
837 .collect::<Vec<_>>();
838
839 cx.spawn(async |_cx| {
840 let results = futures::future::join_all(tasks).await;
841 let mut first_err = None;
842 let worktrees = results
843 .into_iter()
844 .map(|(worktree, err)| {
845 if first_err.is_none() && err.is_some() {
846 first_err = err;
847 }
848 worktree
849 })
850 .collect::<Vec<_>>();
851 (AssistantSystemPromptContext::new(worktrees), first_err)
852 })
853 }
854
855 fn load_worktree_info_for_system_prompt(
856 fs: Arc<dyn Fs>,
857 worktree: &Worktree,
858 cx: &App,
859 ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
860 let root_name = worktree.root_name().into();
861 let abs_path = worktree.abs_path();
862
863 let rules_task = load_worktree_rules_file(fs, worktree, cx);
864 let Some(rules_task) = rules_task else {
865 return Task::ready((
866 WorktreeInfoForSystemPrompt {
867 root_name,
868 abs_path,
869 rules_file: None,
870 },
871 None,
872 ));
873 };
874
875 cx.spawn(async move |_| {
876 let (rules_file, rules_file_error) = match rules_task.await {
877 Ok(rules_file) => (Some(rules_file), None),
878 Err(err) => (
879 None,
880 Some(ThreadError::Message {
881 header: "Error loading rules file".into(),
882 message: format!("{err}").into(),
883 }),
884 ),
885 };
886 let worktree_info = WorktreeInfoForSystemPrompt {
887 root_name,
888 abs_path,
889 rules_file,
890 };
891 (worktree_info, rules_file_error)
892 })
893 }
894
895 pub fn send_to_model(
896 &mut self,
897 model: Arc<dyn LanguageModel>,
898 request_kind: RequestKind,
899 cx: &mut Context<Self>,
900 ) {
901 let mut request = self.to_completion_request(request_kind, cx);
902 if model.supports_tools() {
903 request.tools = {
904 let mut tools = Vec::new();
905 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
906 LanguageModelRequestTool {
907 name: tool.name(),
908 description: tool.description(),
909 input_schema: tool.input_schema(model.tool_input_format()),
910 }
911 }));
912
913 tools
914 };
915 }
916
917 self.stream_completion(request, model, cx);
918 }
919
920 pub fn used_tools_since_last_user_message(&self) -> bool {
921 for message in self.messages.iter().rev() {
922 if self.tool_use.message_has_tool_results(message.id) {
923 return true;
924 } else if message.role == Role::User {
925 return false;
926 }
927 }
928
929 false
930 }
931
932 pub fn to_completion_request(
933 &self,
934 request_kind: RequestKind,
935 cx: &App,
936 ) -> LanguageModelRequest {
937 let mut request = LanguageModelRequest {
938 messages: vec![],
939 tools: Vec::new(),
940 stop: Vec::new(),
941 temperature: None,
942 };
943
944 if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
945 if let Some(system_prompt) = self
946 .prompt_builder
947 .generate_assistant_system_prompt(system_prompt_context)
948 .context("failed to generate assistant system prompt")
949 .log_err()
950 {
951 request.messages.push(LanguageModelRequestMessage {
952 role: Role::System,
953 content: vec![MessageContent::Text(system_prompt)],
954 cache: true,
955 });
956 }
957 } else {
958 log::error!("system_prompt_context not set.")
959 }
960
961 for message in &self.messages {
962 let mut request_message = LanguageModelRequestMessage {
963 role: message.role,
964 content: Vec::new(),
965 cache: false,
966 };
967
968 match request_kind {
969 RequestKind::Chat => {
970 self.tool_use
971 .attach_tool_results(message.id, &mut request_message);
972 }
973 RequestKind::Summarize => {
974 // We don't care about tool use during summarization.
975 if self.tool_use.message_has_tool_results(message.id) {
976 continue;
977 }
978 }
979 }
980
981 if !message.segments.is_empty() {
982 request_message
983 .content
984 .push(MessageContent::Text(message.to_string()));
985 }
986
987 match request_kind {
988 RequestKind::Chat => {
989 self.tool_use
990 .attach_tool_uses(message.id, &mut request_message);
991 }
992 RequestKind::Summarize => {
993 // We don't care about tool use during summarization.
994 }
995 };
996
997 request.messages.push(request_message);
998 }
999
1000 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1001 if let Some(last) = request.messages.last_mut() {
1002 last.cache = true;
1003 }
1004
1005 self.attached_tracked_files_state(&mut request.messages, cx);
1006
1007 request
1008 }
1009
1010 fn attached_tracked_files_state(
1011 &self,
1012 messages: &mut Vec<LanguageModelRequestMessage>,
1013 cx: &App,
1014 ) {
1015 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1016
1017 let mut stale_message = String::new();
1018
1019 let action_log = self.action_log.read(cx);
1020
1021 for stale_file in action_log.stale_buffers(cx) {
1022 let Some(file) = stale_file.read(cx).file() else {
1023 continue;
1024 };
1025
1026 if stale_message.is_empty() {
1027 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1028 }
1029
1030 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1031 }
1032
1033 let mut content = Vec::with_capacity(2);
1034
1035 if !stale_message.is_empty() {
1036 content.push(stale_message.into());
1037 }
1038
1039 if action_log.has_edited_files_since_project_diagnostics_check() {
1040 content.push(
1041 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1042 and fix all errors AND warnings you introduced! \
1043 DO NOT mention you're going to do this until you're done."
1044 .into(),
1045 );
1046 }
1047
1048 if !content.is_empty() {
1049 let context_message = LanguageModelRequestMessage {
1050 role: Role::User,
1051 content,
1052 cache: false,
1053 };
1054
1055 messages.push(context_message);
1056 }
1057 }
1058
1059 pub fn stream_completion(
1060 &mut self,
1061 request: LanguageModelRequest,
1062 model: Arc<dyn LanguageModel>,
1063 cx: &mut Context<Self>,
1064 ) {
1065 let pending_completion_id = post_inc(&mut self.completion_count);
1066
1067 let task = cx.spawn(async move |thread, cx| {
1068 let stream = model.stream_completion(request, &cx);
1069 let initial_token_usage =
1070 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1071 let stream_completion = async {
1072 let mut events = stream.await?;
1073 let mut stop_reason = StopReason::EndTurn;
1074 let mut current_token_usage = TokenUsage::default();
1075
1076 while let Some(event) = events.next().await {
1077 let event = event?;
1078
1079 thread.update(cx, |thread, cx| {
1080 match event {
1081 LanguageModelCompletionEvent::StartMessage { .. } => {
1082 thread.insert_message(
1083 Role::Assistant,
1084 vec![MessageSegment::Text(String::new())],
1085 cx,
1086 );
1087 }
1088 LanguageModelCompletionEvent::Stop(reason) => {
1089 stop_reason = reason;
1090 }
1091 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1092 thread.cumulative_token_usage =
1093 thread.cumulative_token_usage.clone() + token_usage.clone()
1094 - current_token_usage.clone();
1095 current_token_usage = token_usage;
1096 }
1097 LanguageModelCompletionEvent::Text(chunk) => {
1098 if let Some(last_message) = thread.messages.last_mut() {
1099 if last_message.role == Role::Assistant {
1100 last_message.push_text(&chunk);
1101 cx.emit(ThreadEvent::StreamedAssistantText(
1102 last_message.id,
1103 chunk,
1104 ));
1105 } else {
1106 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1107 // of a new Assistant response.
1108 //
1109 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1110 // will result in duplicating the text of the chunk in the rendered Markdown.
1111 thread.insert_message(
1112 Role::Assistant,
1113 vec![MessageSegment::Text(chunk.to_string())],
1114 cx,
1115 );
1116 };
1117 }
1118 }
1119 LanguageModelCompletionEvent::Thinking(chunk) => {
1120 if let Some(last_message) = thread.messages.last_mut() {
1121 if last_message.role == Role::Assistant {
1122 last_message.push_thinking(&chunk);
1123 cx.emit(ThreadEvent::StreamedAssistantThinking(
1124 last_message.id,
1125 chunk,
1126 ));
1127 } else {
1128 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1129 // of a new Assistant response.
1130 //
1131 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1132 // will result in duplicating the text of the chunk in the rendered Markdown.
1133 thread.insert_message(
1134 Role::Assistant,
1135 vec![MessageSegment::Thinking(chunk.to_string())],
1136 cx,
1137 );
1138 };
1139 }
1140 }
1141 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1142 let last_assistant_message_id = thread
1143 .messages
1144 .iter_mut()
1145 .rfind(|message| message.role == Role::Assistant)
1146 .map(|message| message.id)
1147 .unwrap_or_else(|| {
1148 thread.insert_message(Role::Assistant, vec![], cx)
1149 });
1150
1151 thread.tool_use.request_tool_use(
1152 last_assistant_message_id,
1153 tool_use,
1154 cx,
1155 );
1156 }
1157 }
1158
1159 thread.touch_updated_at();
1160 cx.emit(ThreadEvent::StreamedCompletion);
1161 cx.notify();
1162
1163 thread.auto_capture_telemetry(cx);
1164 })?;
1165
1166 smol::future::yield_now().await;
1167 }
1168
1169 thread.update(cx, |thread, cx| {
1170 thread
1171 .pending_completions
1172 .retain(|completion| completion.id != pending_completion_id);
1173
1174 if thread.summary.is_none() && thread.messages.len() >= 2 {
1175 thread.summarize(cx);
1176 }
1177 })?;
1178
1179 anyhow::Ok(stop_reason)
1180 };
1181
1182 let result = stream_completion.await;
1183
1184 thread
1185 .update(cx, |thread, cx| {
1186 thread.finalize_pending_checkpoint(cx);
1187 match result.as_ref() {
1188 Ok(stop_reason) => match stop_reason {
1189 StopReason::ToolUse => {
1190 let tool_uses = thread.use_pending_tools(cx);
1191 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1192 }
1193 StopReason::EndTurn => {}
1194 StopReason::MaxTokens => {}
1195 },
1196 Err(error) => {
1197 if error.is::<PaymentRequiredError>() {
1198 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1199 } else if error.is::<MaxMonthlySpendReachedError>() {
1200 cx.emit(ThreadEvent::ShowError(
1201 ThreadError::MaxMonthlySpendReached,
1202 ));
1203 } else {
1204 let error_message = error
1205 .chain()
1206 .map(|err| err.to_string())
1207 .collect::<Vec<_>>()
1208 .join("\n");
1209 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1210 header: "Error interacting with language model".into(),
1211 message: SharedString::from(error_message.clone()),
1212 }));
1213 }
1214
1215 thread.cancel_last_completion(cx);
1216 }
1217 }
1218 cx.emit(ThreadEvent::DoneStreaming);
1219
1220 thread.auto_capture_telemetry(cx);
1221
1222 if let Ok(initial_usage) = initial_token_usage {
1223 let usage = thread.cumulative_token_usage.clone() - initial_usage;
1224
1225 telemetry::event!(
1226 "Assistant Thread Completion",
1227 thread_id = thread.id().to_string(),
1228 model = model.telemetry_id(),
1229 model_provider = model.provider_id().to_string(),
1230 input_tokens = usage.input_tokens,
1231 output_tokens = usage.output_tokens,
1232 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1233 cache_read_input_tokens = usage.cache_read_input_tokens,
1234 );
1235 }
1236 })
1237 .ok();
1238 });
1239
1240 self.pending_completions.push(PendingCompletion {
1241 id: pending_completion_id,
1242 _task: task,
1243 });
1244 }
1245
1246 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1247 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1248 return;
1249 };
1250
1251 if !model.provider.is_authenticated(cx) {
1252 return;
1253 }
1254
1255 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1256 request.messages.push(LanguageModelRequestMessage {
1257 role: Role::User,
1258 content: vec![
1259 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1260 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1261 If the conversation is about a specific subject, include it in the title. \
1262 Be descriptive. DO NOT speak in the first person."
1263 .into(),
1264 ],
1265 cache: false,
1266 });
1267
1268 self.pending_summary = cx.spawn(async move |this, cx| {
1269 async move {
1270 let stream = model.model.stream_completion_text(request, &cx);
1271 let mut messages = stream.await?;
1272
1273 let mut new_summary = String::new();
1274 while let Some(message) = messages.stream.next().await {
1275 let text = message?;
1276 let mut lines = text.lines();
1277 new_summary.extend(lines.next());
1278
1279 // Stop if the LLM generated multiple lines.
1280 if lines.next().is_some() {
1281 break;
1282 }
1283 }
1284
1285 this.update(cx, |this, cx| {
1286 if !new_summary.is_empty() {
1287 this.summary = Some(new_summary.into());
1288 }
1289
1290 cx.emit(ThreadEvent::SummaryGenerated);
1291 })?;
1292
1293 anyhow::Ok(())
1294 }
1295 .log_err()
1296 .await
1297 });
1298 }
1299
1300 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1301 let last_message_id = self.messages.last().map(|message| message.id)?;
1302
1303 match &self.detailed_summary_state {
1304 DetailedSummaryState::Generating { message_id, .. }
1305 | DetailedSummaryState::Generated { message_id, .. }
1306 if *message_id == last_message_id =>
1307 {
1308 // Already up-to-date
1309 return None;
1310 }
1311 _ => {}
1312 }
1313
1314 let ConfiguredModel { model, provider } =
1315 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1316
1317 if !provider.is_authenticated(cx) {
1318 return None;
1319 }
1320
1321 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1322
1323 request.messages.push(LanguageModelRequestMessage {
1324 role: Role::User,
1325 content: vec![
1326 "Generate a detailed summary of this conversation. Include:\n\
1327 1. A brief overview of what was discussed\n\
1328 2. Key facts or information discovered\n\
1329 3. Outcomes or conclusions reached\n\
1330 4. Any action items or next steps if any\n\
1331 Format it in Markdown with headings and bullet points."
1332 .into(),
1333 ],
1334 cache: false,
1335 });
1336
1337 let task = cx.spawn(async move |thread, cx| {
1338 let stream = model.stream_completion_text(request, &cx);
1339 let Some(mut messages) = stream.await.log_err() else {
1340 thread
1341 .update(cx, |this, _cx| {
1342 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1343 })
1344 .log_err();
1345
1346 return;
1347 };
1348
1349 let mut new_detailed_summary = String::new();
1350
1351 while let Some(chunk) = messages.stream.next().await {
1352 if let Some(chunk) = chunk.log_err() {
1353 new_detailed_summary.push_str(&chunk);
1354 }
1355 }
1356
1357 thread
1358 .update(cx, |this, _cx| {
1359 this.detailed_summary_state = DetailedSummaryState::Generated {
1360 text: new_detailed_summary.into(),
1361 message_id: last_message_id,
1362 };
1363 })
1364 .log_err();
1365 });
1366
1367 self.detailed_summary_state = DetailedSummaryState::Generating {
1368 message_id: last_message_id,
1369 };
1370
1371 Some(task)
1372 }
1373
1374 pub fn is_generating_detailed_summary(&self) -> bool {
1375 matches!(
1376 self.detailed_summary_state,
1377 DetailedSummaryState::Generating { .. }
1378 )
1379 }
1380
1381 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1382 self.auto_capture_telemetry(cx);
1383 let request = self.to_completion_request(RequestKind::Chat, cx);
1384 let messages = Arc::new(request.messages);
1385 let pending_tool_uses = self
1386 .tool_use
1387 .pending_tool_uses()
1388 .into_iter()
1389 .filter(|tool_use| tool_use.status.is_idle())
1390 .cloned()
1391 .collect::<Vec<_>>();
1392
1393 for tool_use in pending_tool_uses.iter() {
1394 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1395 if tool.needs_confirmation(&tool_use.input, cx)
1396 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1397 {
1398 self.tool_use.confirm_tool_use(
1399 tool_use.id.clone(),
1400 tool_use.ui_text.clone(),
1401 tool_use.input.clone(),
1402 messages.clone(),
1403 tool,
1404 );
1405 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1406 } else {
1407 self.run_tool(
1408 tool_use.id.clone(),
1409 tool_use.ui_text.clone(),
1410 tool_use.input.clone(),
1411 &messages,
1412 tool,
1413 cx,
1414 );
1415 }
1416 }
1417 }
1418
1419 pending_tool_uses
1420 }
1421
1422 pub fn run_tool(
1423 &mut self,
1424 tool_use_id: LanguageModelToolUseId,
1425 ui_text: impl Into<SharedString>,
1426 input: serde_json::Value,
1427 messages: &[LanguageModelRequestMessage],
1428 tool: Arc<dyn Tool>,
1429 cx: &mut Context<Thread>,
1430 ) {
1431 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1432 self.tool_use
1433 .run_pending_tool(tool_use_id, ui_text.into(), task);
1434 }
1435
1436 fn spawn_tool_use(
1437 &mut self,
1438 tool_use_id: LanguageModelToolUseId,
1439 messages: &[LanguageModelRequestMessage],
1440 input: serde_json::Value,
1441 tool: Arc<dyn Tool>,
1442 cx: &mut Context<Thread>,
1443 ) -> Task<()> {
1444 let tool_name: Arc<str> = tool.name().into();
1445
1446 let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1447 Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1448 } else {
1449 tool.run(
1450 input,
1451 messages,
1452 self.project.clone(),
1453 self.action_log.clone(),
1454 cx,
1455 )
1456 };
1457
1458 cx.spawn({
1459 async move |thread: WeakEntity<Thread>, cx| {
1460 let output = run_tool.await;
1461
1462 thread
1463 .update(cx, |thread, cx| {
1464 let pending_tool_use = thread.tool_use.insert_tool_output(
1465 tool_use_id.clone(),
1466 tool_name,
1467 output,
1468 cx,
1469 );
1470 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1471 })
1472 .ok();
1473 }
1474 })
1475 }
1476
1477 fn tool_finished(
1478 &mut self,
1479 tool_use_id: LanguageModelToolUseId,
1480 pending_tool_use: Option<PendingToolUse>,
1481 canceled: bool,
1482 cx: &mut Context<Self>,
1483 ) {
1484 if self.all_tools_finished() {
1485 let model_registry = LanguageModelRegistry::read_global(cx);
1486 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1487 self.attach_tool_results(cx);
1488 if !canceled {
1489 self.send_to_model(model, RequestKind::Chat, cx);
1490 }
1491 }
1492 }
1493
1494 cx.emit(ThreadEvent::ToolFinished {
1495 tool_use_id,
1496 pending_tool_use,
1497 });
1498 }
1499
1500 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1501 // Insert a user message to contain the tool results.
1502 self.insert_user_message(
1503 // TODO: Sending up a user message without any content results in the model sending back
1504 // responses that also don't have any content. We currently don't handle this case well,
1505 // so for now we provide some text to keep the model on track.
1506 "Here are the tool results.",
1507 Vec::new(),
1508 None,
1509 cx,
1510 );
1511 }
1512
1513 /// Cancels the last pending completion, if there are any pending.
1514 ///
1515 /// Returns whether a completion was canceled.
1516 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1517 let canceled = if self.pending_completions.pop().is_some() {
1518 true
1519 } else {
1520 let mut canceled = false;
1521 for pending_tool_use in self.tool_use.cancel_pending() {
1522 canceled = true;
1523 self.tool_finished(
1524 pending_tool_use.id.clone(),
1525 Some(pending_tool_use),
1526 true,
1527 cx,
1528 );
1529 }
1530 canceled
1531 };
1532 self.finalize_pending_checkpoint(cx);
1533 canceled
1534 }
1535
1536 pub fn feedback(&self) -> Option<ThreadFeedback> {
1537 self.feedback
1538 }
1539
1540 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1541 self.message_feedback.get(&message_id).copied()
1542 }
1543
1544 pub fn report_message_feedback(
1545 &mut self,
1546 message_id: MessageId,
1547 feedback: ThreadFeedback,
1548 cx: &mut Context<Self>,
1549 ) -> Task<Result<()>> {
1550 if self.message_feedback.get(&message_id) == Some(&feedback) {
1551 return Task::ready(Ok(()));
1552 }
1553
1554 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1555 let serialized_thread = self.serialize(cx);
1556 let thread_id = self.id().clone();
1557 let client = self.project.read(cx).client();
1558
1559 let enabled_tool_names: Vec<String> = self
1560 .tools()
1561 .enabled_tools(cx)
1562 .iter()
1563 .map(|tool| tool.name().to_string())
1564 .collect();
1565
1566 self.message_feedback.insert(message_id, feedback);
1567
1568 cx.notify();
1569
1570 let message_content = self
1571 .message(message_id)
1572 .map(|msg| msg.to_string())
1573 .unwrap_or_default();
1574
1575 cx.background_spawn(async move {
1576 let final_project_snapshot = final_project_snapshot.await;
1577 let serialized_thread = serialized_thread.await?;
1578 let thread_data =
1579 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1580
1581 let rating = match feedback {
1582 ThreadFeedback::Positive => "positive",
1583 ThreadFeedback::Negative => "negative",
1584 };
1585 telemetry::event!(
1586 "Assistant Thread Rated",
1587 rating,
1588 thread_id,
1589 enabled_tool_names,
1590 message_id = message_id.0,
1591 message_content,
1592 thread_data,
1593 final_project_snapshot
1594 );
1595 client.telemetry().flush_events();
1596
1597 Ok(())
1598 })
1599 }
1600
1601 pub fn report_feedback(
1602 &mut self,
1603 feedback: ThreadFeedback,
1604 cx: &mut Context<Self>,
1605 ) -> Task<Result<()>> {
1606 let last_assistant_message_id = self
1607 .messages
1608 .iter()
1609 .rev()
1610 .find(|msg| msg.role == Role::Assistant)
1611 .map(|msg| msg.id);
1612
1613 if let Some(message_id) = last_assistant_message_id {
1614 self.report_message_feedback(message_id, feedback, cx)
1615 } else {
1616 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1617 let serialized_thread = self.serialize(cx);
1618 let thread_id = self.id().clone();
1619 let client = self.project.read(cx).client();
1620 self.feedback = Some(feedback);
1621 cx.notify();
1622
1623 cx.background_spawn(async move {
1624 let final_project_snapshot = final_project_snapshot.await;
1625 let serialized_thread = serialized_thread.await?;
1626 let thread_data = serde_json::to_value(serialized_thread)
1627 .unwrap_or_else(|_| serde_json::Value::Null);
1628
1629 let rating = match feedback {
1630 ThreadFeedback::Positive => "positive",
1631 ThreadFeedback::Negative => "negative",
1632 };
1633 telemetry::event!(
1634 "Assistant Thread Rated",
1635 rating,
1636 thread_id,
1637 thread_data,
1638 final_project_snapshot
1639 );
1640 client.telemetry().flush_events();
1641
1642 Ok(())
1643 })
1644 }
1645 }
1646
1647 /// Create a snapshot of the current project state including git information and unsaved buffers.
1648 fn project_snapshot(
1649 project: Entity<Project>,
1650 cx: &mut Context<Self>,
1651 ) -> Task<Arc<ProjectSnapshot>> {
1652 let git_store = project.read(cx).git_store().clone();
1653 let worktree_snapshots: Vec<_> = project
1654 .read(cx)
1655 .visible_worktrees(cx)
1656 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1657 .collect();
1658
1659 cx.spawn(async move |_, cx| {
1660 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1661
1662 let mut unsaved_buffers = Vec::new();
1663 cx.update(|app_cx| {
1664 let buffer_store = project.read(app_cx).buffer_store();
1665 for buffer_handle in buffer_store.read(app_cx).buffers() {
1666 let buffer = buffer_handle.read(app_cx);
1667 if buffer.is_dirty() {
1668 if let Some(file) = buffer.file() {
1669 let path = file.path().to_string_lossy().to_string();
1670 unsaved_buffers.push(path);
1671 }
1672 }
1673 }
1674 })
1675 .ok();
1676
1677 Arc::new(ProjectSnapshot {
1678 worktree_snapshots,
1679 unsaved_buffer_paths: unsaved_buffers,
1680 timestamp: Utc::now(),
1681 })
1682 })
1683 }
1684
1685 fn worktree_snapshot(
1686 worktree: Entity<project::Worktree>,
1687 git_store: Entity<GitStore>,
1688 cx: &App,
1689 ) -> Task<WorktreeSnapshot> {
1690 cx.spawn(async move |cx| {
1691 // Get worktree path and snapshot
1692 let worktree_info = cx.update(|app_cx| {
1693 let worktree = worktree.read(app_cx);
1694 let path = worktree.abs_path().to_string_lossy().to_string();
1695 let snapshot = worktree.snapshot();
1696 (path, snapshot)
1697 });
1698
1699 let Ok((worktree_path, _snapshot)) = worktree_info else {
1700 return WorktreeSnapshot {
1701 worktree_path: String::new(),
1702 git_state: None,
1703 };
1704 };
1705
1706 let git_state = git_store
1707 .update(cx, |git_store, cx| {
1708 git_store
1709 .repositories()
1710 .values()
1711 .find(|repo| {
1712 repo.read(cx)
1713 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1714 .is_some()
1715 })
1716 .cloned()
1717 })
1718 .ok()
1719 .flatten()
1720 .map(|repo| {
1721 repo.update(cx, |repo, _| {
1722 let current_branch =
1723 repo.branch.as_ref().map(|branch| branch.name.to_string());
1724 repo.send_job(None, |state, _| async move {
1725 let RepositoryState::Local { backend, .. } = state else {
1726 return GitState {
1727 remote_url: None,
1728 head_sha: None,
1729 current_branch,
1730 diff: None,
1731 };
1732 };
1733
1734 let remote_url = backend.remote_url("origin");
1735 let head_sha = backend.head_sha();
1736 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1737
1738 GitState {
1739 remote_url,
1740 head_sha,
1741 current_branch,
1742 diff,
1743 }
1744 })
1745 })
1746 });
1747
1748 let git_state = match git_state {
1749 Some(git_state) => match git_state.ok() {
1750 Some(git_state) => git_state.await.ok(),
1751 None => None,
1752 },
1753 None => None,
1754 };
1755
1756 WorktreeSnapshot {
1757 worktree_path,
1758 git_state,
1759 }
1760 })
1761 }
1762
1763 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1764 let mut markdown = Vec::new();
1765
1766 if let Some(summary) = self.summary() {
1767 writeln!(markdown, "# {summary}\n")?;
1768 };
1769
1770 for message in self.messages() {
1771 writeln!(
1772 markdown,
1773 "## {role}\n",
1774 role = match message.role {
1775 Role::User => "User",
1776 Role::Assistant => "Assistant",
1777 Role::System => "System",
1778 }
1779 )?;
1780
1781 if !message.context.is_empty() {
1782 writeln!(markdown, "{}", message.context)?;
1783 }
1784
1785 for segment in &message.segments {
1786 match segment {
1787 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1788 MessageSegment::Thinking(text) => {
1789 writeln!(markdown, "<think>{}</think>\n", text)?
1790 }
1791 }
1792 }
1793
1794 for tool_use in self.tool_uses_for_message(message.id, cx) {
1795 writeln!(
1796 markdown,
1797 "**Use Tool: {} ({})**",
1798 tool_use.name, tool_use.id
1799 )?;
1800 writeln!(markdown, "```json")?;
1801 writeln!(
1802 markdown,
1803 "{}",
1804 serde_json::to_string_pretty(&tool_use.input)?
1805 )?;
1806 writeln!(markdown, "```")?;
1807 }
1808
1809 for tool_result in self.tool_results_for_message(message.id) {
1810 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1811 if tool_result.is_error {
1812 write!(markdown, " (Error)")?;
1813 }
1814
1815 writeln!(markdown, "**\n")?;
1816 writeln!(markdown, "{}", tool_result.content)?;
1817 }
1818 }
1819
1820 Ok(String::from_utf8_lossy(&markdown).to_string())
1821 }
1822
1823 pub fn keep_edits_in_range(
1824 &mut self,
1825 buffer: Entity<language::Buffer>,
1826 buffer_range: Range<language::Anchor>,
1827 cx: &mut Context<Self>,
1828 ) {
1829 self.action_log.update(cx, |action_log, cx| {
1830 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1831 });
1832 }
1833
1834 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1835 self.action_log
1836 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1837 }
1838
1839 pub fn reject_edits_in_range(
1840 &mut self,
1841 buffer: Entity<language::Buffer>,
1842 buffer_range: Range<language::Anchor>,
1843 cx: &mut Context<Self>,
1844 ) -> Task<Result<()>> {
1845 self.action_log.update(cx, |action_log, cx| {
1846 action_log.reject_edits_in_range(buffer, buffer_range, cx)
1847 })
1848 }
1849
1850 pub fn action_log(&self) -> &Entity<ActionLog> {
1851 &self.action_log
1852 }
1853
1854 pub fn project(&self) -> &Entity<Project> {
1855 &self.project
1856 }
1857
1858 pub fn cumulative_token_usage(&self) -> TokenUsage {
1859 self.cumulative_token_usage.clone()
1860 }
1861
1862 pub fn auto_capture_telemetry(&self, cx: &mut Context<Self>) {
1863 static mut LAST_CAPTURE: Option<std::time::Instant> = None;
1864 let now = std::time::Instant::now();
1865 let should_check = unsafe {
1866 if let Some(last) = LAST_CAPTURE {
1867 if now.duration_since(last).as_secs() < 10 {
1868 return;
1869 }
1870 }
1871 LAST_CAPTURE = Some(now);
1872 true
1873 };
1874
1875 if !should_check {
1876 return;
1877 }
1878
1879 let feature_flag_enabled = cx.has_flag::<feature_flags::ThreadAutoCapture>();
1880
1881 if cfg!(debug_assertions) {
1882 if !feature_flag_enabled {
1883 return;
1884 }
1885 }
1886
1887 let thread_id = self.id().clone();
1888
1889 let github_handle = self
1890 .project
1891 .read(cx)
1892 .user_store()
1893 .read(cx)
1894 .current_user()
1895 .map(|user| user.github_login.clone());
1896
1897 let client = self.project.read(cx).client().clone();
1898
1899 let serialized_thread = self.serialize(cx);
1900
1901 cx.foreground_executor()
1902 .spawn(async move {
1903 if let Ok(serialized_thread) = serialized_thread.await {
1904 let thread_data = serde_json::to_value(serialized_thread)
1905 .unwrap_or_else(|_| serde_json::Value::Null);
1906
1907 telemetry::event!(
1908 "Agent Thread AutoCaptured",
1909 thread_id = thread_id.to_string(),
1910 thread_data = thread_data,
1911 auto_capture_reason = "tracked_user",
1912 github_handle = github_handle
1913 );
1914
1915 client.telemetry().flush_events();
1916 }
1917 })
1918 .detach();
1919 }
1920
1921 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1922 let model_registry = LanguageModelRegistry::read_global(cx);
1923 let Some(model) = model_registry.default_model() else {
1924 return TotalTokenUsage::default();
1925 };
1926
1927 let max = model.model.max_token_count();
1928
1929 #[cfg(debug_assertions)]
1930 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1931 .unwrap_or("0.8".to_string())
1932 .parse()
1933 .unwrap();
1934 #[cfg(not(debug_assertions))]
1935 let warning_threshold: f32 = 0.8;
1936
1937 let total = self.cumulative_token_usage.total_tokens() as usize;
1938
1939 let ratio = if total >= max {
1940 TokenUsageRatio::Exceeded
1941 } else if total as f32 / max as f32 >= warning_threshold {
1942 TokenUsageRatio::Warning
1943 } else {
1944 TokenUsageRatio::Normal
1945 };
1946
1947 TotalTokenUsage { total, max, ratio }
1948 }
1949
1950 pub fn deny_tool_use(
1951 &mut self,
1952 tool_use_id: LanguageModelToolUseId,
1953 tool_name: Arc<str>,
1954 cx: &mut Context<Self>,
1955 ) {
1956 let err = Err(anyhow::anyhow!(
1957 "Permission to run tool action denied by user"
1958 ));
1959
1960 self.tool_use
1961 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1962 self.tool_finished(tool_use_id.clone(), None, true, cx);
1963 }
1964}
1965
1966#[derive(Debug, Clone)]
1967pub enum ThreadError {
1968 PaymentRequired,
1969 MaxMonthlySpendReached,
1970 Message {
1971 header: SharedString,
1972 message: SharedString,
1973 },
1974}
1975
1976#[derive(Debug, Clone)]
1977pub enum ThreadEvent {
1978 ShowError(ThreadError),
1979 StreamedCompletion,
1980 StreamedAssistantText(MessageId, String),
1981 StreamedAssistantThinking(MessageId, String),
1982 DoneStreaming,
1983 MessageAdded(MessageId),
1984 MessageEdited(MessageId),
1985 MessageDeleted(MessageId),
1986 SummaryGenerated,
1987 SummaryChanged,
1988 UsePendingTools {
1989 tool_uses: Vec<PendingToolUse>,
1990 },
1991 ToolFinished {
1992 #[allow(unused)]
1993 tool_use_id: LanguageModelToolUseId,
1994 /// The pending tool use that corresponds to this tool.
1995 pending_tool_use: Option<PendingToolUse>,
1996 },
1997 CheckpointChanged,
1998 ToolConfirmationNeeded,
1999}
2000
2001impl EventEmitter<ThreadEvent> for Thread {}
2002
2003struct PendingCompletion {
2004 id: usize,
2005 _task: Task<()>,
2006}
2007
2008#[cfg(test)]
2009mod tests {
2010 use super::*;
2011 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2012 use assistant_settings::AssistantSettings;
2013 use context_server::ContextServerSettings;
2014 use editor::EditorSettings;
2015 use gpui::TestAppContext;
2016 use project::{FakeFs, Project};
2017 use prompt_store::PromptBuilder;
2018 use serde_json::json;
2019 use settings::{Settings, SettingsStore};
2020 use std::sync::Arc;
2021 use theme::ThemeSettings;
2022 use util::path;
2023 use workspace::Workspace;
2024
2025 #[gpui::test]
2026 async fn test_message_with_context(cx: &mut TestAppContext) {
2027 init_test_settings(cx);
2028
2029 let project = create_test_project(
2030 cx,
2031 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2032 )
2033 .await;
2034
2035 let (_workspace, _thread_store, thread, context_store) =
2036 setup_test_environment(cx, project.clone()).await;
2037
2038 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2039 .await
2040 .unwrap();
2041
2042 let context =
2043 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2044
2045 // Insert user message with context
2046 let message_id = thread.update(cx, |thread, cx| {
2047 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2048 });
2049
2050 // Check content and context in message object
2051 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2052
2053 // Use different path format strings based on platform for the test
2054 #[cfg(windows)]
2055 let path_part = r"test\code.rs";
2056 #[cfg(not(windows))]
2057 let path_part = "test/code.rs";
2058
2059 let expected_context = format!(
2060 r#"
2061<context>
2062The following items were attached by the user. You don't need to use other tools to read them.
2063
2064<files>
2065```rs {path_part}
2066fn main() {{
2067 println!("Hello, world!");
2068}}
2069```
2070</files>
2071</context>
2072"#
2073 );
2074
2075 assert_eq!(message.role, Role::User);
2076 assert_eq!(message.segments.len(), 1);
2077 assert_eq!(
2078 message.segments[0],
2079 MessageSegment::Text("Please explain this code".to_string())
2080 );
2081 assert_eq!(message.context, expected_context);
2082
2083 // Check message in request
2084 let request = thread.read_with(cx, |thread, cx| {
2085 thread.to_completion_request(RequestKind::Chat, cx)
2086 });
2087
2088 assert_eq!(request.messages.len(), 1);
2089 let expected_full_message = format!("{}Please explain this code", expected_context);
2090 assert_eq!(request.messages[0].string_contents(), expected_full_message);
2091 }
2092
2093 #[gpui::test]
2094 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2095 init_test_settings(cx);
2096
2097 let project = create_test_project(
2098 cx,
2099 json!({
2100 "file1.rs": "fn function1() {}\n",
2101 "file2.rs": "fn function2() {}\n",
2102 "file3.rs": "fn function3() {}\n",
2103 }),
2104 )
2105 .await;
2106
2107 let (_, _thread_store, thread, context_store) =
2108 setup_test_environment(cx, project.clone()).await;
2109
2110 // Open files individually
2111 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2112 .await
2113 .unwrap();
2114 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2115 .await
2116 .unwrap();
2117 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2118 .await
2119 .unwrap();
2120
2121 // Get the context objects
2122 let contexts = context_store.update(cx, |store, _| store.context().clone());
2123 assert_eq!(contexts.len(), 3);
2124
2125 // First message with context 1
2126 let message1_id = thread.update(cx, |thread, cx| {
2127 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2128 });
2129
2130 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2131 let message2_id = thread.update(cx, |thread, cx| {
2132 thread.insert_user_message(
2133 "Message 2",
2134 vec![contexts[0].clone(), contexts[1].clone()],
2135 None,
2136 cx,
2137 )
2138 });
2139
2140 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2141 let message3_id = thread.update(cx, |thread, cx| {
2142 thread.insert_user_message(
2143 "Message 3",
2144 vec![
2145 contexts[0].clone(),
2146 contexts[1].clone(),
2147 contexts[2].clone(),
2148 ],
2149 None,
2150 cx,
2151 )
2152 });
2153
2154 // Check what contexts are included in each message
2155 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2156 (
2157 thread.message(message1_id).unwrap().clone(),
2158 thread.message(message2_id).unwrap().clone(),
2159 thread.message(message3_id).unwrap().clone(),
2160 )
2161 });
2162
2163 // First message should include context 1
2164 assert!(message1.context.contains("file1.rs"));
2165
2166 // Second message should include only context 2 (not 1)
2167 assert!(!message2.context.contains("file1.rs"));
2168 assert!(message2.context.contains("file2.rs"));
2169
2170 // Third message should include only context 3 (not 1 or 2)
2171 assert!(!message3.context.contains("file1.rs"));
2172 assert!(!message3.context.contains("file2.rs"));
2173 assert!(message3.context.contains("file3.rs"));
2174
2175 // Check entire request to make sure all contexts are properly included
2176 let request = thread.read_with(cx, |thread, cx| {
2177 thread.to_completion_request(RequestKind::Chat, cx)
2178 });
2179
2180 // The request should contain all 3 messages
2181 assert_eq!(request.messages.len(), 3);
2182
2183 // Check that the contexts are properly formatted in each message
2184 assert!(request.messages[0].string_contents().contains("file1.rs"));
2185 assert!(!request.messages[0].string_contents().contains("file2.rs"));
2186 assert!(!request.messages[0].string_contents().contains("file3.rs"));
2187
2188 assert!(!request.messages[1].string_contents().contains("file1.rs"));
2189 assert!(request.messages[1].string_contents().contains("file2.rs"));
2190 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2191
2192 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2193 assert!(!request.messages[2].string_contents().contains("file2.rs"));
2194 assert!(request.messages[2].string_contents().contains("file3.rs"));
2195 }
2196
2197 #[gpui::test]
2198 async fn test_message_without_files(cx: &mut TestAppContext) {
2199 init_test_settings(cx);
2200
2201 let project = create_test_project(
2202 cx,
2203 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2204 )
2205 .await;
2206
2207 let (_, _thread_store, thread, _context_store) =
2208 setup_test_environment(cx, project.clone()).await;
2209
2210 // Insert user message without any context (empty context vector)
2211 let message_id = thread.update(cx, |thread, cx| {
2212 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2213 });
2214
2215 // Check content and context in message object
2216 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2217
2218 // Context should be empty when no files are included
2219 assert_eq!(message.role, Role::User);
2220 assert_eq!(message.segments.len(), 1);
2221 assert_eq!(
2222 message.segments[0],
2223 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2224 );
2225 assert_eq!(message.context, "");
2226
2227 // Check message in request
2228 let request = thread.read_with(cx, |thread, cx| {
2229 thread.to_completion_request(RequestKind::Chat, cx)
2230 });
2231
2232 assert_eq!(request.messages.len(), 1);
2233 assert_eq!(
2234 request.messages[0].string_contents(),
2235 "What is the best way to learn Rust?"
2236 );
2237
2238 // Add second message, also without context
2239 let message2_id = thread.update(cx, |thread, cx| {
2240 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2241 });
2242
2243 let message2 =
2244 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2245 assert_eq!(message2.context, "");
2246
2247 // Check that both messages appear in the request
2248 let request = thread.read_with(cx, |thread, cx| {
2249 thread.to_completion_request(RequestKind::Chat, cx)
2250 });
2251
2252 assert_eq!(request.messages.len(), 2);
2253 assert_eq!(
2254 request.messages[0].string_contents(),
2255 "What is the best way to learn Rust?"
2256 );
2257 assert_eq!(
2258 request.messages[1].string_contents(),
2259 "Are there any good books?"
2260 );
2261 }
2262
2263 #[gpui::test]
2264 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2265 init_test_settings(cx);
2266
2267 let project = create_test_project(
2268 cx,
2269 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2270 )
2271 .await;
2272
2273 let (_workspace, _thread_store, thread, context_store) =
2274 setup_test_environment(cx, project.clone()).await;
2275
2276 // Open buffer and add it to context
2277 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2278 .await
2279 .unwrap();
2280
2281 let context =
2282 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2283
2284 // Insert user message with the buffer as context
2285 thread.update(cx, |thread, cx| {
2286 thread.insert_user_message("Explain this code", vec![context], None, cx)
2287 });
2288
2289 // Create a request and check that it doesn't have a stale buffer warning yet
2290 let initial_request = thread.read_with(cx, |thread, cx| {
2291 thread.to_completion_request(RequestKind::Chat, cx)
2292 });
2293
2294 // Make sure we don't have a stale file warning yet
2295 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2296 msg.string_contents()
2297 .contains("These files changed since last read:")
2298 });
2299 assert!(
2300 !has_stale_warning,
2301 "Should not have stale buffer warning before buffer is modified"
2302 );
2303
2304 // Modify the buffer
2305 buffer.update(cx, |buffer, cx| {
2306 // Find a position at the end of line 1
2307 buffer.edit(
2308 [(1..1, "\n println!(\"Added a new line\");\n")],
2309 None,
2310 cx,
2311 );
2312 });
2313
2314 // Insert another user message without context
2315 thread.update(cx, |thread, cx| {
2316 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2317 });
2318
2319 // Create a new request and check for the stale buffer warning
2320 let new_request = thread.read_with(cx, |thread, cx| {
2321 thread.to_completion_request(RequestKind::Chat, cx)
2322 });
2323
2324 // We should have a stale file warning as the last message
2325 let last_message = new_request
2326 .messages
2327 .last()
2328 .expect("Request should have messages");
2329
2330 // The last message should be the stale buffer notification
2331 assert_eq!(last_message.role, Role::User);
2332
2333 // Check the exact content of the message
2334 let expected_content = "These files changed since last read:\n- code.rs\n";
2335 assert_eq!(
2336 last_message.string_contents(),
2337 expected_content,
2338 "Last message should be exactly the stale buffer notification"
2339 );
2340 }
2341
2342 fn init_test_settings(cx: &mut TestAppContext) {
2343 cx.update(|cx| {
2344 let settings_store = SettingsStore::test(cx);
2345 cx.set_global(settings_store);
2346 language::init(cx);
2347 Project::init_settings(cx);
2348 AssistantSettings::register(cx);
2349 thread_store::init(cx);
2350 workspace::init_settings(cx);
2351 ThemeSettings::register(cx);
2352 ContextServerSettings::register(cx);
2353 EditorSettings::register(cx);
2354 });
2355 }
2356
2357 // Helper to create a test project with test files
2358 async fn create_test_project(
2359 cx: &mut TestAppContext,
2360 files: serde_json::Value,
2361 ) -> Entity<Project> {
2362 let fs = FakeFs::new(cx.executor());
2363 fs.insert_tree(path!("/test"), files).await;
2364 Project::test(fs, [path!("/test").as_ref()], cx).await
2365 }
2366
2367 async fn setup_test_environment(
2368 cx: &mut TestAppContext,
2369 project: Entity<Project>,
2370 ) -> (
2371 Entity<Workspace>,
2372 Entity<ThreadStore>,
2373 Entity<Thread>,
2374 Entity<ContextStore>,
2375 ) {
2376 let (workspace, cx) =
2377 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2378
2379 let thread_store = cx.update(|_, cx| {
2380 ThreadStore::new(
2381 project.clone(),
2382 Arc::default(),
2383 Arc::new(PromptBuilder::new(None).unwrap()),
2384 cx,
2385 )
2386 .unwrap()
2387 });
2388
2389 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2390 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2391
2392 (workspace, thread_store, thread, context_store)
2393 }
2394
2395 async fn add_file_to_context(
2396 project: &Entity<Project>,
2397 context_store: &Entity<ContextStore>,
2398 path: &str,
2399 cx: &mut TestAppContext,
2400 ) -> Result<Entity<language::Buffer>> {
2401 let buffer_path = project
2402 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2403 .unwrap();
2404
2405 let buffer = project
2406 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2407 .await
2408 .unwrap();
2409
2410 context_store
2411 .update(cx, |store, cx| {
2412 store.add_file_from_buffer(buffer.clone(), cx)
2413 })
2414 .await?;
2415
2416 Ok(buffer)
2417 }
2418}