1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(Debug, Clone, Copy)]
44pub enum RequestKind {
45 Chat,
46 /// Used when summarizing a thread.
47 Summarize,
48}
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
74pub struct MessageId(pub(crate) usize);
75
76impl MessageId {
77 fn post_inc(&mut self) -> Self {
78 Self(post_inc(&mut self.0))
79 }
80}
81
82/// A message in a [`Thread`].
83#[derive(Debug, Clone)]
84pub struct Message {
85 pub id: MessageId,
86 pub role: Role,
87 pub segments: Vec<MessageSegment>,
88 pub context: String,
89}
90
91impl Message {
92 /// Returns whether the message contains any meaningful text that should be displayed
93 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
94 pub fn should_display_content(&self) -> bool {
95 self.segments.iter().all(|segment| segment.should_display())
96 }
97
98 pub fn push_thinking(&mut self, text: &str) {
99 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
100 segment.push_str(text);
101 } else {
102 self.segments
103 .push(MessageSegment::Thinking(text.to_string()));
104 }
105 }
106
107 pub fn push_text(&mut self, text: &str) {
108 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
109 segment.push_str(text);
110 } else {
111 self.segments.push(MessageSegment::Text(text.to_string()));
112 }
113 }
114
115 pub fn to_string(&self) -> String {
116 let mut result = String::new();
117
118 if !self.context.is_empty() {
119 result.push_str(&self.context);
120 }
121
122 for segment in &self.segments {
123 match segment {
124 MessageSegment::Text(text) => result.push_str(text),
125 MessageSegment::Thinking(text) => {
126 result.push_str("<think>");
127 result.push_str(text);
128 result.push_str("</think>");
129 }
130 }
131 }
132
133 result
134 }
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum MessageSegment {
139 Text(String),
140 Thinking(String),
141}
142
143impl MessageSegment {
144 pub fn text_mut(&mut self) -> &mut String {
145 match self {
146 Self::Text(text) => text,
147 Self::Thinking(text) => text,
148 }
149 }
150
151 pub fn should_display(&self) -> bool {
152 // We add USING_TOOL_MARKER when making a request that includes tool uses
153 // without non-whitespace text around them, and this can cause the model
154 // to mimic the pattern, so we consider those segments not displayable.
155 match self {
156 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
157 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
158 }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct ProjectSnapshot {
164 pub worktree_snapshots: Vec<WorktreeSnapshot>,
165 pub unsaved_buffer_paths: Vec<String>,
166 pub timestamp: DateTime<Utc>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct WorktreeSnapshot {
171 pub worktree_path: String,
172 pub git_state: Option<GitState>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct GitState {
177 pub remote_url: Option<String>,
178 pub head_sha: Option<String>,
179 pub current_branch: Option<String>,
180 pub diff: Option<String>,
181}
182
183#[derive(Clone)]
184pub struct ThreadCheckpoint {
185 message_id: MessageId,
186 git_checkpoint: GitStoreCheckpoint,
187}
188
189#[derive(Copy, Clone, Debug, PartialEq, Eq)]
190pub enum ThreadFeedback {
191 Positive,
192 Negative,
193}
194
195pub enum LastRestoreCheckpoint {
196 Pending {
197 message_id: MessageId,
198 },
199 Error {
200 message_id: MessageId,
201 error: String,
202 },
203}
204
205impl LastRestoreCheckpoint {
206 pub fn message_id(&self) -> MessageId {
207 match self {
208 LastRestoreCheckpoint::Pending { message_id } => *message_id,
209 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
210 }
211 }
212}
213
214#[derive(Clone, Debug, Default, Serialize, Deserialize)]
215pub enum DetailedSummaryState {
216 #[default]
217 NotGenerated,
218 Generating {
219 message_id: MessageId,
220 },
221 Generated {
222 text: SharedString,
223 message_id: MessageId,
224 },
225}
226
227#[derive(Default)]
228pub struct TotalTokenUsage {
229 pub total: usize,
230 pub max: usize,
231}
232
233impl TotalTokenUsage {
234 pub fn ratio(&self) -> TokenUsageRatio {
235 #[cfg(debug_assertions)]
236 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
237 .unwrap_or("0.8".to_string())
238 .parse()
239 .unwrap();
240 #[cfg(not(debug_assertions))]
241 let warning_threshold: f32 = 0.8;
242
243 if self.total >= self.max {
244 TokenUsageRatio::Exceeded
245 } else if self.total as f32 / self.max as f32 >= warning_threshold {
246 TokenUsageRatio::Warning
247 } else {
248 TokenUsageRatio::Normal
249 }
250 }
251
252 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
253 TotalTokenUsage {
254 total: self.total + tokens,
255 max: self.max,
256 }
257 }
258}
259
260#[derive(Debug, Default, PartialEq, Eq)]
261pub enum TokenUsageRatio {
262 #[default]
263 Normal,
264 Warning,
265 Exceeded,
266}
267
268/// A thread of conversation with the LLM.
269pub struct Thread {
270 id: ThreadId,
271 updated_at: DateTime<Utc>,
272 summary: Option<SharedString>,
273 pending_summary: Task<Option<()>>,
274 detailed_summary_state: DetailedSummaryState,
275 messages: Vec<Message>,
276 next_message_id: MessageId,
277 context: BTreeMap<ContextId, AssistantContext>,
278 context_by_message: HashMap<MessageId, Vec<ContextId>>,
279 project_context: SharedProjectContext,
280 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
281 completion_count: usize,
282 pending_completions: Vec<PendingCompletion>,
283 project: Entity<Project>,
284 prompt_builder: Arc<PromptBuilder>,
285 tools: Entity<ToolWorkingSet>,
286 tool_use: ToolUseState,
287 action_log: Entity<ActionLog>,
288 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
289 pending_checkpoint: Option<ThreadCheckpoint>,
290 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
291 request_token_usage: Vec<TokenUsage>,
292 cumulative_token_usage: TokenUsage,
293 exceeded_window_error: Option<ExceededWindowError>,
294 feedback: Option<ThreadFeedback>,
295 message_feedback: HashMap<MessageId, ThreadFeedback>,
296 last_auto_capture_at: Option<Instant>,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct ExceededWindowError {
301 /// Model used when last message exceeded context window
302 model_id: LanguageModelId,
303 /// Token count including last message
304 token_count: usize,
305}
306
307impl Thread {
308 pub fn new(
309 project: Entity<Project>,
310 tools: Entity<ToolWorkingSet>,
311 prompt_builder: Arc<PromptBuilder>,
312 system_prompt: SharedProjectContext,
313 cx: &mut Context<Self>,
314 ) -> Self {
315 Self {
316 id: ThreadId::new(),
317 updated_at: Utc::now(),
318 summary: None,
319 pending_summary: Task::ready(None),
320 detailed_summary_state: DetailedSummaryState::NotGenerated,
321 messages: Vec::new(),
322 next_message_id: MessageId(0),
323 context: BTreeMap::default(),
324 context_by_message: HashMap::default(),
325 project_context: system_prompt,
326 checkpoints_by_message: HashMap::default(),
327 completion_count: 0,
328 pending_completions: Vec::new(),
329 project: project.clone(),
330 prompt_builder,
331 tools: tools.clone(),
332 last_restore_checkpoint: None,
333 pending_checkpoint: None,
334 tool_use: ToolUseState::new(tools.clone()),
335 action_log: cx.new(|_| ActionLog::new(project.clone())),
336 initial_project_snapshot: {
337 let project_snapshot = Self::project_snapshot(project, cx);
338 cx.foreground_executor()
339 .spawn(async move { Some(project_snapshot.await) })
340 .shared()
341 },
342 request_token_usage: Vec::new(),
343 cumulative_token_usage: TokenUsage::default(),
344 exceeded_window_error: None,
345 feedback: None,
346 message_feedback: HashMap::default(),
347 last_auto_capture_at: None,
348 }
349 }
350
351 pub fn deserialize(
352 id: ThreadId,
353 serialized: SerializedThread,
354 project: Entity<Project>,
355 tools: Entity<ToolWorkingSet>,
356 prompt_builder: Arc<PromptBuilder>,
357 project_context: SharedProjectContext,
358 cx: &mut Context<Self>,
359 ) -> Self {
360 let next_message_id = MessageId(
361 serialized
362 .messages
363 .last()
364 .map(|message| message.id.0 + 1)
365 .unwrap_or(0),
366 );
367 let tool_use =
368 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
369
370 Self {
371 id,
372 updated_at: serialized.updated_at,
373 summary: Some(serialized.summary),
374 pending_summary: Task::ready(None),
375 detailed_summary_state: serialized.detailed_summary_state,
376 messages: serialized
377 .messages
378 .into_iter()
379 .map(|message| Message {
380 id: message.id,
381 role: message.role,
382 segments: message
383 .segments
384 .into_iter()
385 .map(|segment| match segment {
386 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
387 SerializedMessageSegment::Thinking { text } => {
388 MessageSegment::Thinking(text)
389 }
390 })
391 .collect(),
392 context: message.context,
393 })
394 .collect(),
395 next_message_id,
396 context: BTreeMap::default(),
397 context_by_message: HashMap::default(),
398 project_context,
399 checkpoints_by_message: HashMap::default(),
400 completion_count: 0,
401 pending_completions: Vec::new(),
402 last_restore_checkpoint: None,
403 pending_checkpoint: None,
404 project: project.clone(),
405 prompt_builder,
406 tools,
407 tool_use,
408 action_log: cx.new(|_| ActionLog::new(project)),
409 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
410 request_token_usage: serialized.request_token_usage,
411 cumulative_token_usage: serialized.cumulative_token_usage,
412 exceeded_window_error: None,
413 feedback: None,
414 message_feedback: HashMap::default(),
415 last_auto_capture_at: None,
416 }
417 }
418
419 pub fn id(&self) -> &ThreadId {
420 &self.id
421 }
422
423 pub fn is_empty(&self) -> bool {
424 self.messages.is_empty()
425 }
426
427 pub fn updated_at(&self) -> DateTime<Utc> {
428 self.updated_at
429 }
430
431 pub fn touch_updated_at(&mut self) {
432 self.updated_at = Utc::now();
433 }
434
435 pub fn summary(&self) -> Option<SharedString> {
436 self.summary.clone()
437 }
438
439 pub fn project_context(&self) -> SharedProjectContext {
440 self.project_context.clone()
441 }
442
443 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
444
445 pub fn summary_or_default(&self) -> SharedString {
446 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
447 }
448
449 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
450 let Some(current_summary) = &self.summary else {
451 // Don't allow setting summary until generated
452 return;
453 };
454
455 let mut new_summary = new_summary.into();
456
457 if new_summary.is_empty() {
458 new_summary = Self::DEFAULT_SUMMARY;
459 }
460
461 if current_summary != &new_summary {
462 self.summary = Some(new_summary);
463 cx.emit(ThreadEvent::SummaryChanged);
464 }
465 }
466
467 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
468 self.latest_detailed_summary()
469 .unwrap_or_else(|| self.text().into())
470 }
471
472 fn latest_detailed_summary(&self) -> Option<SharedString> {
473 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
474 Some(text.clone())
475 } else {
476 None
477 }
478 }
479
480 pub fn message(&self, id: MessageId) -> Option<&Message> {
481 self.messages.iter().find(|message| message.id == id)
482 }
483
484 pub fn messages(&self) -> impl Iterator<Item = &Message> {
485 self.messages.iter()
486 }
487
488 pub fn is_generating(&self) -> bool {
489 !self.pending_completions.is_empty() || !self.all_tools_finished()
490 }
491
492 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
493 &self.tools
494 }
495
496 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
497 self.tool_use
498 .pending_tool_uses()
499 .into_iter()
500 .find(|tool_use| &tool_use.id == id)
501 }
502
503 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
504 self.tool_use
505 .pending_tool_uses()
506 .into_iter()
507 .filter(|tool_use| tool_use.status.needs_confirmation())
508 }
509
510 pub fn has_pending_tool_uses(&self) -> bool {
511 !self.tool_use.pending_tool_uses().is_empty()
512 }
513
514 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
515 self.checkpoints_by_message.get(&id).cloned()
516 }
517
518 pub fn restore_checkpoint(
519 &mut self,
520 checkpoint: ThreadCheckpoint,
521 cx: &mut Context<Self>,
522 ) -> Task<Result<()>> {
523 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
524 message_id: checkpoint.message_id,
525 });
526 cx.emit(ThreadEvent::CheckpointChanged);
527 cx.notify();
528
529 let git_store = self.project().read(cx).git_store().clone();
530 let restore = git_store.update(cx, |git_store, cx| {
531 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
532 });
533
534 cx.spawn(async move |this, cx| {
535 let result = restore.await;
536 this.update(cx, |this, cx| {
537 if let Err(err) = result.as_ref() {
538 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
539 message_id: checkpoint.message_id,
540 error: err.to_string(),
541 });
542 } else {
543 this.truncate(checkpoint.message_id, cx);
544 this.last_restore_checkpoint = None;
545 }
546 this.pending_checkpoint = None;
547 cx.emit(ThreadEvent::CheckpointChanged);
548 cx.notify();
549 })?;
550 result
551 })
552 }
553
554 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
555 let pending_checkpoint = if self.is_generating() {
556 return;
557 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
558 checkpoint
559 } else {
560 return;
561 };
562
563 let git_store = self.project.read(cx).git_store().clone();
564 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
565 cx.spawn(async move |this, cx| match final_checkpoint.await {
566 Ok(final_checkpoint) => {
567 let equal = git_store
568 .update(cx, |store, cx| {
569 store.compare_checkpoints(
570 pending_checkpoint.git_checkpoint.clone(),
571 final_checkpoint.clone(),
572 cx,
573 )
574 })?
575 .await
576 .unwrap_or(false);
577
578 if equal {
579 git_store
580 .update(cx, |store, cx| {
581 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
582 })?
583 .detach();
584 } else {
585 this.update(cx, |this, cx| {
586 this.insert_checkpoint(pending_checkpoint, cx)
587 })?;
588 }
589
590 git_store
591 .update(cx, |store, cx| {
592 store.delete_checkpoint(final_checkpoint, cx)
593 })?
594 .detach();
595
596 Ok(())
597 }
598 Err(_) => this.update(cx, |this, cx| {
599 this.insert_checkpoint(pending_checkpoint, cx)
600 }),
601 })
602 .detach();
603 }
604
605 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
606 self.checkpoints_by_message
607 .insert(checkpoint.message_id, checkpoint);
608 cx.emit(ThreadEvent::CheckpointChanged);
609 cx.notify();
610 }
611
612 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
613 self.last_restore_checkpoint.as_ref()
614 }
615
616 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
617 let Some(message_ix) = self
618 .messages
619 .iter()
620 .rposition(|message| message.id == message_id)
621 else {
622 return;
623 };
624 for deleted_message in self.messages.drain(message_ix..) {
625 self.context_by_message.remove(&deleted_message.id);
626 self.checkpoints_by_message.remove(&deleted_message.id);
627 }
628 cx.notify();
629 }
630
631 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
632 self.context_by_message
633 .get(&id)
634 .into_iter()
635 .flat_map(|context| {
636 context
637 .iter()
638 .filter_map(|context_id| self.context.get(&context_id))
639 })
640 }
641
642 /// Returns whether all of the tool uses have finished running.
643 pub fn all_tools_finished(&self) -> bool {
644 // If the only pending tool uses left are the ones with errors, then
645 // that means that we've finished running all of the pending tools.
646 self.tool_use
647 .pending_tool_uses()
648 .iter()
649 .all(|tool_use| tool_use.status.is_error())
650 }
651
652 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
653 self.tool_use.tool_uses_for_message(id, cx)
654 }
655
656 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
657 self.tool_use.tool_results_for_message(id)
658 }
659
660 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
661 self.tool_use.tool_result(id)
662 }
663
664 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
665 Some(&self.tool_use.tool_result(id)?.content)
666 }
667
668 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
669 self.tool_use.tool_result_card(id).cloned()
670 }
671
672 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
673 self.tool_use.message_has_tool_results(message_id)
674 }
675
676 /// Filter out contexts that have already been included in previous messages
677 pub fn filter_new_context<'a>(
678 &self,
679 context: impl Iterator<Item = &'a AssistantContext>,
680 ) -> impl Iterator<Item = &'a AssistantContext> {
681 context.filter(|ctx| self.is_context_new(ctx))
682 }
683
684 fn is_context_new(&self, context: &AssistantContext) -> bool {
685 !self.context.contains_key(&context.id())
686 }
687
688 pub fn insert_user_message(
689 &mut self,
690 text: impl Into<String>,
691 context: Vec<AssistantContext>,
692 git_checkpoint: Option<GitStoreCheckpoint>,
693 cx: &mut Context<Self>,
694 ) -> MessageId {
695 let text = text.into();
696
697 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
698
699 let new_context: Vec<_> = context
700 .into_iter()
701 .filter(|ctx| self.is_context_new(ctx))
702 .collect();
703
704 if !new_context.is_empty() {
705 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
706 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
707 message.context = context_string;
708 }
709 }
710
711 self.action_log.update(cx, |log, cx| {
712 // Track all buffers added as context
713 for ctx in &new_context {
714 match ctx {
715 AssistantContext::File(file_ctx) => {
716 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
717 }
718 AssistantContext::Directory(dir_ctx) => {
719 for context_buffer in &dir_ctx.context_buffers {
720 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
721 }
722 }
723 AssistantContext::Symbol(symbol_ctx) => {
724 log.buffer_added_as_context(
725 symbol_ctx.context_symbol.buffer.clone(),
726 cx,
727 );
728 }
729 AssistantContext::Excerpt(excerpt_context) => {
730 log.buffer_added_as_context(
731 excerpt_context.context_buffer.buffer.clone(),
732 cx,
733 );
734 }
735 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
736 }
737 }
738 });
739 }
740
741 let context_ids = new_context
742 .iter()
743 .map(|context| context.id())
744 .collect::<Vec<_>>();
745 self.context.extend(
746 new_context
747 .into_iter()
748 .map(|context| (context.id(), context)),
749 );
750 self.context_by_message.insert(message_id, context_ids);
751
752 if let Some(git_checkpoint) = git_checkpoint {
753 self.pending_checkpoint = Some(ThreadCheckpoint {
754 message_id,
755 git_checkpoint,
756 });
757 }
758
759 self.auto_capture_telemetry(cx);
760
761 message_id
762 }
763
764 pub fn insert_message(
765 &mut self,
766 role: Role,
767 segments: Vec<MessageSegment>,
768 cx: &mut Context<Self>,
769 ) -> MessageId {
770 let id = self.next_message_id.post_inc();
771 self.messages.push(Message {
772 id,
773 role,
774 segments,
775 context: String::new(),
776 });
777 self.touch_updated_at();
778 cx.emit(ThreadEvent::MessageAdded(id));
779 id
780 }
781
782 pub fn edit_message(
783 &mut self,
784 id: MessageId,
785 new_role: Role,
786 new_segments: Vec<MessageSegment>,
787 cx: &mut Context<Self>,
788 ) -> bool {
789 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
790 return false;
791 };
792 message.role = new_role;
793 message.segments = new_segments;
794 self.touch_updated_at();
795 cx.emit(ThreadEvent::MessageEdited(id));
796 true
797 }
798
799 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
800 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
801 return false;
802 };
803 self.messages.remove(index);
804 self.context_by_message.remove(&id);
805 self.touch_updated_at();
806 cx.emit(ThreadEvent::MessageDeleted(id));
807 true
808 }
809
810 /// Returns the representation of this [`Thread`] in a textual form.
811 ///
812 /// This is the representation we use when attaching a thread as context to another thread.
813 pub fn text(&self) -> String {
814 let mut text = String::new();
815
816 for message in &self.messages {
817 text.push_str(match message.role {
818 language_model::Role::User => "User:",
819 language_model::Role::Assistant => "Assistant:",
820 language_model::Role::System => "System:",
821 });
822 text.push('\n');
823
824 for segment in &message.segments {
825 match segment {
826 MessageSegment::Text(content) => text.push_str(content),
827 MessageSegment::Thinking(content) => {
828 text.push_str(&format!("<think>{}</think>", content))
829 }
830 }
831 }
832 text.push('\n');
833 }
834
835 text
836 }
837
838 /// Serializes this thread into a format for storage or telemetry.
839 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
840 let initial_project_snapshot = self.initial_project_snapshot.clone();
841 cx.spawn(async move |this, cx| {
842 let initial_project_snapshot = initial_project_snapshot.await;
843 this.read_with(cx, |this, cx| SerializedThread {
844 version: SerializedThread::VERSION.to_string(),
845 summary: this.summary_or_default(),
846 updated_at: this.updated_at(),
847 messages: this
848 .messages()
849 .map(|message| SerializedMessage {
850 id: message.id,
851 role: message.role,
852 segments: message
853 .segments
854 .iter()
855 .map(|segment| match segment {
856 MessageSegment::Text(text) => {
857 SerializedMessageSegment::Text { text: text.clone() }
858 }
859 MessageSegment::Thinking(text) => {
860 SerializedMessageSegment::Thinking { text: text.clone() }
861 }
862 })
863 .collect(),
864 tool_uses: this
865 .tool_uses_for_message(message.id, cx)
866 .into_iter()
867 .map(|tool_use| SerializedToolUse {
868 id: tool_use.id,
869 name: tool_use.name,
870 input: tool_use.input,
871 })
872 .collect(),
873 tool_results: this
874 .tool_results_for_message(message.id)
875 .into_iter()
876 .map(|tool_result| SerializedToolResult {
877 tool_use_id: tool_result.tool_use_id.clone(),
878 is_error: tool_result.is_error,
879 content: tool_result.content.clone(),
880 })
881 .collect(),
882 context: message.context.clone(),
883 })
884 .collect(),
885 initial_project_snapshot,
886 cumulative_token_usage: this.cumulative_token_usage,
887 request_token_usage: this.request_token_usage.clone(),
888 detailed_summary_state: this.detailed_summary_state.clone(),
889 exceeded_window_error: this.exceeded_window_error.clone(),
890 })
891 })
892 }
893
894 pub fn send_to_model(
895 &mut self,
896 model: Arc<dyn LanguageModel>,
897 request_kind: RequestKind,
898 cx: &mut Context<Self>,
899 ) {
900 let mut request = self.to_completion_request(request_kind, cx);
901 if model.supports_tools() {
902 request.tools = {
903 let mut tools = Vec::new();
904 tools.extend(
905 self.tools()
906 .read(cx)
907 .enabled_tools(cx)
908 .into_iter()
909 .filter_map(|tool| {
910 // Skip tools that cannot be supported
911 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
912 Some(LanguageModelRequestTool {
913 name: tool.name(),
914 description: tool.description(),
915 input_schema,
916 })
917 }),
918 );
919
920 tools
921 };
922 }
923
924 self.stream_completion(request, model, cx);
925 }
926
927 pub fn used_tools_since_last_user_message(&self) -> bool {
928 for message in self.messages.iter().rev() {
929 if self.tool_use.message_has_tool_results(message.id) {
930 return true;
931 } else if message.role == Role::User {
932 return false;
933 }
934 }
935
936 false
937 }
938
939 pub fn to_completion_request(
940 &self,
941 request_kind: RequestKind,
942 cx: &App,
943 ) -> LanguageModelRequest {
944 let mut request = LanguageModelRequest {
945 messages: vec![],
946 tools: Vec::new(),
947 stop: Vec::new(),
948 temperature: None,
949 };
950
951 if let Some(project_context) = self.project_context.borrow().as_ref() {
952 if let Some(system_prompt) = self
953 .prompt_builder
954 .generate_assistant_system_prompt(project_context)
955 .context("failed to generate assistant system prompt")
956 .log_err()
957 {
958 request.messages.push(LanguageModelRequestMessage {
959 role: Role::System,
960 content: vec![MessageContent::Text(system_prompt)],
961 cache: true,
962 });
963 }
964 } else {
965 log::error!("project_context not set.")
966 }
967
968 for message in &self.messages {
969 let mut request_message = LanguageModelRequestMessage {
970 role: message.role,
971 content: Vec::new(),
972 cache: false,
973 };
974
975 match request_kind {
976 RequestKind::Chat => {
977 self.tool_use
978 .attach_tool_results(message.id, &mut request_message);
979 }
980 RequestKind::Summarize => {
981 // We don't care about tool use during summarization.
982 if self.tool_use.message_has_tool_results(message.id) {
983 continue;
984 }
985 }
986 }
987
988 if !message.segments.is_empty() {
989 request_message
990 .content
991 .push(MessageContent::Text(message.to_string()));
992 }
993
994 match request_kind {
995 RequestKind::Chat => {
996 self.tool_use
997 .attach_tool_uses(message.id, &mut request_message);
998 }
999 RequestKind::Summarize => {
1000 // We don't care about tool use during summarization.
1001 }
1002 };
1003
1004 request.messages.push(request_message);
1005 }
1006
1007 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1008 if let Some(last) = request.messages.last_mut() {
1009 last.cache = true;
1010 }
1011
1012 self.attached_tracked_files_state(&mut request.messages, cx);
1013
1014 request
1015 }
1016
1017 fn attached_tracked_files_state(
1018 &self,
1019 messages: &mut Vec<LanguageModelRequestMessage>,
1020 cx: &App,
1021 ) {
1022 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1023
1024 let mut stale_message = String::new();
1025
1026 let action_log = self.action_log.read(cx);
1027
1028 for stale_file in action_log.stale_buffers(cx) {
1029 let Some(file) = stale_file.read(cx).file() else {
1030 continue;
1031 };
1032
1033 if stale_message.is_empty() {
1034 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1035 }
1036
1037 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1038 }
1039
1040 let mut content = Vec::with_capacity(2);
1041
1042 if !stale_message.is_empty() {
1043 content.push(stale_message.into());
1044 }
1045
1046 if action_log.has_edited_files_since_project_diagnostics_check() {
1047 content.push(
1048 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1049 and fix all errors AND warnings you introduced! \
1050 DO NOT mention you're going to do this until you're done."
1051 .into(),
1052 );
1053 }
1054
1055 if !content.is_empty() {
1056 let context_message = LanguageModelRequestMessage {
1057 role: Role::User,
1058 content,
1059 cache: false,
1060 };
1061
1062 messages.push(context_message);
1063 }
1064 }
1065
1066 pub fn stream_completion(
1067 &mut self,
1068 request: LanguageModelRequest,
1069 model: Arc<dyn LanguageModel>,
1070 cx: &mut Context<Self>,
1071 ) {
1072 let pending_completion_id = post_inc(&mut self.completion_count);
1073 let task = cx.spawn(async move |thread, cx| {
1074 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1075 let initial_token_usage =
1076 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1077 let stream_completion = async {
1078 let (mut events, usage) = stream_completion_future.await?;
1079 let mut stop_reason = StopReason::EndTurn;
1080 let mut current_token_usage = TokenUsage::default();
1081
1082 if let Some(usage) = usage {
1083 thread
1084 .update(cx, |_thread, cx| {
1085 cx.emit(ThreadEvent::UsageUpdated(usage));
1086 })
1087 .ok();
1088 }
1089
1090 while let Some(event) = events.next().await {
1091 let event = event?;
1092
1093 thread.update(cx, |thread, cx| {
1094 match event {
1095 LanguageModelCompletionEvent::StartMessage { .. } => {
1096 thread.insert_message(
1097 Role::Assistant,
1098 vec![MessageSegment::Text(String::new())],
1099 cx,
1100 );
1101 }
1102 LanguageModelCompletionEvent::Stop(reason) => {
1103 stop_reason = reason;
1104 }
1105 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1106 thread.update_token_usage_at_last_message(token_usage);
1107 thread.cumulative_token_usage = thread.cumulative_token_usage
1108 + token_usage
1109 - current_token_usage;
1110 current_token_usage = token_usage;
1111 }
1112 LanguageModelCompletionEvent::Text(chunk) => {
1113 if let Some(last_message) = thread.messages.last_mut() {
1114 if last_message.role == Role::Assistant {
1115 last_message.push_text(&chunk);
1116 cx.emit(ThreadEvent::StreamedAssistantText(
1117 last_message.id,
1118 chunk,
1119 ));
1120 } else {
1121 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1122 // of a new Assistant response.
1123 //
1124 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1125 // will result in duplicating the text of the chunk in the rendered Markdown.
1126 thread.insert_message(
1127 Role::Assistant,
1128 vec![MessageSegment::Text(chunk.to_string())],
1129 cx,
1130 );
1131 };
1132 }
1133 }
1134 LanguageModelCompletionEvent::Thinking(chunk) => {
1135 if let Some(last_message) = thread.messages.last_mut() {
1136 if last_message.role == Role::Assistant {
1137 last_message.push_thinking(&chunk);
1138 cx.emit(ThreadEvent::StreamedAssistantThinking(
1139 last_message.id,
1140 chunk,
1141 ));
1142 } else {
1143 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1144 // of a new Assistant response.
1145 //
1146 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1147 // will result in duplicating the text of the chunk in the rendered Markdown.
1148 thread.insert_message(
1149 Role::Assistant,
1150 vec![MessageSegment::Thinking(chunk.to_string())],
1151 cx,
1152 );
1153 };
1154 }
1155 }
1156 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1157 let last_assistant_message_id = thread
1158 .messages
1159 .iter_mut()
1160 .rfind(|message| message.role == Role::Assistant)
1161 .map(|message| message.id)
1162 .unwrap_or_else(|| {
1163 thread.insert_message(Role::Assistant, vec![], cx)
1164 });
1165
1166 thread.tool_use.request_tool_use(
1167 last_assistant_message_id,
1168 tool_use,
1169 cx,
1170 );
1171 }
1172 }
1173
1174 thread.touch_updated_at();
1175 cx.emit(ThreadEvent::StreamedCompletion);
1176 cx.notify();
1177
1178 thread.auto_capture_telemetry(cx);
1179 })?;
1180
1181 smol::future::yield_now().await;
1182 }
1183
1184 thread.update(cx, |thread, cx| {
1185 thread
1186 .pending_completions
1187 .retain(|completion| completion.id != pending_completion_id);
1188
1189 if thread.summary.is_none() && thread.messages.len() >= 2 {
1190 thread.summarize(cx);
1191 }
1192 })?;
1193
1194 anyhow::Ok(stop_reason)
1195 };
1196
1197 let result = stream_completion.await;
1198
1199 thread
1200 .update(cx, |thread, cx| {
1201 thread.finalize_pending_checkpoint(cx);
1202 match result.as_ref() {
1203 Ok(stop_reason) => match stop_reason {
1204 StopReason::ToolUse => {
1205 let tool_uses = thread.use_pending_tools(cx);
1206 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1207 }
1208 StopReason::EndTurn => {}
1209 StopReason::MaxTokens => {}
1210 },
1211 Err(error) => {
1212 if error.is::<PaymentRequiredError>() {
1213 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1214 } else if error.is::<MaxMonthlySpendReachedError>() {
1215 cx.emit(ThreadEvent::ShowError(
1216 ThreadError::MaxMonthlySpendReached,
1217 ));
1218 } else if let Some(error) =
1219 error.downcast_ref::<ModelRequestLimitReachedError>()
1220 {
1221 cx.emit(ThreadEvent::ShowError(
1222 ThreadError::ModelRequestLimitReached { plan: error.plan },
1223 ));
1224 } else if let Some(known_error) =
1225 error.downcast_ref::<LanguageModelKnownError>()
1226 {
1227 match known_error {
1228 LanguageModelKnownError::ContextWindowLimitExceeded {
1229 tokens,
1230 } => {
1231 thread.exceeded_window_error = Some(ExceededWindowError {
1232 model_id: model.id(),
1233 token_count: *tokens,
1234 });
1235 cx.notify();
1236 }
1237 }
1238 } else {
1239 let error_message = error
1240 .chain()
1241 .map(|err| err.to_string())
1242 .collect::<Vec<_>>()
1243 .join("\n");
1244 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1245 header: "Error interacting with language model".into(),
1246 message: SharedString::from(error_message.clone()),
1247 }));
1248 }
1249
1250 thread.cancel_last_completion(cx);
1251 }
1252 }
1253 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1254
1255 thread.auto_capture_telemetry(cx);
1256
1257 if let Ok(initial_usage) = initial_token_usage {
1258 let usage = thread.cumulative_token_usage - initial_usage;
1259
1260 telemetry::event!(
1261 "Assistant Thread Completion",
1262 thread_id = thread.id().to_string(),
1263 model = model.telemetry_id(),
1264 model_provider = model.provider_id().to_string(),
1265 input_tokens = usage.input_tokens,
1266 output_tokens = usage.output_tokens,
1267 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1268 cache_read_input_tokens = usage.cache_read_input_tokens,
1269 );
1270 }
1271 })
1272 .ok();
1273 });
1274
1275 self.pending_completions.push(PendingCompletion {
1276 id: pending_completion_id,
1277 _task: task,
1278 });
1279 }
1280
1281 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1282 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1283 return;
1284 };
1285
1286 if !model.provider.is_authenticated(cx) {
1287 return;
1288 }
1289
1290 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1291 request.messages.push(LanguageModelRequestMessage {
1292 role: Role::User,
1293 content: vec![
1294 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1295 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1296 If the conversation is about a specific subject, include it in the title. \
1297 Be descriptive. DO NOT speak in the first person."
1298 .into(),
1299 ],
1300 cache: false,
1301 });
1302
1303 self.pending_summary = cx.spawn(async move |this, cx| {
1304 async move {
1305 let stream = model.model.stream_completion_text(request, &cx);
1306 let mut messages = stream.await?;
1307
1308 let mut new_summary = String::new();
1309 while let Some(message) = messages.stream.next().await {
1310 let text = message?;
1311 let mut lines = text.lines();
1312 new_summary.extend(lines.next());
1313
1314 // Stop if the LLM generated multiple lines.
1315 if lines.next().is_some() {
1316 break;
1317 }
1318 }
1319
1320 this.update(cx, |this, cx| {
1321 if !new_summary.is_empty() {
1322 this.summary = Some(new_summary.into());
1323 }
1324
1325 cx.emit(ThreadEvent::SummaryGenerated);
1326 })?;
1327
1328 anyhow::Ok(())
1329 }
1330 .log_err()
1331 .await
1332 });
1333 }
1334
1335 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1336 let last_message_id = self.messages.last().map(|message| message.id)?;
1337
1338 match &self.detailed_summary_state {
1339 DetailedSummaryState::Generating { message_id, .. }
1340 | DetailedSummaryState::Generated { message_id, .. }
1341 if *message_id == last_message_id =>
1342 {
1343 // Already up-to-date
1344 return None;
1345 }
1346 _ => {}
1347 }
1348
1349 let ConfiguredModel { model, provider } =
1350 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1351
1352 if !provider.is_authenticated(cx) {
1353 return None;
1354 }
1355
1356 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1357
1358 request.messages.push(LanguageModelRequestMessage {
1359 role: Role::User,
1360 content: vec![
1361 "Generate a detailed summary of this conversation. Include:\n\
1362 1. A brief overview of what was discussed\n\
1363 2. Key facts or information discovered\n\
1364 3. Outcomes or conclusions reached\n\
1365 4. Any action items or next steps if any\n\
1366 Format it in Markdown with headings and bullet points."
1367 .into(),
1368 ],
1369 cache: false,
1370 });
1371
1372 let task = cx.spawn(async move |thread, cx| {
1373 let stream = model.stream_completion_text(request, &cx);
1374 let Some(mut messages) = stream.await.log_err() else {
1375 thread
1376 .update(cx, |this, _cx| {
1377 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1378 })
1379 .log_err();
1380
1381 return;
1382 };
1383
1384 let mut new_detailed_summary = String::new();
1385
1386 while let Some(chunk) = messages.stream.next().await {
1387 if let Some(chunk) = chunk.log_err() {
1388 new_detailed_summary.push_str(&chunk);
1389 }
1390 }
1391
1392 thread
1393 .update(cx, |this, _cx| {
1394 this.detailed_summary_state = DetailedSummaryState::Generated {
1395 text: new_detailed_summary.into(),
1396 message_id: last_message_id,
1397 };
1398 })
1399 .log_err();
1400 });
1401
1402 self.detailed_summary_state = DetailedSummaryState::Generating {
1403 message_id: last_message_id,
1404 };
1405
1406 Some(task)
1407 }
1408
1409 pub fn is_generating_detailed_summary(&self) -> bool {
1410 matches!(
1411 self.detailed_summary_state,
1412 DetailedSummaryState::Generating { .. }
1413 )
1414 }
1415
1416 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1417 self.auto_capture_telemetry(cx);
1418 let request = self.to_completion_request(RequestKind::Chat, cx);
1419 let messages = Arc::new(request.messages);
1420 let pending_tool_uses = self
1421 .tool_use
1422 .pending_tool_uses()
1423 .into_iter()
1424 .filter(|tool_use| tool_use.status.is_idle())
1425 .cloned()
1426 .collect::<Vec<_>>();
1427
1428 for tool_use in pending_tool_uses.iter() {
1429 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1430 if tool.needs_confirmation(&tool_use.input, cx)
1431 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1432 {
1433 self.tool_use.confirm_tool_use(
1434 tool_use.id.clone(),
1435 tool_use.ui_text.clone(),
1436 tool_use.input.clone(),
1437 messages.clone(),
1438 tool,
1439 );
1440 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1441 } else {
1442 self.run_tool(
1443 tool_use.id.clone(),
1444 tool_use.ui_text.clone(),
1445 tool_use.input.clone(),
1446 &messages,
1447 tool,
1448 cx,
1449 );
1450 }
1451 }
1452 }
1453
1454 pending_tool_uses
1455 }
1456
1457 pub fn run_tool(
1458 &mut self,
1459 tool_use_id: LanguageModelToolUseId,
1460 ui_text: impl Into<SharedString>,
1461 input: serde_json::Value,
1462 messages: &[LanguageModelRequestMessage],
1463 tool: Arc<dyn Tool>,
1464 cx: &mut Context<Thread>,
1465 ) {
1466 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1467 self.tool_use
1468 .run_pending_tool(tool_use_id, ui_text.into(), task);
1469 }
1470
1471 fn spawn_tool_use(
1472 &mut self,
1473 tool_use_id: LanguageModelToolUseId,
1474 messages: &[LanguageModelRequestMessage],
1475 input: serde_json::Value,
1476 tool: Arc<dyn Tool>,
1477 cx: &mut Context<Thread>,
1478 ) -> Task<()> {
1479 let tool_name: Arc<str> = tool.name().into();
1480
1481 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1482 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1483 } else {
1484 tool.run(
1485 input,
1486 messages,
1487 self.project.clone(),
1488 self.action_log.clone(),
1489 cx,
1490 )
1491 };
1492
1493 // Store the card separately if it exists
1494 if let Some(card) = tool_result.card.clone() {
1495 self.tool_use
1496 .insert_tool_result_card(tool_use_id.clone(), card);
1497 }
1498
1499 cx.spawn({
1500 async move |thread: WeakEntity<Thread>, cx| {
1501 let output = tool_result.output.await;
1502
1503 thread
1504 .update(cx, |thread, cx| {
1505 let pending_tool_use = thread.tool_use.insert_tool_output(
1506 tool_use_id.clone(),
1507 tool_name,
1508 output,
1509 cx,
1510 );
1511 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1512 })
1513 .ok();
1514 }
1515 })
1516 }
1517
1518 fn tool_finished(
1519 &mut self,
1520 tool_use_id: LanguageModelToolUseId,
1521 pending_tool_use: Option<PendingToolUse>,
1522 canceled: bool,
1523 cx: &mut Context<Self>,
1524 ) {
1525 if self.all_tools_finished() {
1526 let model_registry = LanguageModelRegistry::read_global(cx);
1527 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1528 self.attach_tool_results(cx);
1529 if !canceled {
1530 self.send_to_model(model, RequestKind::Chat, cx);
1531 }
1532 }
1533 }
1534
1535 cx.emit(ThreadEvent::ToolFinished {
1536 tool_use_id,
1537 pending_tool_use,
1538 });
1539 }
1540
1541 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1542 // Insert a user message to contain the tool results.
1543 self.insert_user_message(
1544 // TODO: Sending up a user message without any content results in the model sending back
1545 // responses that also don't have any content. We currently don't handle this case well,
1546 // so for now we provide some text to keep the model on track.
1547 "Here are the tool results.",
1548 Vec::new(),
1549 None,
1550 cx,
1551 );
1552 }
1553
1554 /// Cancels the last pending completion, if there are any pending.
1555 ///
1556 /// Returns whether a completion was canceled.
1557 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1558 let canceled = if self.pending_completions.pop().is_some() {
1559 true
1560 } else {
1561 let mut canceled = false;
1562 for pending_tool_use in self.tool_use.cancel_pending() {
1563 canceled = true;
1564 self.tool_finished(
1565 pending_tool_use.id.clone(),
1566 Some(pending_tool_use),
1567 true,
1568 cx,
1569 );
1570 }
1571 canceled
1572 };
1573 self.finalize_pending_checkpoint(cx);
1574 canceled
1575 }
1576
1577 pub fn feedback(&self) -> Option<ThreadFeedback> {
1578 self.feedback
1579 }
1580
1581 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1582 self.message_feedback.get(&message_id).copied()
1583 }
1584
1585 pub fn report_message_feedback(
1586 &mut self,
1587 message_id: MessageId,
1588 feedback: ThreadFeedback,
1589 cx: &mut Context<Self>,
1590 ) -> Task<Result<()>> {
1591 if self.message_feedback.get(&message_id) == Some(&feedback) {
1592 return Task::ready(Ok(()));
1593 }
1594
1595 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1596 let serialized_thread = self.serialize(cx);
1597 let thread_id = self.id().clone();
1598 let client = self.project.read(cx).client();
1599
1600 let enabled_tool_names: Vec<String> = self
1601 .tools()
1602 .read(cx)
1603 .enabled_tools(cx)
1604 .iter()
1605 .map(|tool| tool.name().to_string())
1606 .collect();
1607
1608 self.message_feedback.insert(message_id, feedback);
1609
1610 cx.notify();
1611
1612 let message_content = self
1613 .message(message_id)
1614 .map(|msg| msg.to_string())
1615 .unwrap_or_default();
1616
1617 cx.background_spawn(async move {
1618 let final_project_snapshot = final_project_snapshot.await;
1619 let serialized_thread = serialized_thread.await?;
1620 let thread_data =
1621 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1622
1623 let rating = match feedback {
1624 ThreadFeedback::Positive => "positive",
1625 ThreadFeedback::Negative => "negative",
1626 };
1627 telemetry::event!(
1628 "Assistant Thread Rated",
1629 rating,
1630 thread_id,
1631 enabled_tool_names,
1632 message_id = message_id.0,
1633 message_content,
1634 thread_data,
1635 final_project_snapshot
1636 );
1637 client.telemetry().flush_events();
1638
1639 Ok(())
1640 })
1641 }
1642
1643 pub fn report_feedback(
1644 &mut self,
1645 feedback: ThreadFeedback,
1646 cx: &mut Context<Self>,
1647 ) -> Task<Result<()>> {
1648 let last_assistant_message_id = self
1649 .messages
1650 .iter()
1651 .rev()
1652 .find(|msg| msg.role == Role::Assistant)
1653 .map(|msg| msg.id);
1654
1655 if let Some(message_id) = last_assistant_message_id {
1656 self.report_message_feedback(message_id, feedback, cx)
1657 } else {
1658 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1659 let serialized_thread = self.serialize(cx);
1660 let thread_id = self.id().clone();
1661 let client = self.project.read(cx).client();
1662 self.feedback = Some(feedback);
1663 cx.notify();
1664
1665 cx.background_spawn(async move {
1666 let final_project_snapshot = final_project_snapshot.await;
1667 let serialized_thread = serialized_thread.await?;
1668 let thread_data = serde_json::to_value(serialized_thread)
1669 .unwrap_or_else(|_| serde_json::Value::Null);
1670
1671 let rating = match feedback {
1672 ThreadFeedback::Positive => "positive",
1673 ThreadFeedback::Negative => "negative",
1674 };
1675 telemetry::event!(
1676 "Assistant Thread Rated",
1677 rating,
1678 thread_id,
1679 thread_data,
1680 final_project_snapshot
1681 );
1682 client.telemetry().flush_events();
1683
1684 Ok(())
1685 })
1686 }
1687 }
1688
1689 /// Create a snapshot of the current project state including git information and unsaved buffers.
1690 fn project_snapshot(
1691 project: Entity<Project>,
1692 cx: &mut Context<Self>,
1693 ) -> Task<Arc<ProjectSnapshot>> {
1694 let git_store = project.read(cx).git_store().clone();
1695 let worktree_snapshots: Vec<_> = project
1696 .read(cx)
1697 .visible_worktrees(cx)
1698 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1699 .collect();
1700
1701 cx.spawn(async move |_, cx| {
1702 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1703
1704 let mut unsaved_buffers = Vec::new();
1705 cx.update(|app_cx| {
1706 let buffer_store = project.read(app_cx).buffer_store();
1707 for buffer_handle in buffer_store.read(app_cx).buffers() {
1708 let buffer = buffer_handle.read(app_cx);
1709 if buffer.is_dirty() {
1710 if let Some(file) = buffer.file() {
1711 let path = file.path().to_string_lossy().to_string();
1712 unsaved_buffers.push(path);
1713 }
1714 }
1715 }
1716 })
1717 .ok();
1718
1719 Arc::new(ProjectSnapshot {
1720 worktree_snapshots,
1721 unsaved_buffer_paths: unsaved_buffers,
1722 timestamp: Utc::now(),
1723 })
1724 })
1725 }
1726
1727 fn worktree_snapshot(
1728 worktree: Entity<project::Worktree>,
1729 git_store: Entity<GitStore>,
1730 cx: &App,
1731 ) -> Task<WorktreeSnapshot> {
1732 cx.spawn(async move |cx| {
1733 // Get worktree path and snapshot
1734 let worktree_info = cx.update(|app_cx| {
1735 let worktree = worktree.read(app_cx);
1736 let path = worktree.abs_path().to_string_lossy().to_string();
1737 let snapshot = worktree.snapshot();
1738 (path, snapshot)
1739 });
1740
1741 let Ok((worktree_path, _snapshot)) = worktree_info else {
1742 return WorktreeSnapshot {
1743 worktree_path: String::new(),
1744 git_state: None,
1745 };
1746 };
1747
1748 let git_state = git_store
1749 .update(cx, |git_store, cx| {
1750 git_store
1751 .repositories()
1752 .values()
1753 .find(|repo| {
1754 repo.read(cx)
1755 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1756 .is_some()
1757 })
1758 .cloned()
1759 })
1760 .ok()
1761 .flatten()
1762 .map(|repo| {
1763 repo.update(cx, |repo, _| {
1764 let current_branch =
1765 repo.branch.as_ref().map(|branch| branch.name.to_string());
1766 repo.send_job(None, |state, _| async move {
1767 let RepositoryState::Local { backend, .. } = state else {
1768 return GitState {
1769 remote_url: None,
1770 head_sha: None,
1771 current_branch,
1772 diff: None,
1773 };
1774 };
1775
1776 let remote_url = backend.remote_url("origin");
1777 let head_sha = backend.head_sha();
1778 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1779
1780 GitState {
1781 remote_url,
1782 head_sha,
1783 current_branch,
1784 diff,
1785 }
1786 })
1787 })
1788 });
1789
1790 let git_state = match git_state {
1791 Some(git_state) => match git_state.ok() {
1792 Some(git_state) => git_state.await.ok(),
1793 None => None,
1794 },
1795 None => None,
1796 };
1797
1798 WorktreeSnapshot {
1799 worktree_path,
1800 git_state,
1801 }
1802 })
1803 }
1804
1805 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1806 let mut markdown = Vec::new();
1807
1808 if let Some(summary) = self.summary() {
1809 writeln!(markdown, "# {summary}\n")?;
1810 };
1811
1812 for message in self.messages() {
1813 writeln!(
1814 markdown,
1815 "## {role}\n",
1816 role = match message.role {
1817 Role::User => "User",
1818 Role::Assistant => "Assistant",
1819 Role::System => "System",
1820 }
1821 )?;
1822
1823 if !message.context.is_empty() {
1824 writeln!(markdown, "{}", message.context)?;
1825 }
1826
1827 for segment in &message.segments {
1828 match segment {
1829 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1830 MessageSegment::Thinking(text) => {
1831 writeln!(markdown, "<think>{}</think>\n", text)?
1832 }
1833 }
1834 }
1835
1836 for tool_use in self.tool_uses_for_message(message.id, cx) {
1837 writeln!(
1838 markdown,
1839 "**Use Tool: {} ({})**",
1840 tool_use.name, tool_use.id
1841 )?;
1842 writeln!(markdown, "```json")?;
1843 writeln!(
1844 markdown,
1845 "{}",
1846 serde_json::to_string_pretty(&tool_use.input)?
1847 )?;
1848 writeln!(markdown, "```")?;
1849 }
1850
1851 for tool_result in self.tool_results_for_message(message.id) {
1852 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1853 if tool_result.is_error {
1854 write!(markdown, " (Error)")?;
1855 }
1856
1857 writeln!(markdown, "**\n")?;
1858 writeln!(markdown, "{}", tool_result.content)?;
1859 }
1860 }
1861
1862 Ok(String::from_utf8_lossy(&markdown).to_string())
1863 }
1864
1865 pub fn keep_edits_in_range(
1866 &mut self,
1867 buffer: Entity<language::Buffer>,
1868 buffer_range: Range<language::Anchor>,
1869 cx: &mut Context<Self>,
1870 ) {
1871 self.action_log.update(cx, |action_log, cx| {
1872 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1873 });
1874 }
1875
1876 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1877 self.action_log
1878 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1879 }
1880
1881 pub fn reject_edits_in_ranges(
1882 &mut self,
1883 buffer: Entity<language::Buffer>,
1884 buffer_ranges: Vec<Range<language::Anchor>>,
1885 cx: &mut Context<Self>,
1886 ) -> Task<Result<()>> {
1887 self.action_log.update(cx, |action_log, cx| {
1888 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1889 })
1890 }
1891
1892 pub fn action_log(&self) -> &Entity<ActionLog> {
1893 &self.action_log
1894 }
1895
1896 pub fn project(&self) -> &Entity<Project> {
1897 &self.project
1898 }
1899
1900 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1901 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1902 return;
1903 }
1904
1905 let now = Instant::now();
1906 if let Some(last) = self.last_auto_capture_at {
1907 if now.duration_since(last).as_secs() < 10 {
1908 return;
1909 }
1910 }
1911
1912 self.last_auto_capture_at = Some(now);
1913
1914 let thread_id = self.id().clone();
1915 let github_login = self
1916 .project
1917 .read(cx)
1918 .user_store()
1919 .read(cx)
1920 .current_user()
1921 .map(|user| user.github_login.clone());
1922 let client = self.project.read(cx).client().clone();
1923 let serialize_task = self.serialize(cx);
1924
1925 cx.background_executor()
1926 .spawn(async move {
1927 if let Ok(serialized_thread) = serialize_task.await {
1928 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1929 telemetry::event!(
1930 "Agent Thread Auto-Captured",
1931 thread_id = thread_id.to_string(),
1932 thread_data = thread_data,
1933 auto_capture_reason = "tracked_user",
1934 github_login = github_login
1935 );
1936
1937 client.telemetry().flush_events();
1938 }
1939 }
1940 })
1941 .detach();
1942 }
1943
1944 pub fn cumulative_token_usage(&self) -> TokenUsage {
1945 self.cumulative_token_usage
1946 }
1947
1948 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1949 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1950 return TotalTokenUsage::default();
1951 };
1952
1953 let max = model.model.max_token_count();
1954
1955 let index = self
1956 .messages
1957 .iter()
1958 .position(|msg| msg.id == message_id)
1959 .unwrap_or(0);
1960
1961 if index == 0 {
1962 return TotalTokenUsage { total: 0, max };
1963 }
1964
1965 let token_usage = &self
1966 .request_token_usage
1967 .get(index - 1)
1968 .cloned()
1969 .unwrap_or_default();
1970
1971 TotalTokenUsage {
1972 total: token_usage.total_tokens() as usize,
1973 max,
1974 }
1975 }
1976
1977 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1978 let model_registry = LanguageModelRegistry::read_global(cx);
1979 let Some(model) = model_registry.default_model() else {
1980 return TotalTokenUsage::default();
1981 };
1982
1983 let max = model.model.max_token_count();
1984
1985 if let Some(exceeded_error) = &self.exceeded_window_error {
1986 if model.model.id() == exceeded_error.model_id {
1987 return TotalTokenUsage {
1988 total: exceeded_error.token_count,
1989 max,
1990 };
1991 }
1992 }
1993
1994 let total = self
1995 .token_usage_at_last_message()
1996 .unwrap_or_default()
1997 .total_tokens() as usize;
1998
1999 TotalTokenUsage { total, max }
2000 }
2001
2002 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2003 self.request_token_usage
2004 .get(self.messages.len().saturating_sub(1))
2005 .or_else(|| self.request_token_usage.last())
2006 .cloned()
2007 }
2008
2009 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2010 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2011 self.request_token_usage
2012 .resize(self.messages.len(), placeholder);
2013
2014 if let Some(last) = self.request_token_usage.last_mut() {
2015 *last = token_usage;
2016 }
2017 }
2018
2019 pub fn deny_tool_use(
2020 &mut self,
2021 tool_use_id: LanguageModelToolUseId,
2022 tool_name: Arc<str>,
2023 cx: &mut Context<Self>,
2024 ) {
2025 let err = Err(anyhow::anyhow!(
2026 "Permission to run tool action denied by user"
2027 ));
2028
2029 self.tool_use
2030 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2031 self.tool_finished(tool_use_id.clone(), None, true, cx);
2032 }
2033}
2034
2035#[derive(Debug, Clone, Error)]
2036pub enum ThreadError {
2037 #[error("Payment required")]
2038 PaymentRequired,
2039 #[error("Max monthly spend reached")]
2040 MaxMonthlySpendReached,
2041 #[error("Model request limit reached")]
2042 ModelRequestLimitReached { plan: Plan },
2043 #[error("Message {header}: {message}")]
2044 Message {
2045 header: SharedString,
2046 message: SharedString,
2047 },
2048}
2049
2050#[derive(Debug, Clone)]
2051pub enum ThreadEvent {
2052 ShowError(ThreadError),
2053 UsageUpdated(RequestUsage),
2054 StreamedCompletion,
2055 StreamedAssistantText(MessageId, String),
2056 StreamedAssistantThinking(MessageId, String),
2057 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2058 MessageAdded(MessageId),
2059 MessageEdited(MessageId),
2060 MessageDeleted(MessageId),
2061 SummaryGenerated,
2062 SummaryChanged,
2063 UsePendingTools {
2064 tool_uses: Vec<PendingToolUse>,
2065 },
2066 ToolFinished {
2067 #[allow(unused)]
2068 tool_use_id: LanguageModelToolUseId,
2069 /// The pending tool use that corresponds to this tool.
2070 pending_tool_use: Option<PendingToolUse>,
2071 },
2072 CheckpointChanged,
2073 ToolConfirmationNeeded,
2074}
2075
2076impl EventEmitter<ThreadEvent> for Thread {}
2077
2078struct PendingCompletion {
2079 id: usize,
2080 _task: Task<()>,
2081}
2082
2083#[cfg(test)]
2084mod tests {
2085 use super::*;
2086 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2087 use assistant_settings::AssistantSettings;
2088 use context_server::ContextServerSettings;
2089 use editor::EditorSettings;
2090 use gpui::TestAppContext;
2091 use project::{FakeFs, Project};
2092 use prompt_store::PromptBuilder;
2093 use serde_json::json;
2094 use settings::{Settings, SettingsStore};
2095 use std::sync::Arc;
2096 use theme::ThemeSettings;
2097 use util::path;
2098 use workspace::Workspace;
2099
2100 #[gpui::test]
2101 async fn test_message_with_context(cx: &mut TestAppContext) {
2102 init_test_settings(cx);
2103
2104 let project = create_test_project(
2105 cx,
2106 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2107 )
2108 .await;
2109
2110 let (_workspace, _thread_store, thread, context_store) =
2111 setup_test_environment(cx, project.clone()).await;
2112
2113 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2114 .await
2115 .unwrap();
2116
2117 let context =
2118 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2119
2120 // Insert user message with context
2121 let message_id = thread.update(cx, |thread, cx| {
2122 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2123 });
2124
2125 // Check content and context in message object
2126 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2127
2128 // Use different path format strings based on platform for the test
2129 #[cfg(windows)]
2130 let path_part = r"test\code.rs";
2131 #[cfg(not(windows))]
2132 let path_part = "test/code.rs";
2133
2134 let expected_context = format!(
2135 r#"
2136<context>
2137The following items were attached by the user. You don't need to use other tools to read them.
2138
2139<files>
2140```rs {path_part}
2141fn main() {{
2142 println!("Hello, world!");
2143}}
2144```
2145</files>
2146</context>
2147"#
2148 );
2149
2150 assert_eq!(message.role, Role::User);
2151 assert_eq!(message.segments.len(), 1);
2152 assert_eq!(
2153 message.segments[0],
2154 MessageSegment::Text("Please explain this code".to_string())
2155 );
2156 assert_eq!(message.context, expected_context);
2157
2158 // Check message in request
2159 let request = thread.read_with(cx, |thread, cx| {
2160 thread.to_completion_request(RequestKind::Chat, cx)
2161 });
2162
2163 assert_eq!(request.messages.len(), 2);
2164 let expected_full_message = format!("{}Please explain this code", expected_context);
2165 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2166 }
2167
2168 #[gpui::test]
2169 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2170 init_test_settings(cx);
2171
2172 let project = create_test_project(
2173 cx,
2174 json!({
2175 "file1.rs": "fn function1() {}\n",
2176 "file2.rs": "fn function2() {}\n",
2177 "file3.rs": "fn function3() {}\n",
2178 }),
2179 )
2180 .await;
2181
2182 let (_, _thread_store, thread, context_store) =
2183 setup_test_environment(cx, project.clone()).await;
2184
2185 // Open files individually
2186 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2187 .await
2188 .unwrap();
2189 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2190 .await
2191 .unwrap();
2192 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2193 .await
2194 .unwrap();
2195
2196 // Get the context objects
2197 let contexts = context_store.update(cx, |store, _| store.context().clone());
2198 assert_eq!(contexts.len(), 3);
2199
2200 // First message with context 1
2201 let message1_id = thread.update(cx, |thread, cx| {
2202 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2203 });
2204
2205 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2206 let message2_id = thread.update(cx, |thread, cx| {
2207 thread.insert_user_message(
2208 "Message 2",
2209 vec![contexts[0].clone(), contexts[1].clone()],
2210 None,
2211 cx,
2212 )
2213 });
2214
2215 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2216 let message3_id = thread.update(cx, |thread, cx| {
2217 thread.insert_user_message(
2218 "Message 3",
2219 vec![
2220 contexts[0].clone(),
2221 contexts[1].clone(),
2222 contexts[2].clone(),
2223 ],
2224 None,
2225 cx,
2226 )
2227 });
2228
2229 // Check what contexts are included in each message
2230 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2231 (
2232 thread.message(message1_id).unwrap().clone(),
2233 thread.message(message2_id).unwrap().clone(),
2234 thread.message(message3_id).unwrap().clone(),
2235 )
2236 });
2237
2238 // First message should include context 1
2239 assert!(message1.context.contains("file1.rs"));
2240
2241 // Second message should include only context 2 (not 1)
2242 assert!(!message2.context.contains("file1.rs"));
2243 assert!(message2.context.contains("file2.rs"));
2244
2245 // Third message should include only context 3 (not 1 or 2)
2246 assert!(!message3.context.contains("file1.rs"));
2247 assert!(!message3.context.contains("file2.rs"));
2248 assert!(message3.context.contains("file3.rs"));
2249
2250 // Check entire request to make sure all contexts are properly included
2251 let request = thread.read_with(cx, |thread, cx| {
2252 thread.to_completion_request(RequestKind::Chat, cx)
2253 });
2254
2255 // The request should contain all 3 messages
2256 assert_eq!(request.messages.len(), 4);
2257
2258 // Check that the contexts are properly formatted in each message
2259 assert!(request.messages[1].string_contents().contains("file1.rs"));
2260 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2261 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2262
2263 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2264 assert!(request.messages[2].string_contents().contains("file2.rs"));
2265 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2266
2267 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2268 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2269 assert!(request.messages[3].string_contents().contains("file3.rs"));
2270 }
2271
2272 #[gpui::test]
2273 async fn test_message_without_files(cx: &mut TestAppContext) {
2274 init_test_settings(cx);
2275
2276 let project = create_test_project(
2277 cx,
2278 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2279 )
2280 .await;
2281
2282 let (_, _thread_store, thread, _context_store) =
2283 setup_test_environment(cx, project.clone()).await;
2284
2285 // Insert user message without any context (empty context vector)
2286 let message_id = thread.update(cx, |thread, cx| {
2287 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2288 });
2289
2290 // Check content and context in message object
2291 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2292
2293 // Context should be empty when no files are included
2294 assert_eq!(message.role, Role::User);
2295 assert_eq!(message.segments.len(), 1);
2296 assert_eq!(
2297 message.segments[0],
2298 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2299 );
2300 assert_eq!(message.context, "");
2301
2302 // Check message in request
2303 let request = thread.read_with(cx, |thread, cx| {
2304 thread.to_completion_request(RequestKind::Chat, cx)
2305 });
2306
2307 assert_eq!(request.messages.len(), 2);
2308 assert_eq!(
2309 request.messages[1].string_contents(),
2310 "What is the best way to learn Rust?"
2311 );
2312
2313 // Add second message, also without context
2314 let message2_id = thread.update(cx, |thread, cx| {
2315 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2316 });
2317
2318 let message2 =
2319 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2320 assert_eq!(message2.context, "");
2321
2322 // Check that both messages appear in the request
2323 let request = thread.read_with(cx, |thread, cx| {
2324 thread.to_completion_request(RequestKind::Chat, cx)
2325 });
2326
2327 assert_eq!(request.messages.len(), 3);
2328 assert_eq!(
2329 request.messages[1].string_contents(),
2330 "What is the best way to learn Rust?"
2331 );
2332 assert_eq!(
2333 request.messages[2].string_contents(),
2334 "Are there any good books?"
2335 );
2336 }
2337
2338 #[gpui::test]
2339 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2340 init_test_settings(cx);
2341
2342 let project = create_test_project(
2343 cx,
2344 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2345 )
2346 .await;
2347
2348 let (_workspace, _thread_store, thread, context_store) =
2349 setup_test_environment(cx, project.clone()).await;
2350
2351 // Open buffer and add it to context
2352 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2353 .await
2354 .unwrap();
2355
2356 let context =
2357 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2358
2359 // Insert user message with the buffer as context
2360 thread.update(cx, |thread, cx| {
2361 thread.insert_user_message("Explain this code", vec![context], None, cx)
2362 });
2363
2364 // Create a request and check that it doesn't have a stale buffer warning yet
2365 let initial_request = thread.read_with(cx, |thread, cx| {
2366 thread.to_completion_request(RequestKind::Chat, cx)
2367 });
2368
2369 // Make sure we don't have a stale file warning yet
2370 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2371 msg.string_contents()
2372 .contains("These files changed since last read:")
2373 });
2374 assert!(
2375 !has_stale_warning,
2376 "Should not have stale buffer warning before buffer is modified"
2377 );
2378
2379 // Modify the buffer
2380 buffer.update(cx, |buffer, cx| {
2381 // Find a position at the end of line 1
2382 buffer.edit(
2383 [(1..1, "\n println!(\"Added a new line\");\n")],
2384 None,
2385 cx,
2386 );
2387 });
2388
2389 // Insert another user message without context
2390 thread.update(cx, |thread, cx| {
2391 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2392 });
2393
2394 // Create a new request and check for the stale buffer warning
2395 let new_request = thread.read_with(cx, |thread, cx| {
2396 thread.to_completion_request(RequestKind::Chat, cx)
2397 });
2398
2399 // We should have a stale file warning as the last message
2400 let last_message = new_request
2401 .messages
2402 .last()
2403 .expect("Request should have messages");
2404
2405 // The last message should be the stale buffer notification
2406 assert_eq!(last_message.role, Role::User);
2407
2408 // Check the exact content of the message
2409 let expected_content = "These files changed since last read:\n- code.rs\n";
2410 assert_eq!(
2411 last_message.string_contents(),
2412 expected_content,
2413 "Last message should be exactly the stale buffer notification"
2414 );
2415 }
2416
2417 fn init_test_settings(cx: &mut TestAppContext) {
2418 cx.update(|cx| {
2419 let settings_store = SettingsStore::test(cx);
2420 cx.set_global(settings_store);
2421 language::init(cx);
2422 Project::init_settings(cx);
2423 AssistantSettings::register(cx);
2424 thread_store::init(cx);
2425 workspace::init_settings(cx);
2426 ThemeSettings::register(cx);
2427 ContextServerSettings::register(cx);
2428 EditorSettings::register(cx);
2429 });
2430 }
2431
2432 // Helper to create a test project with test files
2433 async fn create_test_project(
2434 cx: &mut TestAppContext,
2435 files: serde_json::Value,
2436 ) -> Entity<Project> {
2437 let fs = FakeFs::new(cx.executor());
2438 fs.insert_tree(path!("/test"), files).await;
2439 Project::test(fs, [path!("/test").as_ref()], cx).await
2440 }
2441
2442 async fn setup_test_environment(
2443 cx: &mut TestAppContext,
2444 project: Entity<Project>,
2445 ) -> (
2446 Entity<Workspace>,
2447 Entity<ThreadStore>,
2448 Entity<Thread>,
2449 Entity<ContextStore>,
2450 ) {
2451 let (workspace, cx) =
2452 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2453
2454 let thread_store = cx
2455 .update(|_, cx| {
2456 ThreadStore::load(
2457 project.clone(),
2458 cx.new(|_| ToolWorkingSet::default()),
2459 Arc::new(PromptBuilder::new(None).unwrap()),
2460 cx,
2461 )
2462 })
2463 .await;
2464
2465 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2466 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2467
2468 (workspace, thread_store, thread, context_store)
2469 }
2470
2471 async fn add_file_to_context(
2472 project: &Entity<Project>,
2473 context_store: &Entity<ContextStore>,
2474 path: &str,
2475 cx: &mut TestAppContext,
2476 ) -> Result<Entity<language::Buffer>> {
2477 let buffer_path = project
2478 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2479 .unwrap();
2480
2481 let buffer = project
2482 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2483 .await
2484 .unwrap();
2485
2486 context_store
2487 .update(cx, |store, cx| {
2488 store.add_file_from_buffer(buffer.clone(), cx)
2489 })
2490 .await?;
2491
2492 Ok(buffer)
2493 }
2494}