1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5
6use agent_rules::load_worktree_rules_file;
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use fs::Fs;
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
19 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
20 LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
21 PaymentRequiredError, Role, StopReason, TokenUsage,
22};
23use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
24use project::{Project, Worktree};
25use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use settings::Settings;
29use util::{ResultExt as _, TryFutureExt as _, post_inc};
30use uuid::Uuid;
31
32use crate::context::{AssistantContext, ContextId, format_context_as_string};
33use crate::thread_store::{
34 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
35 SerializedToolUse,
36};
37use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
38
39#[derive(Debug, Clone, Copy)]
40pub enum RequestKind {
41 Chat,
42 /// Used when summarizing a thread.
43 Summarize,
44}
45
46#[derive(
47 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
48)]
49pub struct ThreadId(Arc<str>);
50
51impl ThreadId {
52 pub fn new() -> Self {
53 Self(Uuid::new_v4().to_string().into())
54 }
55}
56
57impl std::fmt::Display for ThreadId {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63impl From<&str> for ThreadId {
64 fn from(value: &str) -> Self {
65 Self(value.into())
66 }
67}
68
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
70pub struct MessageId(pub(crate) usize);
71
72impl MessageId {
73 fn post_inc(&mut self) -> Self {
74 Self(post_inc(&mut self.0))
75 }
76}
77
78/// A message in a [`Thread`].
79#[derive(Debug, Clone)]
80pub struct Message {
81 pub id: MessageId,
82 pub role: Role,
83 pub segments: Vec<MessageSegment>,
84 pub context: String,
85}
86
87impl Message {
88 /// Returns whether the message contains any meaningful text that should be displayed
89 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
90 pub fn should_display_content(&self) -> bool {
91 self.segments.iter().all(|segment| segment.should_display())
92 }
93
94 pub fn push_thinking(&mut self, text: &str) {
95 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
96 segment.push_str(text);
97 } else {
98 self.segments
99 .push(MessageSegment::Thinking(text.to_string()));
100 }
101 }
102
103 pub fn push_text(&mut self, text: &str) {
104 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
105 segment.push_str(text);
106 } else {
107 self.segments.push(MessageSegment::Text(text.to_string()));
108 }
109 }
110
111 pub fn to_string(&self) -> String {
112 let mut result = String::new();
113
114 if !self.context.is_empty() {
115 result.push_str(&self.context);
116 }
117
118 for segment in &self.segments {
119 match segment {
120 MessageSegment::Text(text) => result.push_str(text),
121 MessageSegment::Thinking(text) => {
122 result.push_str("<think>");
123 result.push_str(text);
124 result.push_str("</think>");
125 }
126 }
127 }
128
129 result
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum MessageSegment {
135 Text(String),
136 Thinking(String),
137}
138
139impl MessageSegment {
140 pub fn text_mut(&mut self) -> &mut String {
141 match self {
142 Self::Text(text) => text,
143 Self::Thinking(text) => text,
144 }
145 }
146
147 pub fn should_display(&self) -> bool {
148 // We add USING_TOOL_MARKER when making a request that includes tool uses
149 // without non-whitespace text around them, and this can cause the model
150 // to mimic the pattern, so we consider those segments not displayable.
151 match self {
152 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
153 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ProjectSnapshot {
160 pub worktree_snapshots: Vec<WorktreeSnapshot>,
161 pub unsaved_buffer_paths: Vec<String>,
162 pub timestamp: DateTime<Utc>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct WorktreeSnapshot {
167 pub worktree_path: String,
168 pub git_state: Option<GitState>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct GitState {
173 pub remote_url: Option<String>,
174 pub head_sha: Option<String>,
175 pub current_branch: Option<String>,
176 pub diff: Option<String>,
177}
178
179#[derive(Clone)]
180pub struct ThreadCheckpoint {
181 message_id: MessageId,
182 git_checkpoint: GitStoreCheckpoint,
183}
184
185#[derive(Copy, Clone, Debug, PartialEq, Eq)]
186pub enum ThreadFeedback {
187 Positive,
188 Negative,
189}
190
191pub enum LastRestoreCheckpoint {
192 Pending {
193 message_id: MessageId,
194 },
195 Error {
196 message_id: MessageId,
197 error: String,
198 },
199}
200
201impl LastRestoreCheckpoint {
202 pub fn message_id(&self) -> MessageId {
203 match self {
204 LastRestoreCheckpoint::Pending { message_id } => *message_id,
205 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
206 }
207 }
208}
209
210#[derive(Clone, Debug, Default, Serialize, Deserialize)]
211pub enum DetailedSummaryState {
212 #[default]
213 NotGenerated,
214 Generating {
215 message_id: MessageId,
216 },
217 Generated {
218 text: SharedString,
219 message_id: MessageId,
220 },
221}
222
223#[derive(Default)]
224pub struct TotalTokenUsage {
225 pub total: usize,
226 pub max: usize,
227 pub ratio: TokenUsageRatio,
228}
229
230#[derive(Default, PartialEq, Eq)]
231pub enum TokenUsageRatio {
232 #[default]
233 Normal,
234 Warning,
235 Exceeded,
236}
237
238/// A thread of conversation with the LLM.
239pub struct Thread {
240 id: ThreadId,
241 updated_at: DateTime<Utc>,
242 summary: Option<SharedString>,
243 pending_summary: Task<Option<()>>,
244 detailed_summary_state: DetailedSummaryState,
245 messages: Vec<Message>,
246 next_message_id: MessageId,
247 context: BTreeMap<ContextId, AssistantContext>,
248 context_by_message: HashMap<MessageId, Vec<ContextId>>,
249 system_prompt_context: Option<AssistantSystemPromptContext>,
250 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
251 completion_count: usize,
252 pending_completions: Vec<PendingCompletion>,
253 project: Entity<Project>,
254 prompt_builder: Arc<PromptBuilder>,
255 tools: Arc<ToolWorkingSet>,
256 tool_use: ToolUseState,
257 action_log: Entity<ActionLog>,
258 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
259 pending_checkpoint: Option<ThreadCheckpoint>,
260 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
261 cumulative_token_usage: TokenUsage,
262 feedback: Option<ThreadFeedback>,
263 message_feedback: HashMap<MessageId, ThreadFeedback>,
264}
265
266impl Thread {
267 pub fn new(
268 project: Entity<Project>,
269 tools: Arc<ToolWorkingSet>,
270 prompt_builder: Arc<PromptBuilder>,
271 cx: &mut Context<Self>,
272 ) -> Self {
273 Self {
274 id: ThreadId::new(),
275 updated_at: Utc::now(),
276 summary: None,
277 pending_summary: Task::ready(None),
278 detailed_summary_state: DetailedSummaryState::NotGenerated,
279 messages: Vec::new(),
280 next_message_id: MessageId(0),
281 context: BTreeMap::default(),
282 context_by_message: HashMap::default(),
283 system_prompt_context: None,
284 checkpoints_by_message: HashMap::default(),
285 completion_count: 0,
286 pending_completions: Vec::new(),
287 project: project.clone(),
288 prompt_builder,
289 tools: tools.clone(),
290 last_restore_checkpoint: None,
291 pending_checkpoint: None,
292 tool_use: ToolUseState::new(tools.clone()),
293 action_log: cx.new(|_| ActionLog::new(project.clone())),
294 initial_project_snapshot: {
295 let project_snapshot = Self::project_snapshot(project, cx);
296 cx.foreground_executor()
297 .spawn(async move { Some(project_snapshot.await) })
298 .shared()
299 },
300 cumulative_token_usage: TokenUsage::default(),
301 feedback: None,
302 message_feedback: HashMap::default(),
303 }
304 }
305
306 pub fn deserialize(
307 id: ThreadId,
308 serialized: SerializedThread,
309 project: Entity<Project>,
310 tools: Arc<ToolWorkingSet>,
311 prompt_builder: Arc<PromptBuilder>,
312 cx: &mut Context<Self>,
313 ) -> Self {
314 let next_message_id = MessageId(
315 serialized
316 .messages
317 .last()
318 .map(|message| message.id.0 + 1)
319 .unwrap_or(0),
320 );
321 let tool_use =
322 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
323
324 Self {
325 id,
326 updated_at: serialized.updated_at,
327 summary: Some(serialized.summary),
328 pending_summary: Task::ready(None),
329 detailed_summary_state: serialized.detailed_summary_state,
330 messages: serialized
331 .messages
332 .into_iter()
333 .map(|message| Message {
334 id: message.id,
335 role: message.role,
336 segments: message
337 .segments
338 .into_iter()
339 .map(|segment| match segment {
340 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
341 SerializedMessageSegment::Thinking { text } => {
342 MessageSegment::Thinking(text)
343 }
344 })
345 .collect(),
346 context: message.context,
347 })
348 .collect(),
349 next_message_id,
350 context: BTreeMap::default(),
351 context_by_message: HashMap::default(),
352 system_prompt_context: None,
353 checkpoints_by_message: HashMap::default(),
354 completion_count: 0,
355 pending_completions: Vec::new(),
356 last_restore_checkpoint: None,
357 pending_checkpoint: None,
358 project: project.clone(),
359 prompt_builder,
360 tools,
361 tool_use,
362 action_log: cx.new(|_| ActionLog::new(project)),
363 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
364 cumulative_token_usage: serialized.cumulative_token_usage,
365 feedback: None,
366 message_feedback: HashMap::default(),
367 }
368 }
369
370 pub fn id(&self) -> &ThreadId {
371 &self.id
372 }
373
374 pub fn is_empty(&self) -> bool {
375 self.messages.is_empty()
376 }
377
378 pub fn updated_at(&self) -> DateTime<Utc> {
379 self.updated_at
380 }
381
382 pub fn touch_updated_at(&mut self) {
383 self.updated_at = Utc::now();
384 }
385
386 pub fn summary(&self) -> Option<SharedString> {
387 self.summary.clone()
388 }
389
390 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
391
392 pub fn summary_or_default(&self) -> SharedString {
393 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
394 }
395
396 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
397 let Some(current_summary) = &self.summary else {
398 // Don't allow setting summary until generated
399 return;
400 };
401
402 let mut new_summary = new_summary.into();
403
404 if new_summary.is_empty() {
405 new_summary = Self::DEFAULT_SUMMARY;
406 }
407
408 if current_summary != &new_summary {
409 self.summary = Some(new_summary);
410 cx.emit(ThreadEvent::SummaryChanged);
411 }
412 }
413
414 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
415 self.latest_detailed_summary()
416 .unwrap_or_else(|| self.text().into())
417 }
418
419 fn latest_detailed_summary(&self) -> Option<SharedString> {
420 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
421 Some(text.clone())
422 } else {
423 None
424 }
425 }
426
427 pub fn message(&self, id: MessageId) -> Option<&Message> {
428 self.messages.iter().find(|message| message.id == id)
429 }
430
431 pub fn messages(&self) -> impl Iterator<Item = &Message> {
432 self.messages.iter()
433 }
434
435 pub fn is_generating(&self) -> bool {
436 !self.pending_completions.is_empty() || !self.all_tools_finished()
437 }
438
439 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
440 &self.tools
441 }
442
443 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
444 self.tool_use
445 .pending_tool_uses()
446 .into_iter()
447 .find(|tool_use| &tool_use.id == id)
448 }
449
450 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
451 self.tool_use
452 .pending_tool_uses()
453 .into_iter()
454 .filter(|tool_use| tool_use.status.needs_confirmation())
455 }
456
457 pub fn has_pending_tool_uses(&self) -> bool {
458 !self.tool_use.pending_tool_uses().is_empty()
459 }
460
461 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
462 self.checkpoints_by_message.get(&id).cloned()
463 }
464
465 pub fn restore_checkpoint(
466 &mut self,
467 checkpoint: ThreadCheckpoint,
468 cx: &mut Context<Self>,
469 ) -> Task<Result<()>> {
470 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
471 message_id: checkpoint.message_id,
472 });
473 cx.emit(ThreadEvent::CheckpointChanged);
474 cx.notify();
475
476 let git_store = self.project().read(cx).git_store().clone();
477 let restore = git_store.update(cx, |git_store, cx| {
478 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
479 });
480
481 cx.spawn(async move |this, cx| {
482 let result = restore.await;
483 this.update(cx, |this, cx| {
484 if let Err(err) = result.as_ref() {
485 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
486 message_id: checkpoint.message_id,
487 error: err.to_string(),
488 });
489 } else {
490 this.truncate(checkpoint.message_id, cx);
491 this.last_restore_checkpoint = None;
492 }
493 this.pending_checkpoint = None;
494 cx.emit(ThreadEvent::CheckpointChanged);
495 cx.notify();
496 })?;
497 result
498 })
499 }
500
501 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
502 let pending_checkpoint = if self.is_generating() {
503 return;
504 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
505 checkpoint
506 } else {
507 return;
508 };
509
510 let git_store = self.project.read(cx).git_store().clone();
511 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
512 cx.spawn(async move |this, cx| match final_checkpoint.await {
513 Ok(final_checkpoint) => {
514 let equal = git_store
515 .update(cx, |store, cx| {
516 store.compare_checkpoints(
517 pending_checkpoint.git_checkpoint.clone(),
518 final_checkpoint.clone(),
519 cx,
520 )
521 })?
522 .await
523 .unwrap_or(false);
524
525 if equal {
526 git_store
527 .update(cx, |store, cx| {
528 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
529 })?
530 .detach();
531 } else {
532 this.update(cx, |this, cx| {
533 this.insert_checkpoint(pending_checkpoint, cx)
534 })?;
535 }
536
537 git_store
538 .update(cx, |store, cx| {
539 store.delete_checkpoint(final_checkpoint, cx)
540 })?
541 .detach();
542
543 Ok(())
544 }
545 Err(_) => this.update(cx, |this, cx| {
546 this.insert_checkpoint(pending_checkpoint, cx)
547 }),
548 })
549 .detach();
550 }
551
552 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
553 self.checkpoints_by_message
554 .insert(checkpoint.message_id, checkpoint);
555 cx.emit(ThreadEvent::CheckpointChanged);
556 cx.notify();
557 }
558
559 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
560 self.last_restore_checkpoint.as_ref()
561 }
562
563 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
564 let Some(message_ix) = self
565 .messages
566 .iter()
567 .rposition(|message| message.id == message_id)
568 else {
569 return;
570 };
571 for deleted_message in self.messages.drain(message_ix..) {
572 self.context_by_message.remove(&deleted_message.id);
573 self.checkpoints_by_message.remove(&deleted_message.id);
574 }
575 cx.notify();
576 }
577
578 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
579 self.context_by_message
580 .get(&id)
581 .into_iter()
582 .flat_map(|context| {
583 context
584 .iter()
585 .filter_map(|context_id| self.context.get(&context_id))
586 })
587 }
588
589 /// Returns whether all of the tool uses have finished running.
590 pub fn all_tools_finished(&self) -> bool {
591 // If the only pending tool uses left are the ones with errors, then
592 // that means that we've finished running all of the pending tools.
593 self.tool_use
594 .pending_tool_uses()
595 .iter()
596 .all(|tool_use| tool_use.status.is_error())
597 }
598
599 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
600 self.tool_use.tool_uses_for_message(id, cx)
601 }
602
603 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
604 self.tool_use.tool_results_for_message(id)
605 }
606
607 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
608 self.tool_use.tool_result(id)
609 }
610
611 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
612 self.tool_use.message_has_tool_results(message_id)
613 }
614
615 pub fn insert_user_message(
616 &mut self,
617 text: impl Into<String>,
618 context: Vec<AssistantContext>,
619 git_checkpoint: Option<GitStoreCheckpoint>,
620 cx: &mut Context<Self>,
621 ) -> MessageId {
622 let text = text.into();
623
624 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
625
626 // Filter out contexts that have already been included in previous messages
627 let new_context: Vec<_> = context
628 .into_iter()
629 .filter(|ctx| !self.context.contains_key(&ctx.id()))
630 .collect();
631
632 if !new_context.is_empty() {
633 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
634 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
635 message.context = context_string;
636 }
637 }
638
639 self.action_log.update(cx, |log, cx| {
640 // Track all buffers added as context
641 for ctx in &new_context {
642 match ctx {
643 AssistantContext::File(file_ctx) => {
644 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
645 }
646 AssistantContext::Directory(dir_ctx) => {
647 for context_buffer in &dir_ctx.context_buffers {
648 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
649 }
650 }
651 AssistantContext::Symbol(symbol_ctx) => {
652 log.buffer_added_as_context(
653 symbol_ctx.context_symbol.buffer.clone(),
654 cx,
655 );
656 }
657 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
658 }
659 }
660 });
661 }
662
663 let context_ids = new_context
664 .iter()
665 .map(|context| context.id())
666 .collect::<Vec<_>>();
667 self.context.extend(
668 new_context
669 .into_iter()
670 .map(|context| (context.id(), context)),
671 );
672 self.context_by_message.insert(message_id, context_ids);
673
674 if let Some(git_checkpoint) = git_checkpoint {
675 self.pending_checkpoint = Some(ThreadCheckpoint {
676 message_id,
677 git_checkpoint,
678 });
679 }
680 message_id
681 }
682
683 pub fn insert_message(
684 &mut self,
685 role: Role,
686 segments: Vec<MessageSegment>,
687 cx: &mut Context<Self>,
688 ) -> MessageId {
689 let id = self.next_message_id.post_inc();
690 self.messages.push(Message {
691 id,
692 role,
693 segments,
694 context: String::new(),
695 });
696 self.touch_updated_at();
697 cx.emit(ThreadEvent::MessageAdded(id));
698 id
699 }
700
701 pub fn edit_message(
702 &mut self,
703 id: MessageId,
704 new_role: Role,
705 new_segments: Vec<MessageSegment>,
706 cx: &mut Context<Self>,
707 ) -> bool {
708 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
709 return false;
710 };
711 message.role = new_role;
712 message.segments = new_segments;
713 self.touch_updated_at();
714 cx.emit(ThreadEvent::MessageEdited(id));
715 true
716 }
717
718 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
719 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
720 return false;
721 };
722 self.messages.remove(index);
723 self.context_by_message.remove(&id);
724 self.touch_updated_at();
725 cx.emit(ThreadEvent::MessageDeleted(id));
726 true
727 }
728
729 /// Returns the representation of this [`Thread`] in a textual form.
730 ///
731 /// This is the representation we use when attaching a thread as context to another thread.
732 pub fn text(&self) -> String {
733 let mut text = String::new();
734
735 for message in &self.messages {
736 text.push_str(match message.role {
737 language_model::Role::User => "User:",
738 language_model::Role::Assistant => "Assistant:",
739 language_model::Role::System => "System:",
740 });
741 text.push('\n');
742
743 for segment in &message.segments {
744 match segment {
745 MessageSegment::Text(content) => text.push_str(content),
746 MessageSegment::Thinking(content) => {
747 text.push_str(&format!("<think>{}</think>", content))
748 }
749 }
750 }
751 text.push('\n');
752 }
753
754 text
755 }
756
757 /// Serializes this thread into a format for storage or telemetry.
758 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
759 let initial_project_snapshot = self.initial_project_snapshot.clone();
760 cx.spawn(async move |this, cx| {
761 let initial_project_snapshot = initial_project_snapshot.await;
762 this.read_with(cx, |this, cx| SerializedThread {
763 version: SerializedThread::VERSION.to_string(),
764 summary: this.summary_or_default(),
765 updated_at: this.updated_at(),
766 messages: this
767 .messages()
768 .map(|message| SerializedMessage {
769 id: message.id,
770 role: message.role,
771 segments: message
772 .segments
773 .iter()
774 .map(|segment| match segment {
775 MessageSegment::Text(text) => {
776 SerializedMessageSegment::Text { text: text.clone() }
777 }
778 MessageSegment::Thinking(text) => {
779 SerializedMessageSegment::Thinking { text: text.clone() }
780 }
781 })
782 .collect(),
783 tool_uses: this
784 .tool_uses_for_message(message.id, cx)
785 .into_iter()
786 .map(|tool_use| SerializedToolUse {
787 id: tool_use.id,
788 name: tool_use.name,
789 input: tool_use.input,
790 })
791 .collect(),
792 tool_results: this
793 .tool_results_for_message(message.id)
794 .into_iter()
795 .map(|tool_result| SerializedToolResult {
796 tool_use_id: tool_result.tool_use_id.clone(),
797 is_error: tool_result.is_error,
798 content: tool_result.content.clone(),
799 })
800 .collect(),
801 context: message.context.clone(),
802 })
803 .collect(),
804 initial_project_snapshot,
805 cumulative_token_usage: this.cumulative_token_usage.clone(),
806 detailed_summary_state: this.detailed_summary_state.clone(),
807 })
808 })
809 }
810
811 pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
812 self.system_prompt_context = Some(context);
813 }
814
815 pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
816 &self.system_prompt_context
817 }
818
819 pub fn load_system_prompt_context(
820 &self,
821 cx: &App,
822 ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
823 let project = self.project.read(cx);
824 let tasks = project
825 .visible_worktrees(cx)
826 .map(|worktree| {
827 Self::load_worktree_info_for_system_prompt(
828 project.fs().clone(),
829 worktree.read(cx),
830 cx,
831 )
832 })
833 .collect::<Vec<_>>();
834
835 cx.spawn(async |_cx| {
836 let results = futures::future::join_all(tasks).await;
837 let mut first_err = None;
838 let worktrees = results
839 .into_iter()
840 .map(|(worktree, err)| {
841 if first_err.is_none() && err.is_some() {
842 first_err = err;
843 }
844 worktree
845 })
846 .collect::<Vec<_>>();
847 (AssistantSystemPromptContext::new(worktrees), first_err)
848 })
849 }
850
851 fn load_worktree_info_for_system_prompt(
852 fs: Arc<dyn Fs>,
853 worktree: &Worktree,
854 cx: &App,
855 ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
856 let root_name = worktree.root_name().into();
857 let abs_path = worktree.abs_path();
858
859 let rules_task = load_worktree_rules_file(fs, worktree, cx);
860 let Some(rules_task) = rules_task else {
861 return Task::ready((
862 WorktreeInfoForSystemPrompt {
863 root_name,
864 abs_path,
865 rules_file: None,
866 },
867 None,
868 ));
869 };
870
871 cx.spawn(async move |_| {
872 let (rules_file, rules_file_error) = match rules_task.await {
873 Ok(rules_file) => (Some(rules_file), None),
874 Err(err) => (
875 None,
876 Some(ThreadError::Message {
877 header: "Error loading rules file".into(),
878 message: format!("{err}").into(),
879 }),
880 ),
881 };
882 let worktree_info = WorktreeInfoForSystemPrompt {
883 root_name,
884 abs_path,
885 rules_file,
886 };
887 (worktree_info, rules_file_error)
888 })
889 }
890
891 pub fn send_to_model(
892 &mut self,
893 model: Arc<dyn LanguageModel>,
894 request_kind: RequestKind,
895 cx: &mut Context<Self>,
896 ) {
897 let mut request = self.to_completion_request(request_kind, cx);
898 if model.supports_tools() {
899 request.tools = {
900 let mut tools = Vec::new();
901 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
902 LanguageModelRequestTool {
903 name: tool.name(),
904 description: tool.description(),
905 input_schema: tool.input_schema(model.tool_input_format()),
906 }
907 }));
908
909 tools
910 };
911 }
912
913 self.stream_completion(request, model, cx);
914 }
915
916 pub fn used_tools_since_last_user_message(&self) -> bool {
917 for message in self.messages.iter().rev() {
918 if self.tool_use.message_has_tool_results(message.id) {
919 return true;
920 } else if message.role == Role::User {
921 return false;
922 }
923 }
924
925 false
926 }
927
928 pub fn to_completion_request(
929 &self,
930 request_kind: RequestKind,
931 cx: &App,
932 ) -> LanguageModelRequest {
933 let mut request = LanguageModelRequest {
934 messages: vec![],
935 tools: Vec::new(),
936 stop: Vec::new(),
937 temperature: None,
938 };
939
940 if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
941 if let Some(system_prompt) = self
942 .prompt_builder
943 .generate_assistant_system_prompt(system_prompt_context)
944 .context("failed to generate assistant system prompt")
945 .log_err()
946 {
947 request.messages.push(LanguageModelRequestMessage {
948 role: Role::System,
949 content: vec![MessageContent::Text(system_prompt)],
950 cache: true,
951 });
952 }
953 } else {
954 log::error!("system_prompt_context not set.")
955 }
956
957 for message in &self.messages {
958 let mut request_message = LanguageModelRequestMessage {
959 role: message.role,
960 content: Vec::new(),
961 cache: false,
962 };
963
964 match request_kind {
965 RequestKind::Chat => {
966 self.tool_use
967 .attach_tool_results(message.id, &mut request_message);
968 }
969 RequestKind::Summarize => {
970 // We don't care about tool use during summarization.
971 if self.tool_use.message_has_tool_results(message.id) {
972 continue;
973 }
974 }
975 }
976
977 if !message.segments.is_empty() {
978 request_message
979 .content
980 .push(MessageContent::Text(message.to_string()));
981 }
982
983 match request_kind {
984 RequestKind::Chat => {
985 self.tool_use
986 .attach_tool_uses(message.id, &mut request_message);
987 }
988 RequestKind::Summarize => {
989 // We don't care about tool use during summarization.
990 }
991 };
992
993 request.messages.push(request_message);
994 }
995
996 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
997 if let Some(last) = request.messages.last_mut() {
998 last.cache = true;
999 }
1000
1001 self.attached_tracked_files_state(&mut request.messages, cx);
1002
1003 // Add reminder to the last user message about code blocks
1004 if let Some(last_user_message) = request
1005 .messages
1006 .iter_mut()
1007 .rev()
1008 .find(|msg| msg.role == Role::User)
1009 {
1010 last_user_message
1011 .content
1012 .push(MessageContent::Text(system_prompt_reminder(
1013 &self.prompt_builder,
1014 )));
1015 }
1016
1017 request
1018 }
1019
1020 fn attached_tracked_files_state(
1021 &self,
1022 messages: &mut Vec<LanguageModelRequestMessage>,
1023 cx: &App,
1024 ) {
1025 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1026
1027 let mut stale_message = String::new();
1028
1029 let action_log = self.action_log.read(cx);
1030
1031 for stale_file in action_log.stale_buffers(cx) {
1032 let Some(file) = stale_file.read(cx).file() else {
1033 continue;
1034 };
1035
1036 if stale_message.is_empty() {
1037 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1038 }
1039
1040 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1041 }
1042
1043 let mut content = Vec::with_capacity(2);
1044
1045 if !stale_message.is_empty() {
1046 content.push(stale_message.into());
1047 }
1048
1049 if action_log.has_edited_files_since_project_diagnostics_check() {
1050 content.push(
1051 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1052 and fix all errors AND warnings you introduced! \
1053 DO NOT mention you're going to do this until you're done."
1054 .into(),
1055 );
1056 }
1057
1058 if !content.is_empty() {
1059 let context_message = LanguageModelRequestMessage {
1060 role: Role::User,
1061 content,
1062 cache: false,
1063 };
1064
1065 messages.push(context_message);
1066 }
1067 }
1068
1069 pub fn stream_completion(
1070 &mut self,
1071 request: LanguageModelRequest,
1072 model: Arc<dyn LanguageModel>,
1073 cx: &mut Context<Self>,
1074 ) {
1075 let pending_completion_id = post_inc(&mut self.completion_count);
1076
1077 let task = cx.spawn(async move |thread, cx| {
1078 let stream = model.stream_completion(request, &cx);
1079 let initial_token_usage =
1080 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1081 let stream_completion = async {
1082 let mut events = stream.await?;
1083 let mut stop_reason = StopReason::EndTurn;
1084 let mut current_token_usage = TokenUsage::default();
1085
1086 while let Some(event) = events.next().await {
1087 let event = event?;
1088
1089 thread.update(cx, |thread, cx| {
1090 match event {
1091 LanguageModelCompletionEvent::StartMessage { .. } => {
1092 thread.insert_message(
1093 Role::Assistant,
1094 vec![MessageSegment::Text(String::new())],
1095 cx,
1096 );
1097 }
1098 LanguageModelCompletionEvent::Stop(reason) => {
1099 stop_reason = reason;
1100 }
1101 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1102 thread.cumulative_token_usage =
1103 thread.cumulative_token_usage.clone() + token_usage.clone()
1104 - current_token_usage.clone();
1105 current_token_usage = token_usage;
1106 }
1107 LanguageModelCompletionEvent::Text(chunk) => {
1108 if let Some(last_message) = thread.messages.last_mut() {
1109 if last_message.role == Role::Assistant {
1110 last_message.push_text(&chunk);
1111 cx.emit(ThreadEvent::StreamedAssistantText(
1112 last_message.id,
1113 chunk,
1114 ));
1115 } else {
1116 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1117 // of a new Assistant response.
1118 //
1119 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1120 // will result in duplicating the text of the chunk in the rendered Markdown.
1121 thread.insert_message(
1122 Role::Assistant,
1123 vec![MessageSegment::Text(chunk.to_string())],
1124 cx,
1125 );
1126 };
1127 }
1128 }
1129 LanguageModelCompletionEvent::Thinking(chunk) => {
1130 if let Some(last_message) = thread.messages.last_mut() {
1131 if last_message.role == Role::Assistant {
1132 last_message.push_thinking(&chunk);
1133 cx.emit(ThreadEvent::StreamedAssistantThinking(
1134 last_message.id,
1135 chunk,
1136 ));
1137 } else {
1138 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1139 // of a new Assistant response.
1140 //
1141 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1142 // will result in duplicating the text of the chunk in the rendered Markdown.
1143 thread.insert_message(
1144 Role::Assistant,
1145 vec![MessageSegment::Thinking(chunk.to_string())],
1146 cx,
1147 );
1148 };
1149 }
1150 }
1151 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1152 let last_assistant_message_id = thread
1153 .messages
1154 .iter_mut()
1155 .rfind(|message| message.role == Role::Assistant)
1156 .map(|message| message.id)
1157 .unwrap_or_else(|| {
1158 thread.insert_message(Role::Assistant, vec![], cx)
1159 });
1160
1161 thread.tool_use.request_tool_use(
1162 last_assistant_message_id,
1163 tool_use,
1164 cx,
1165 );
1166 }
1167 }
1168
1169 thread.touch_updated_at();
1170 cx.emit(ThreadEvent::StreamedCompletion);
1171 cx.notify();
1172 })?;
1173
1174 smol::future::yield_now().await;
1175 }
1176
1177 thread.update(cx, |thread, cx| {
1178 thread
1179 .pending_completions
1180 .retain(|completion| completion.id != pending_completion_id);
1181
1182 if thread.summary.is_none() && thread.messages.len() >= 2 {
1183 thread.summarize(cx);
1184 }
1185 })?;
1186
1187 anyhow::Ok(stop_reason)
1188 };
1189
1190 let result = stream_completion.await;
1191
1192 thread
1193 .update(cx, |thread, cx| {
1194 thread.finalize_pending_checkpoint(cx);
1195 match result.as_ref() {
1196 Ok(stop_reason) => match stop_reason {
1197 StopReason::ToolUse => {
1198 cx.emit(ThreadEvent::UsePendingTools);
1199 }
1200 StopReason::EndTurn => {}
1201 StopReason::MaxTokens => {}
1202 },
1203 Err(error) => {
1204 if error.is::<PaymentRequiredError>() {
1205 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1206 } else if error.is::<MaxMonthlySpendReachedError>() {
1207 cx.emit(ThreadEvent::ShowError(
1208 ThreadError::MaxMonthlySpendReached,
1209 ));
1210 } else {
1211 let error_message = error
1212 .chain()
1213 .map(|err| err.to_string())
1214 .collect::<Vec<_>>()
1215 .join("\n");
1216 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1217 header: "Error interacting with language model".into(),
1218 message: SharedString::from(error_message.clone()),
1219 }));
1220 }
1221
1222 thread.cancel_last_completion(cx);
1223 }
1224 }
1225 cx.emit(ThreadEvent::DoneStreaming);
1226
1227 if let Ok(initial_usage) = initial_token_usage {
1228 let usage = thread.cumulative_token_usage.clone() - initial_usage;
1229
1230 telemetry::event!(
1231 "Assistant Thread Completion",
1232 thread_id = thread.id().to_string(),
1233 model = model.telemetry_id(),
1234 model_provider = model.provider_id().to_string(),
1235 input_tokens = usage.input_tokens,
1236 output_tokens = usage.output_tokens,
1237 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1238 cache_read_input_tokens = usage.cache_read_input_tokens,
1239 );
1240 }
1241 })
1242 .ok();
1243 });
1244
1245 self.pending_completions.push(PendingCompletion {
1246 id: pending_completion_id,
1247 _task: task,
1248 });
1249 }
1250
1251 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1252 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1253 return;
1254 };
1255
1256 if !model.provider.is_authenticated(cx) {
1257 return;
1258 }
1259
1260 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1261 request.messages.push(LanguageModelRequestMessage {
1262 role: Role::User,
1263 content: vec![
1264 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1265 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1266 If the conversation is about a specific subject, include it in the title. \
1267 Be descriptive. DO NOT speak in the first person."
1268 .into(),
1269 ],
1270 cache: false,
1271 });
1272
1273 self.pending_summary = cx.spawn(async move |this, cx| {
1274 async move {
1275 let stream = model.model.stream_completion_text(request, &cx);
1276 let mut messages = stream.await?;
1277
1278 let mut new_summary = String::new();
1279 while let Some(message) = messages.stream.next().await {
1280 let text = message?;
1281 let mut lines = text.lines();
1282 new_summary.extend(lines.next());
1283
1284 // Stop if the LLM generated multiple lines.
1285 if lines.next().is_some() {
1286 break;
1287 }
1288 }
1289
1290 this.update(cx, |this, cx| {
1291 if !new_summary.is_empty() {
1292 this.summary = Some(new_summary.into());
1293 }
1294
1295 cx.emit(ThreadEvent::SummaryGenerated);
1296 })?;
1297
1298 anyhow::Ok(())
1299 }
1300 .log_err()
1301 .await
1302 });
1303 }
1304
1305 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1306 let last_message_id = self.messages.last().map(|message| message.id)?;
1307
1308 match &self.detailed_summary_state {
1309 DetailedSummaryState::Generating { message_id, .. }
1310 | DetailedSummaryState::Generated { message_id, .. }
1311 if *message_id == last_message_id =>
1312 {
1313 // Already up-to-date
1314 return None;
1315 }
1316 _ => {}
1317 }
1318
1319 let ConfiguredModel { model, provider } =
1320 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1321
1322 if !provider.is_authenticated(cx) {
1323 return None;
1324 }
1325
1326 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1327
1328 request.messages.push(LanguageModelRequestMessage {
1329 role: Role::User,
1330 content: vec![
1331 "Generate a detailed summary of this conversation. Include:\n\
1332 1. A brief overview of what was discussed\n\
1333 2. Key facts or information discovered\n\
1334 3. Outcomes or conclusions reached\n\
1335 4. Any action items or next steps if any\n\
1336 Format it in Markdown with headings and bullet points."
1337 .into(),
1338 ],
1339 cache: false,
1340 });
1341
1342 let task = cx.spawn(async move |thread, cx| {
1343 let stream = model.stream_completion_text(request, &cx);
1344 let Some(mut messages) = stream.await.log_err() else {
1345 thread
1346 .update(cx, |this, _cx| {
1347 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1348 })
1349 .log_err();
1350
1351 return;
1352 };
1353
1354 let mut new_detailed_summary = String::new();
1355
1356 while let Some(chunk) = messages.stream.next().await {
1357 if let Some(chunk) = chunk.log_err() {
1358 new_detailed_summary.push_str(&chunk);
1359 }
1360 }
1361
1362 thread
1363 .update(cx, |this, _cx| {
1364 this.detailed_summary_state = DetailedSummaryState::Generated {
1365 text: new_detailed_summary.into(),
1366 message_id: last_message_id,
1367 };
1368 })
1369 .log_err();
1370 });
1371
1372 self.detailed_summary_state = DetailedSummaryState::Generating {
1373 message_id: last_message_id,
1374 };
1375
1376 Some(task)
1377 }
1378
1379 pub fn is_generating_detailed_summary(&self) -> bool {
1380 matches!(
1381 self.detailed_summary_state,
1382 DetailedSummaryState::Generating { .. }
1383 )
1384 }
1385
1386 pub fn use_pending_tools(
1387 &mut self,
1388 cx: &mut Context<Self>,
1389 ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1390 let request = self.to_completion_request(RequestKind::Chat, cx);
1391 let messages = Arc::new(request.messages);
1392 let pending_tool_uses = self
1393 .tool_use
1394 .pending_tool_uses()
1395 .into_iter()
1396 .filter(|tool_use| tool_use.status.is_idle())
1397 .cloned()
1398 .collect::<Vec<_>>();
1399
1400 for tool_use in pending_tool_uses.iter() {
1401 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1402 if tool.needs_confirmation(&tool_use.input, cx)
1403 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1404 {
1405 self.tool_use.confirm_tool_use(
1406 tool_use.id.clone(),
1407 tool_use.ui_text.clone(),
1408 tool_use.input.clone(),
1409 messages.clone(),
1410 tool,
1411 );
1412 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1413 } else {
1414 self.run_tool(
1415 tool_use.id.clone(),
1416 tool_use.ui_text.clone(),
1417 tool_use.input.clone(),
1418 &messages,
1419 tool,
1420 cx,
1421 );
1422 }
1423 }
1424 }
1425
1426 pending_tool_uses
1427 }
1428
1429 pub fn run_tool(
1430 &mut self,
1431 tool_use_id: LanguageModelToolUseId,
1432 ui_text: impl Into<SharedString>,
1433 input: serde_json::Value,
1434 messages: &[LanguageModelRequestMessage],
1435 tool: Arc<dyn Tool>,
1436 cx: &mut Context<Thread>,
1437 ) {
1438 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1439 self.tool_use
1440 .run_pending_tool(tool_use_id, ui_text.into(), task);
1441 }
1442
1443 fn spawn_tool_use(
1444 &mut self,
1445 tool_use_id: LanguageModelToolUseId,
1446 messages: &[LanguageModelRequestMessage],
1447 input: serde_json::Value,
1448 tool: Arc<dyn Tool>,
1449 cx: &mut Context<Thread>,
1450 ) -> Task<()> {
1451 let tool_name: Arc<str> = tool.name().into();
1452
1453 let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1454 Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1455 } else {
1456 tool.run(
1457 input,
1458 messages,
1459 self.project.clone(),
1460 self.action_log.clone(),
1461 cx,
1462 )
1463 };
1464
1465 cx.spawn({
1466 async move |thread: WeakEntity<Thread>, cx| {
1467 let output = run_tool.await;
1468
1469 thread
1470 .update(cx, |thread, cx| {
1471 let pending_tool_use = thread.tool_use.insert_tool_output(
1472 tool_use_id.clone(),
1473 tool_name,
1474 output,
1475 cx,
1476 );
1477
1478 cx.emit(ThreadEvent::ToolFinished {
1479 tool_use_id,
1480 pending_tool_use,
1481 canceled: false,
1482 });
1483 })
1484 .ok();
1485 }
1486 })
1487 }
1488
1489 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1490 // Insert a user message to contain the tool results.
1491 self.insert_user_message(
1492 // TODO: Sending up a user message without any content results in the model sending back
1493 // responses that also don't have any content. We currently don't handle this case well,
1494 // so for now we provide some text to keep the model on track.
1495 "Here are the tool results.",
1496 Vec::new(),
1497 None,
1498 cx,
1499 );
1500 }
1501
1502 /// Cancels the last pending completion, if there are any pending.
1503 ///
1504 /// Returns whether a completion was canceled.
1505 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1506 let canceled = if self.pending_completions.pop().is_some() {
1507 true
1508 } else {
1509 let mut canceled = false;
1510 for pending_tool_use in self.tool_use.cancel_pending() {
1511 canceled = true;
1512 cx.emit(ThreadEvent::ToolFinished {
1513 tool_use_id: pending_tool_use.id.clone(),
1514 pending_tool_use: Some(pending_tool_use),
1515 canceled: true,
1516 });
1517 }
1518 canceled
1519 };
1520 self.finalize_pending_checkpoint(cx);
1521 canceled
1522 }
1523
1524 pub fn feedback(&self) -> Option<ThreadFeedback> {
1525 self.feedback
1526 }
1527
1528 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1529 self.message_feedback.get(&message_id).copied()
1530 }
1531
1532 pub fn report_message_feedback(
1533 &mut self,
1534 message_id: MessageId,
1535 feedback: ThreadFeedback,
1536 cx: &mut Context<Self>,
1537 ) -> Task<Result<()>> {
1538 if self.message_feedback.get(&message_id) == Some(&feedback) {
1539 return Task::ready(Ok(()));
1540 }
1541
1542 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1543 let serialized_thread = self.serialize(cx);
1544 let thread_id = self.id().clone();
1545 let client = self.project.read(cx).client();
1546
1547 self.message_feedback.insert(message_id, feedback);
1548
1549 cx.notify();
1550
1551 let message_content = self
1552 .message(message_id)
1553 .map(|msg| msg.to_string())
1554 .unwrap_or_default();
1555
1556 cx.background_spawn(async move {
1557 let final_project_snapshot = final_project_snapshot.await;
1558 let serialized_thread = serialized_thread.await?;
1559 let thread_data =
1560 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1561
1562 let rating = match feedback {
1563 ThreadFeedback::Positive => "positive",
1564 ThreadFeedback::Negative => "negative",
1565 };
1566 telemetry::event!(
1567 "Assistant Thread Rated",
1568 rating,
1569 thread_id,
1570 message_id = message_id.0,
1571 message_content,
1572 thread_data,
1573 final_project_snapshot
1574 );
1575 client.telemetry().flush_events();
1576
1577 Ok(())
1578 })
1579 }
1580
1581 pub fn report_feedback(
1582 &mut self,
1583 feedback: ThreadFeedback,
1584 cx: &mut Context<Self>,
1585 ) -> Task<Result<()>> {
1586 let last_assistant_message_id = self
1587 .messages
1588 .iter()
1589 .rev()
1590 .find(|msg| msg.role == Role::Assistant)
1591 .map(|msg| msg.id);
1592
1593 if let Some(message_id) = last_assistant_message_id {
1594 self.report_message_feedback(message_id, feedback, cx)
1595 } else {
1596 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1597 let serialized_thread = self.serialize(cx);
1598 let thread_id = self.id().clone();
1599 let client = self.project.read(cx).client();
1600 self.feedback = Some(feedback);
1601 cx.notify();
1602
1603 cx.background_spawn(async move {
1604 let final_project_snapshot = final_project_snapshot.await;
1605 let serialized_thread = serialized_thread.await?;
1606 let thread_data = serde_json::to_value(serialized_thread)
1607 .unwrap_or_else(|_| serde_json::Value::Null);
1608
1609 let rating = match feedback {
1610 ThreadFeedback::Positive => "positive",
1611 ThreadFeedback::Negative => "negative",
1612 };
1613 telemetry::event!(
1614 "Assistant Thread Rated",
1615 rating,
1616 thread_id,
1617 thread_data,
1618 final_project_snapshot
1619 );
1620 client.telemetry().flush_events();
1621
1622 Ok(())
1623 })
1624 }
1625 }
1626
1627 /// Create a snapshot of the current project state including git information and unsaved buffers.
1628 fn project_snapshot(
1629 project: Entity<Project>,
1630 cx: &mut Context<Self>,
1631 ) -> Task<Arc<ProjectSnapshot>> {
1632 let git_store = project.read(cx).git_store().clone();
1633 let worktree_snapshots: Vec<_> = project
1634 .read(cx)
1635 .visible_worktrees(cx)
1636 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1637 .collect();
1638
1639 cx.spawn(async move |_, cx| {
1640 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1641
1642 let mut unsaved_buffers = Vec::new();
1643 cx.update(|app_cx| {
1644 let buffer_store = project.read(app_cx).buffer_store();
1645 for buffer_handle in buffer_store.read(app_cx).buffers() {
1646 let buffer = buffer_handle.read(app_cx);
1647 if buffer.is_dirty() {
1648 if let Some(file) = buffer.file() {
1649 let path = file.path().to_string_lossy().to_string();
1650 unsaved_buffers.push(path);
1651 }
1652 }
1653 }
1654 })
1655 .ok();
1656
1657 Arc::new(ProjectSnapshot {
1658 worktree_snapshots,
1659 unsaved_buffer_paths: unsaved_buffers,
1660 timestamp: Utc::now(),
1661 })
1662 })
1663 }
1664
1665 fn worktree_snapshot(
1666 worktree: Entity<project::Worktree>,
1667 git_store: Entity<GitStore>,
1668 cx: &App,
1669 ) -> Task<WorktreeSnapshot> {
1670 cx.spawn(async move |cx| {
1671 // Get worktree path and snapshot
1672 let worktree_info = cx.update(|app_cx| {
1673 let worktree = worktree.read(app_cx);
1674 let path = worktree.abs_path().to_string_lossy().to_string();
1675 let snapshot = worktree.snapshot();
1676 (path, snapshot)
1677 });
1678
1679 let Ok((worktree_path, _snapshot)) = worktree_info else {
1680 return WorktreeSnapshot {
1681 worktree_path: String::new(),
1682 git_state: None,
1683 };
1684 };
1685
1686 let git_state = git_store
1687 .update(cx, |git_store, cx| {
1688 git_store
1689 .repositories()
1690 .values()
1691 .find(|repo| {
1692 repo.read(cx)
1693 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1694 .is_some()
1695 })
1696 .cloned()
1697 })
1698 .ok()
1699 .flatten()
1700 .map(|repo| {
1701 repo.update(cx, |repo, _| {
1702 let current_branch =
1703 repo.branch.as_ref().map(|branch| branch.name.to_string());
1704 repo.send_job(None, |state, _| async move {
1705 let RepositoryState::Local { backend, .. } = state else {
1706 return GitState {
1707 remote_url: None,
1708 head_sha: None,
1709 current_branch,
1710 diff: None,
1711 };
1712 };
1713
1714 let remote_url = backend.remote_url("origin");
1715 let head_sha = backend.head_sha();
1716 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1717
1718 GitState {
1719 remote_url,
1720 head_sha,
1721 current_branch,
1722 diff,
1723 }
1724 })
1725 })
1726 });
1727
1728 let git_state = match git_state {
1729 Some(git_state) => match git_state.ok() {
1730 Some(git_state) => git_state.await.ok(),
1731 None => None,
1732 },
1733 None => None,
1734 };
1735
1736 WorktreeSnapshot {
1737 worktree_path,
1738 git_state,
1739 }
1740 })
1741 }
1742
1743 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1744 let mut markdown = Vec::new();
1745
1746 if let Some(summary) = self.summary() {
1747 writeln!(markdown, "# {summary}\n")?;
1748 };
1749
1750 for message in self.messages() {
1751 writeln!(
1752 markdown,
1753 "## {role}\n",
1754 role = match message.role {
1755 Role::User => "User",
1756 Role::Assistant => "Assistant",
1757 Role::System => "System",
1758 }
1759 )?;
1760
1761 if !message.context.is_empty() {
1762 writeln!(markdown, "{}", message.context)?;
1763 }
1764
1765 for segment in &message.segments {
1766 match segment {
1767 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1768 MessageSegment::Thinking(text) => {
1769 writeln!(markdown, "<think>{}</think>\n", text)?
1770 }
1771 }
1772 }
1773
1774 for tool_use in self.tool_uses_for_message(message.id, cx) {
1775 writeln!(
1776 markdown,
1777 "**Use Tool: {} ({})**",
1778 tool_use.name, tool_use.id
1779 )?;
1780 writeln!(markdown, "```json")?;
1781 writeln!(
1782 markdown,
1783 "{}",
1784 serde_json::to_string_pretty(&tool_use.input)?
1785 )?;
1786 writeln!(markdown, "```")?;
1787 }
1788
1789 for tool_result in self.tool_results_for_message(message.id) {
1790 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1791 if tool_result.is_error {
1792 write!(markdown, " (Error)")?;
1793 }
1794
1795 writeln!(markdown, "**\n")?;
1796 writeln!(markdown, "{}", tool_result.content)?;
1797 }
1798 }
1799
1800 Ok(String::from_utf8_lossy(&markdown).to_string())
1801 }
1802
1803 pub fn keep_edits_in_range(
1804 &mut self,
1805 buffer: Entity<language::Buffer>,
1806 buffer_range: Range<language::Anchor>,
1807 cx: &mut Context<Self>,
1808 ) {
1809 self.action_log.update(cx, |action_log, cx| {
1810 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1811 });
1812 }
1813
1814 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1815 self.action_log
1816 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1817 }
1818
1819 pub fn reject_edits_in_range(
1820 &mut self,
1821 buffer: Entity<language::Buffer>,
1822 buffer_range: Range<language::Anchor>,
1823 cx: &mut Context<Self>,
1824 ) -> Task<Result<()>> {
1825 self.action_log.update(cx, |action_log, cx| {
1826 action_log.reject_edits_in_range(buffer, buffer_range, cx)
1827 })
1828 }
1829
1830 pub fn action_log(&self) -> &Entity<ActionLog> {
1831 &self.action_log
1832 }
1833
1834 pub fn project(&self) -> &Entity<Project> {
1835 &self.project
1836 }
1837
1838 pub fn cumulative_token_usage(&self) -> TokenUsage {
1839 self.cumulative_token_usage.clone()
1840 }
1841
1842 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1843 let model_registry = LanguageModelRegistry::read_global(cx);
1844 let Some(model) = model_registry.default_model() else {
1845 return TotalTokenUsage::default();
1846 };
1847
1848 let max = model.model.max_token_count();
1849
1850 #[cfg(debug_assertions)]
1851 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1852 .unwrap_or("0.8".to_string())
1853 .parse()
1854 .unwrap();
1855 #[cfg(not(debug_assertions))]
1856 let warning_threshold: f32 = 0.8;
1857
1858 let total = self.cumulative_token_usage.total_tokens() as usize;
1859
1860 let ratio = if total >= max {
1861 TokenUsageRatio::Exceeded
1862 } else if total as f32 / max as f32 >= warning_threshold {
1863 TokenUsageRatio::Warning
1864 } else {
1865 TokenUsageRatio::Normal
1866 };
1867
1868 TotalTokenUsage { total, max, ratio }
1869 }
1870
1871 pub fn deny_tool_use(
1872 &mut self,
1873 tool_use_id: LanguageModelToolUseId,
1874 tool_name: Arc<str>,
1875 cx: &mut Context<Self>,
1876 ) {
1877 let err = Err(anyhow::anyhow!(
1878 "Permission to run tool action denied by user"
1879 ));
1880
1881 self.tool_use
1882 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1883
1884 cx.emit(ThreadEvent::ToolFinished {
1885 tool_use_id,
1886 pending_tool_use: None,
1887 canceled: true,
1888 });
1889 }
1890}
1891
1892pub fn system_prompt_reminder(prompt_builder: &prompt_store::PromptBuilder) -> String {
1893 prompt_builder
1894 .generate_assistant_system_prompt_reminder()
1895 .unwrap_or_default()
1896}
1897
1898#[derive(Debug, Clone)]
1899pub enum ThreadError {
1900 PaymentRequired,
1901 MaxMonthlySpendReached,
1902 Message {
1903 header: SharedString,
1904 message: SharedString,
1905 },
1906}
1907
1908#[derive(Debug, Clone)]
1909pub enum ThreadEvent {
1910 ShowError(ThreadError),
1911 StreamedCompletion,
1912 StreamedAssistantText(MessageId, String),
1913 StreamedAssistantThinking(MessageId, String),
1914 DoneStreaming,
1915 MessageAdded(MessageId),
1916 MessageEdited(MessageId),
1917 MessageDeleted(MessageId),
1918 SummaryGenerated,
1919 SummaryChanged,
1920 UsePendingTools,
1921 ToolFinished {
1922 #[allow(unused)]
1923 tool_use_id: LanguageModelToolUseId,
1924 /// The pending tool use that corresponds to this tool.
1925 pending_tool_use: Option<PendingToolUse>,
1926 /// Whether the tool was canceled by the user.
1927 canceled: bool,
1928 },
1929 CheckpointChanged,
1930 ToolConfirmationNeeded,
1931}
1932
1933impl EventEmitter<ThreadEvent> for Thread {}
1934
1935struct PendingCompletion {
1936 id: usize,
1937 _task: Task<()>,
1938}
1939
1940#[cfg(test)]
1941mod tests {
1942 use super::*;
1943 use crate::{ThreadStore, context_store::ContextStore, thread_store};
1944 use assistant_settings::AssistantSettings;
1945 use context_server::ContextServerSettings;
1946 use editor::EditorSettings;
1947 use gpui::TestAppContext;
1948 use project::{FakeFs, Project};
1949 use prompt_store::PromptBuilder;
1950 use serde_json::json;
1951 use settings::{Settings, SettingsStore};
1952 use std::sync::Arc;
1953 use theme::ThemeSettings;
1954 use util::path;
1955 use workspace::Workspace;
1956
1957 #[gpui::test]
1958 async fn test_message_with_context(cx: &mut TestAppContext) {
1959 init_test_settings(cx);
1960
1961 let project = create_test_project(
1962 cx,
1963 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1964 )
1965 .await;
1966
1967 let (_workspace, _thread_store, thread, context_store, prompt_builder) =
1968 setup_test_environment(cx, project.clone()).await;
1969
1970 add_file_to_context(&project, &context_store, "test/code.rs", cx)
1971 .await
1972 .unwrap();
1973
1974 let context =
1975 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1976
1977 // Insert user message with context
1978 let message_id = thread.update(cx, |thread, cx| {
1979 thread.insert_user_message("Please explain this code", vec![context], None, cx)
1980 });
1981
1982 // Check content and context in message object
1983 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1984
1985 // Use different path format strings based on platform for the test
1986 #[cfg(windows)]
1987 let path_part = r"test\code.rs";
1988 #[cfg(not(windows))]
1989 let path_part = "test/code.rs";
1990
1991 let expected_context = format!(
1992 r#"
1993<context>
1994The following items were attached by the user. You don't need to use other tools to read them.
1995
1996<files>
1997```rs {path_part}
1998fn main() {{
1999 println!("Hello, world!");
2000}}
2001```
2002</files>
2003</context>
2004"#
2005 );
2006
2007 assert_eq!(message.role, Role::User);
2008 assert_eq!(message.segments.len(), 1);
2009 assert_eq!(
2010 message.segments[0],
2011 MessageSegment::Text("Please explain this code".to_string())
2012 );
2013 assert_eq!(message.context, expected_context);
2014
2015 // Check message in request
2016 let request = thread.read_with(cx, |thread, cx| {
2017 thread.to_completion_request(RequestKind::Chat, cx)
2018 });
2019
2020 assert_eq!(request.messages.len(), 1);
2021 let actual_message = request.messages[0].string_contents();
2022 let expected_content = format!(
2023 "{}Please explain this code{}",
2024 expected_context,
2025 system_prompt_reminder(&prompt_builder)
2026 );
2027
2028 assert_eq!(actual_message, expected_content);
2029 }
2030
2031 #[gpui::test]
2032 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2033 init_test_settings(cx);
2034
2035 let project = create_test_project(
2036 cx,
2037 json!({
2038 "file1.rs": "fn function1() {}\n",
2039 "file2.rs": "fn function2() {}\n",
2040 "file3.rs": "fn function3() {}\n",
2041 }),
2042 )
2043 .await;
2044
2045 let (_, _thread_store, thread, context_store, _prompt_builder) =
2046 setup_test_environment(cx, project.clone()).await;
2047
2048 // Open files individually
2049 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2050 .await
2051 .unwrap();
2052 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2053 .await
2054 .unwrap();
2055 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2056 .await
2057 .unwrap();
2058
2059 // Get the context objects
2060 let contexts = context_store.update(cx, |store, _| store.context().clone());
2061 assert_eq!(contexts.len(), 3);
2062
2063 // First message with context 1
2064 let message1_id = thread.update(cx, |thread, cx| {
2065 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2066 });
2067
2068 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2069 let message2_id = thread.update(cx, |thread, cx| {
2070 thread.insert_user_message(
2071 "Message 2",
2072 vec![contexts[0].clone(), contexts[1].clone()],
2073 None,
2074 cx,
2075 )
2076 });
2077
2078 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2079 let message3_id = thread.update(cx, |thread, cx| {
2080 thread.insert_user_message(
2081 "Message 3",
2082 vec![
2083 contexts[0].clone(),
2084 contexts[1].clone(),
2085 contexts[2].clone(),
2086 ],
2087 None,
2088 cx,
2089 )
2090 });
2091
2092 // Check what contexts are included in each message
2093 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2094 (
2095 thread.message(message1_id).unwrap().clone(),
2096 thread.message(message2_id).unwrap().clone(),
2097 thread.message(message3_id).unwrap().clone(),
2098 )
2099 });
2100
2101 // First message should include context 1
2102 assert!(message1.context.contains("file1.rs"));
2103
2104 // Second message should include only context 2 (not 1)
2105 assert!(!message2.context.contains("file1.rs"));
2106 assert!(message2.context.contains("file2.rs"));
2107
2108 // Third message should include only context 3 (not 1 or 2)
2109 assert!(!message3.context.contains("file1.rs"));
2110 assert!(!message3.context.contains("file2.rs"));
2111 assert!(message3.context.contains("file3.rs"));
2112
2113 // Check entire request to make sure all contexts are properly included
2114 let request = thread.read_with(cx, |thread, cx| {
2115 thread.to_completion_request(RequestKind::Chat, cx)
2116 });
2117
2118 // The request should contain all 3 messages
2119 assert_eq!(request.messages.len(), 3);
2120
2121 // Check that the contexts are properly formatted in each message
2122 assert!(request.messages[0].string_contents().contains("file1.rs"));
2123 assert!(!request.messages[0].string_contents().contains("file2.rs"));
2124 assert!(!request.messages[0].string_contents().contains("file3.rs"));
2125
2126 assert!(!request.messages[1].string_contents().contains("file1.rs"));
2127 assert!(request.messages[1].string_contents().contains("file2.rs"));
2128 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2129
2130 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2131 assert!(!request.messages[2].string_contents().contains("file2.rs"));
2132 assert!(request.messages[2].string_contents().contains("file3.rs"));
2133 }
2134
2135 #[gpui::test]
2136 async fn test_message_without_files(cx: &mut TestAppContext) {
2137 init_test_settings(cx);
2138
2139 let project = create_test_project(
2140 cx,
2141 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2142 )
2143 .await;
2144
2145 let (_, _thread_store, thread, _context_store, prompt_builder) =
2146 setup_test_environment(cx, project.clone()).await;
2147
2148 // Insert user message without any context (empty context vector)
2149 let message_id = thread.update(cx, |thread, cx| {
2150 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2151 });
2152
2153 // Check content and context in message object
2154 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2155
2156 // Context should be empty when no files are included
2157 assert_eq!(message.role, Role::User);
2158 assert_eq!(message.segments.len(), 1);
2159 assert_eq!(
2160 message.segments[0],
2161 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2162 );
2163 assert_eq!(message.context, "");
2164
2165 // Check message in request
2166 let request = thread.read_with(cx, |thread, cx| {
2167 thread.to_completion_request(RequestKind::Chat, cx)
2168 });
2169
2170 assert_eq!(request.messages.len(), 1);
2171 let actual_message = request.messages[0].string_contents();
2172 let expected_content = format!(
2173 "What is the best way to learn Rust?{}",
2174 system_prompt_reminder(&prompt_builder)
2175 );
2176
2177 assert_eq!(actual_message, expected_content);
2178
2179 // Add second message, also without context
2180 let message2_id = thread.update(cx, |thread, cx| {
2181 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2182 });
2183
2184 let message2 =
2185 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2186 assert_eq!(message2.context, "");
2187
2188 // Check that both messages appear in the request
2189 let request = thread.read_with(cx, |thread, cx| {
2190 thread.to_completion_request(RequestKind::Chat, cx)
2191 });
2192
2193 assert_eq!(request.messages.len(), 2);
2194 // First message should be the system prompt
2195 assert_eq!(request.messages[0].role, Role::User);
2196
2197 // Second message should be the user message with prompt reminder
2198 let actual_message = request.messages[1].string_contents();
2199 let expected_content = format!(
2200 "Are there any good books?{}",
2201 system_prompt_reminder(&prompt_builder)
2202 );
2203
2204 assert_eq!(actual_message, expected_content);
2205 }
2206
2207 #[gpui::test]
2208 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2209 init_test_settings(cx);
2210
2211 let project = create_test_project(
2212 cx,
2213 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2214 )
2215 .await;
2216
2217 let (_workspace, _thread_store, thread, context_store, prompt_builder) =
2218 setup_test_environment(cx, project.clone()).await;
2219
2220 // Open buffer and add it to context
2221 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2222 .await
2223 .unwrap();
2224
2225 let context =
2226 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2227
2228 // Insert user message with the buffer as context
2229 thread.update(cx, |thread, cx| {
2230 thread.insert_user_message("Explain this code", vec![context], None, cx)
2231 });
2232
2233 // Create a request and check that it doesn't have a stale buffer warning yet
2234 let initial_request = thread.read_with(cx, |thread, cx| {
2235 thread.to_completion_request(RequestKind::Chat, cx)
2236 });
2237
2238 // Make sure we don't have a stale file warning yet
2239 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2240 msg.string_contents()
2241 .contains("These files changed since last read:")
2242 });
2243 assert!(
2244 !has_stale_warning,
2245 "Should not have stale buffer warning before buffer is modified"
2246 );
2247
2248 // Modify the buffer
2249 buffer.update(cx, |buffer, cx| {
2250 // Find a position at the end of line 1
2251 buffer.edit(
2252 [(1..1, "\n println!(\"Added a new line\");\n")],
2253 None,
2254 cx,
2255 );
2256 });
2257
2258 // Insert another user message without context
2259 thread.update(cx, |thread, cx| {
2260 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2261 });
2262
2263 // Create a new request and check for the stale buffer warning
2264 let new_request = thread.read_with(cx, |thread, cx| {
2265 thread.to_completion_request(RequestKind::Chat, cx)
2266 });
2267
2268 // We should have a stale file warning as the last message
2269 let last_message = new_request
2270 .messages
2271 .last()
2272 .expect("Request should have messages");
2273
2274 // The last message should be the stale buffer notification
2275 assert_eq!(last_message.role, Role::User);
2276
2277 let actual_message = last_message.string_contents();
2278 let expected_content = format!(
2279 "These files changed since last read:\n- code.rs\n{}",
2280 system_prompt_reminder(&prompt_builder)
2281 );
2282
2283 assert_eq!(
2284 actual_message, expected_content,
2285 "Last message should be exactly the stale buffer notification"
2286 );
2287 }
2288
2289 fn init_test_settings(cx: &mut TestAppContext) {
2290 cx.update(|cx| {
2291 let settings_store = SettingsStore::test(cx);
2292 cx.set_global(settings_store);
2293 language::init(cx);
2294 Project::init_settings(cx);
2295 AssistantSettings::register(cx);
2296 thread_store::init(cx);
2297 workspace::init_settings(cx);
2298 ThemeSettings::register(cx);
2299 ContextServerSettings::register(cx);
2300 EditorSettings::register(cx);
2301 });
2302 }
2303
2304 // Helper to create a test project with test files
2305 async fn create_test_project(
2306 cx: &mut TestAppContext,
2307 files: serde_json::Value,
2308 ) -> Entity<Project> {
2309 let fs = FakeFs::new(cx.executor());
2310 fs.insert_tree(path!("/test"), files).await;
2311 Project::test(fs, [path!("/test").as_ref()], cx).await
2312 }
2313
2314 async fn setup_test_environment(
2315 cx: &mut TestAppContext,
2316 project: Entity<Project>,
2317 ) -> (
2318 Entity<Workspace>,
2319 Entity<ThreadStore>,
2320 Entity<Thread>,
2321 Entity<ContextStore>,
2322 Arc<PromptBuilder>,
2323 ) {
2324 let (workspace, cx) =
2325 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2326
2327 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2328
2329 let thread_store = cx.update(|_, cx| {
2330 ThreadStore::new(project.clone(), Arc::default(), prompt_builder.clone(), cx).unwrap()
2331 });
2332
2333 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2334 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2335
2336 (
2337 workspace,
2338 thread_store,
2339 thread,
2340 context_store,
2341 prompt_builder,
2342 )
2343 }
2344
2345 async fn add_file_to_context(
2346 project: &Entity<Project>,
2347 context_store: &Entity<ContextStore>,
2348 path: &str,
2349 cx: &mut TestAppContext,
2350 ) -> Result<Entity<language::Buffer>> {
2351 let buffer_path = project
2352 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2353 .unwrap();
2354
2355 let buffer = project
2356 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2357 .await
2358 .unwrap();
2359
2360 context_store
2361 .update(cx, |store, cx| {
2362 store.add_file_from_buffer(buffer.clone(), cx)
2363 })
2364 .await?;
2365
2366 Ok(buffer)
2367 }
2368}