1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(Debug, Clone, Copy)]
44pub enum RequestKind {
45 Chat,
46 /// Used when summarizing a thread.
47 Summarize,
48}
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
74pub struct MessageId(pub(crate) usize);
75
76impl MessageId {
77 fn post_inc(&mut self) -> Self {
78 Self(post_inc(&mut self.0))
79 }
80}
81
82/// A message in a [`Thread`].
83#[derive(Debug, Clone)]
84pub struct Message {
85 pub id: MessageId,
86 pub role: Role,
87 pub segments: Vec<MessageSegment>,
88 pub context: String,
89}
90
91impl Message {
92 /// Returns whether the message contains any meaningful text that should be displayed
93 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
94 pub fn should_display_content(&self) -> bool {
95 self.segments.iter().all(|segment| segment.should_display())
96 }
97
98 pub fn push_thinking(&mut self, text: &str) {
99 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
100 segment.push_str(text);
101 } else {
102 self.segments
103 .push(MessageSegment::Thinking(text.to_string()));
104 }
105 }
106
107 pub fn push_text(&mut self, text: &str) {
108 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
109 segment.push_str(text);
110 } else {
111 self.segments.push(MessageSegment::Text(text.to_string()));
112 }
113 }
114
115 pub fn to_string(&self) -> String {
116 let mut result = String::new();
117
118 if !self.context.is_empty() {
119 result.push_str(&self.context);
120 }
121
122 for segment in &self.segments {
123 match segment {
124 MessageSegment::Text(text) => result.push_str(text),
125 MessageSegment::Thinking(text) => {
126 result.push_str("<think>");
127 result.push_str(text);
128 result.push_str("</think>");
129 }
130 }
131 }
132
133 result
134 }
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum MessageSegment {
139 Text(String),
140 Thinking(String),
141}
142
143impl MessageSegment {
144 pub fn text_mut(&mut self) -> &mut String {
145 match self {
146 Self::Text(text) => text,
147 Self::Thinking(text) => text,
148 }
149 }
150
151 pub fn should_display(&self) -> bool {
152 // We add USING_TOOL_MARKER when making a request that includes tool uses
153 // without non-whitespace text around them, and this can cause the model
154 // to mimic the pattern, so we consider those segments not displayable.
155 match self {
156 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
157 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
158 }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct ProjectSnapshot {
164 pub worktree_snapshots: Vec<WorktreeSnapshot>,
165 pub unsaved_buffer_paths: Vec<String>,
166 pub timestamp: DateTime<Utc>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct WorktreeSnapshot {
171 pub worktree_path: String,
172 pub git_state: Option<GitState>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct GitState {
177 pub remote_url: Option<String>,
178 pub head_sha: Option<String>,
179 pub current_branch: Option<String>,
180 pub diff: Option<String>,
181}
182
183#[derive(Clone)]
184pub struct ThreadCheckpoint {
185 message_id: MessageId,
186 git_checkpoint: GitStoreCheckpoint,
187}
188
189#[derive(Copy, Clone, Debug, PartialEq, Eq)]
190pub enum ThreadFeedback {
191 Positive,
192 Negative,
193}
194
195pub enum LastRestoreCheckpoint {
196 Pending {
197 message_id: MessageId,
198 },
199 Error {
200 message_id: MessageId,
201 error: String,
202 },
203}
204
205impl LastRestoreCheckpoint {
206 pub fn message_id(&self) -> MessageId {
207 match self {
208 LastRestoreCheckpoint::Pending { message_id } => *message_id,
209 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
210 }
211 }
212}
213
214#[derive(Clone, Debug, Default, Serialize, Deserialize)]
215pub enum DetailedSummaryState {
216 #[default]
217 NotGenerated,
218 Generating {
219 message_id: MessageId,
220 },
221 Generated {
222 text: SharedString,
223 message_id: MessageId,
224 },
225}
226
227#[derive(Default)]
228pub struct TotalTokenUsage {
229 pub total: usize,
230 pub max: usize,
231}
232
233impl TotalTokenUsage {
234 pub fn ratio(&self) -> TokenUsageRatio {
235 #[cfg(debug_assertions)]
236 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
237 .unwrap_or("0.8".to_string())
238 .parse()
239 .unwrap();
240 #[cfg(not(debug_assertions))]
241 let warning_threshold: f32 = 0.8;
242
243 if self.total >= self.max {
244 TokenUsageRatio::Exceeded
245 } else if self.total as f32 / self.max as f32 >= warning_threshold {
246 TokenUsageRatio::Warning
247 } else {
248 TokenUsageRatio::Normal
249 }
250 }
251
252 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
253 TotalTokenUsage {
254 total: self.total + tokens,
255 max: self.max,
256 }
257 }
258}
259
260#[derive(Debug, Default, PartialEq, Eq)]
261pub enum TokenUsageRatio {
262 #[default]
263 Normal,
264 Warning,
265 Exceeded,
266}
267
268/// A thread of conversation with the LLM.
269pub struct Thread {
270 id: ThreadId,
271 updated_at: DateTime<Utc>,
272 summary: Option<SharedString>,
273 pending_summary: Task<Option<()>>,
274 detailed_summary_state: DetailedSummaryState,
275 messages: Vec<Message>,
276 next_message_id: MessageId,
277 context: BTreeMap<ContextId, AssistantContext>,
278 context_by_message: HashMap<MessageId, Vec<ContextId>>,
279 project_context: SharedProjectContext,
280 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
281 completion_count: usize,
282 pending_completions: Vec<PendingCompletion>,
283 project: Entity<Project>,
284 prompt_builder: Arc<PromptBuilder>,
285 tools: Entity<ToolWorkingSet>,
286 tool_use: ToolUseState,
287 action_log: Entity<ActionLog>,
288 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
289 pending_checkpoint: Option<ThreadCheckpoint>,
290 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
291 request_token_usage: Vec<TokenUsage>,
292 cumulative_token_usage: TokenUsage,
293 exceeded_window_error: Option<ExceededWindowError>,
294 feedback: Option<ThreadFeedback>,
295 message_feedback: HashMap<MessageId, ThreadFeedback>,
296 last_auto_capture_at: Option<Instant>,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct ExceededWindowError {
301 /// Model used when last message exceeded context window
302 model_id: LanguageModelId,
303 /// Token count including last message
304 token_count: usize,
305}
306
307impl Thread {
308 pub fn new(
309 project: Entity<Project>,
310 tools: Entity<ToolWorkingSet>,
311 prompt_builder: Arc<PromptBuilder>,
312 system_prompt: SharedProjectContext,
313 cx: &mut Context<Self>,
314 ) -> Self {
315 Self {
316 id: ThreadId::new(),
317 updated_at: Utc::now(),
318 summary: None,
319 pending_summary: Task::ready(None),
320 detailed_summary_state: DetailedSummaryState::NotGenerated,
321 messages: Vec::new(),
322 next_message_id: MessageId(0),
323 context: BTreeMap::default(),
324 context_by_message: HashMap::default(),
325 project_context: system_prompt,
326 checkpoints_by_message: HashMap::default(),
327 completion_count: 0,
328 pending_completions: Vec::new(),
329 project: project.clone(),
330 prompt_builder,
331 tools: tools.clone(),
332 last_restore_checkpoint: None,
333 pending_checkpoint: None,
334 tool_use: ToolUseState::new(tools.clone()),
335 action_log: cx.new(|_| ActionLog::new(project.clone())),
336 initial_project_snapshot: {
337 let project_snapshot = Self::project_snapshot(project, cx);
338 cx.foreground_executor()
339 .spawn(async move { Some(project_snapshot.await) })
340 .shared()
341 },
342 request_token_usage: Vec::new(),
343 cumulative_token_usage: TokenUsage::default(),
344 exceeded_window_error: None,
345 feedback: None,
346 message_feedback: HashMap::default(),
347 last_auto_capture_at: None,
348 }
349 }
350
351 pub fn deserialize(
352 id: ThreadId,
353 serialized: SerializedThread,
354 project: Entity<Project>,
355 tools: Entity<ToolWorkingSet>,
356 prompt_builder: Arc<PromptBuilder>,
357 project_context: SharedProjectContext,
358 cx: &mut Context<Self>,
359 ) -> Self {
360 let next_message_id = MessageId(
361 serialized
362 .messages
363 .last()
364 .map(|message| message.id.0 + 1)
365 .unwrap_or(0),
366 );
367 let tool_use =
368 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
369
370 Self {
371 id,
372 updated_at: serialized.updated_at,
373 summary: Some(serialized.summary),
374 pending_summary: Task::ready(None),
375 detailed_summary_state: serialized.detailed_summary_state,
376 messages: serialized
377 .messages
378 .into_iter()
379 .map(|message| Message {
380 id: message.id,
381 role: message.role,
382 segments: message
383 .segments
384 .into_iter()
385 .map(|segment| match segment {
386 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
387 SerializedMessageSegment::Thinking { text } => {
388 MessageSegment::Thinking(text)
389 }
390 })
391 .collect(),
392 context: message.context,
393 })
394 .collect(),
395 next_message_id,
396 context: BTreeMap::default(),
397 context_by_message: HashMap::default(),
398 project_context,
399 checkpoints_by_message: HashMap::default(),
400 completion_count: 0,
401 pending_completions: Vec::new(),
402 last_restore_checkpoint: None,
403 pending_checkpoint: None,
404 project: project.clone(),
405 prompt_builder,
406 tools,
407 tool_use,
408 action_log: cx.new(|_| ActionLog::new(project)),
409 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
410 request_token_usage: serialized.request_token_usage,
411 cumulative_token_usage: serialized.cumulative_token_usage,
412 exceeded_window_error: None,
413 feedback: None,
414 message_feedback: HashMap::default(),
415 last_auto_capture_at: None,
416 }
417 }
418
419 pub fn id(&self) -> &ThreadId {
420 &self.id
421 }
422
423 pub fn is_empty(&self) -> bool {
424 self.messages.is_empty()
425 }
426
427 pub fn updated_at(&self) -> DateTime<Utc> {
428 self.updated_at
429 }
430
431 pub fn touch_updated_at(&mut self) {
432 self.updated_at = Utc::now();
433 }
434
435 pub fn summary(&self) -> Option<SharedString> {
436 self.summary.clone()
437 }
438
439 pub fn project_context(&self) -> SharedProjectContext {
440 self.project_context.clone()
441 }
442
443 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
444
445 pub fn summary_or_default(&self) -> SharedString {
446 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
447 }
448
449 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
450 let Some(current_summary) = &self.summary else {
451 // Don't allow setting summary until generated
452 return;
453 };
454
455 let mut new_summary = new_summary.into();
456
457 if new_summary.is_empty() {
458 new_summary = Self::DEFAULT_SUMMARY;
459 }
460
461 if current_summary != &new_summary {
462 self.summary = Some(new_summary);
463 cx.emit(ThreadEvent::SummaryChanged);
464 }
465 }
466
467 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
468 self.latest_detailed_summary()
469 .unwrap_or_else(|| self.text().into())
470 }
471
472 fn latest_detailed_summary(&self) -> Option<SharedString> {
473 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
474 Some(text.clone())
475 } else {
476 None
477 }
478 }
479
480 pub fn message(&self, id: MessageId) -> Option<&Message> {
481 self.messages.iter().find(|message| message.id == id)
482 }
483
484 pub fn messages(&self) -> impl Iterator<Item = &Message> {
485 self.messages.iter()
486 }
487
488 pub fn is_generating(&self) -> bool {
489 !self.pending_completions.is_empty() || !self.all_tools_finished()
490 }
491
492 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
493 &self.tools
494 }
495
496 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
497 self.tool_use
498 .pending_tool_uses()
499 .into_iter()
500 .find(|tool_use| &tool_use.id == id)
501 }
502
503 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
504 self.tool_use
505 .pending_tool_uses()
506 .into_iter()
507 .filter(|tool_use| tool_use.status.needs_confirmation())
508 }
509
510 pub fn has_pending_tool_uses(&self) -> bool {
511 !self.tool_use.pending_tool_uses().is_empty()
512 }
513
514 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
515 self.checkpoints_by_message.get(&id).cloned()
516 }
517
518 pub fn restore_checkpoint(
519 &mut self,
520 checkpoint: ThreadCheckpoint,
521 cx: &mut Context<Self>,
522 ) -> Task<Result<()>> {
523 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
524 message_id: checkpoint.message_id,
525 });
526 cx.emit(ThreadEvent::CheckpointChanged);
527 cx.notify();
528
529 let git_store = self.project().read(cx).git_store().clone();
530 let restore = git_store.update(cx, |git_store, cx| {
531 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
532 });
533
534 cx.spawn(async move |this, cx| {
535 let result = restore.await;
536 this.update(cx, |this, cx| {
537 if let Err(err) = result.as_ref() {
538 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
539 message_id: checkpoint.message_id,
540 error: err.to_string(),
541 });
542 } else {
543 this.truncate(checkpoint.message_id, cx);
544 this.last_restore_checkpoint = None;
545 }
546 this.pending_checkpoint = None;
547 cx.emit(ThreadEvent::CheckpointChanged);
548 cx.notify();
549 })?;
550 result
551 })
552 }
553
554 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
555 let pending_checkpoint = if self.is_generating() {
556 return;
557 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
558 checkpoint
559 } else {
560 return;
561 };
562
563 let git_store = self.project.read(cx).git_store().clone();
564 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
565 cx.spawn(async move |this, cx| match final_checkpoint.await {
566 Ok(final_checkpoint) => {
567 let equal = git_store
568 .update(cx, |store, cx| {
569 store.compare_checkpoints(
570 pending_checkpoint.git_checkpoint.clone(),
571 final_checkpoint.clone(),
572 cx,
573 )
574 })?
575 .await
576 .unwrap_or(false);
577
578 if equal {
579 git_store
580 .update(cx, |store, cx| {
581 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
582 })?
583 .detach();
584 } else {
585 this.update(cx, |this, cx| {
586 this.insert_checkpoint(pending_checkpoint, cx)
587 })?;
588 }
589
590 git_store
591 .update(cx, |store, cx| {
592 store.delete_checkpoint(final_checkpoint, cx)
593 })?
594 .detach();
595
596 Ok(())
597 }
598 Err(_) => this.update(cx, |this, cx| {
599 this.insert_checkpoint(pending_checkpoint, cx)
600 }),
601 })
602 .detach();
603 }
604
605 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
606 self.checkpoints_by_message
607 .insert(checkpoint.message_id, checkpoint);
608 cx.emit(ThreadEvent::CheckpointChanged);
609 cx.notify();
610 }
611
612 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
613 self.last_restore_checkpoint.as_ref()
614 }
615
616 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
617 let Some(message_ix) = self
618 .messages
619 .iter()
620 .rposition(|message| message.id == message_id)
621 else {
622 return;
623 };
624 for deleted_message in self.messages.drain(message_ix..) {
625 self.context_by_message.remove(&deleted_message.id);
626 self.checkpoints_by_message.remove(&deleted_message.id);
627 }
628 cx.notify();
629 }
630
631 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
632 self.context_by_message
633 .get(&id)
634 .into_iter()
635 .flat_map(|context| {
636 context
637 .iter()
638 .filter_map(|context_id| self.context.get(&context_id))
639 })
640 }
641
642 /// Returns whether all of the tool uses have finished running.
643 pub fn all_tools_finished(&self) -> bool {
644 // If the only pending tool uses left are the ones with errors, then
645 // that means that we've finished running all of the pending tools.
646 self.tool_use
647 .pending_tool_uses()
648 .iter()
649 .all(|tool_use| tool_use.status.is_error())
650 }
651
652 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
653 self.tool_use.tool_uses_for_message(id, cx)
654 }
655
656 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
657 self.tool_use.tool_results_for_message(id)
658 }
659
660 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
661 self.tool_use.tool_result(id)
662 }
663
664 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
665 Some(&self.tool_use.tool_result(id)?.content)
666 }
667
668 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
669 self.tool_use.tool_result_card(id).cloned()
670 }
671
672 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
673 self.tool_use.message_has_tool_results(message_id)
674 }
675
676 /// Filter out contexts that have already been included in previous messages
677 pub fn filter_new_context<'a>(
678 &self,
679 context: impl Iterator<Item = &'a AssistantContext>,
680 ) -> impl Iterator<Item = &'a AssistantContext> {
681 context.filter(|ctx| self.is_context_new(ctx))
682 }
683
684 fn is_context_new(&self, context: &AssistantContext) -> bool {
685 !self.context.contains_key(&context.id())
686 }
687
688 pub fn insert_user_message(
689 &mut self,
690 text: impl Into<String>,
691 context: Vec<AssistantContext>,
692 git_checkpoint: Option<GitStoreCheckpoint>,
693 cx: &mut Context<Self>,
694 ) -> MessageId {
695 let text = text.into();
696
697 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
698
699 let new_context: Vec<_> = context
700 .into_iter()
701 .filter(|ctx| self.is_context_new(ctx))
702 .collect();
703
704 if !new_context.is_empty() {
705 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
706 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
707 message.context = context_string;
708 }
709 }
710
711 self.action_log.update(cx, |log, cx| {
712 // Track all buffers added as context
713 for ctx in &new_context {
714 match ctx {
715 AssistantContext::File(file_ctx) => {
716 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
717 }
718 AssistantContext::Directory(dir_ctx) => {
719 for context_buffer in &dir_ctx.context_buffers {
720 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
721 }
722 }
723 AssistantContext::Symbol(symbol_ctx) => {
724 log.buffer_added_as_context(
725 symbol_ctx.context_symbol.buffer.clone(),
726 cx,
727 );
728 }
729 AssistantContext::Excerpt(excerpt_context) => {
730 log.buffer_added_as_context(
731 excerpt_context.context_buffer.buffer.clone(),
732 cx,
733 );
734 }
735 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
736 }
737 }
738 });
739 }
740
741 let context_ids = new_context
742 .iter()
743 .map(|context| context.id())
744 .collect::<Vec<_>>();
745 self.context.extend(
746 new_context
747 .into_iter()
748 .map(|context| (context.id(), context)),
749 );
750 self.context_by_message.insert(message_id, context_ids);
751
752 if let Some(git_checkpoint) = git_checkpoint {
753 self.pending_checkpoint = Some(ThreadCheckpoint {
754 message_id,
755 git_checkpoint,
756 });
757 }
758
759 self.auto_capture_telemetry(cx);
760
761 message_id
762 }
763
764 pub fn insert_message(
765 &mut self,
766 role: Role,
767 segments: Vec<MessageSegment>,
768 cx: &mut Context<Self>,
769 ) -> MessageId {
770 let id = self.next_message_id.post_inc();
771 self.messages.push(Message {
772 id,
773 role,
774 segments,
775 context: String::new(),
776 });
777 self.touch_updated_at();
778 cx.emit(ThreadEvent::MessageAdded(id));
779 id
780 }
781
782 pub fn edit_message(
783 &mut self,
784 id: MessageId,
785 new_role: Role,
786 new_segments: Vec<MessageSegment>,
787 cx: &mut Context<Self>,
788 ) -> bool {
789 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
790 return false;
791 };
792 message.role = new_role;
793 message.segments = new_segments;
794 self.touch_updated_at();
795 cx.emit(ThreadEvent::MessageEdited(id));
796 true
797 }
798
799 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
800 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
801 return false;
802 };
803 self.messages.remove(index);
804 self.context_by_message.remove(&id);
805 self.touch_updated_at();
806 cx.emit(ThreadEvent::MessageDeleted(id));
807 true
808 }
809
810 /// Returns the representation of this [`Thread`] in a textual form.
811 ///
812 /// This is the representation we use when attaching a thread as context to another thread.
813 pub fn text(&self) -> String {
814 let mut text = String::new();
815
816 for message in &self.messages {
817 text.push_str(match message.role {
818 language_model::Role::User => "User:",
819 language_model::Role::Assistant => "Assistant:",
820 language_model::Role::System => "System:",
821 });
822 text.push('\n');
823
824 for segment in &message.segments {
825 match segment {
826 MessageSegment::Text(content) => text.push_str(content),
827 MessageSegment::Thinking(content) => {
828 text.push_str(&format!("<think>{}</think>", content))
829 }
830 }
831 }
832 text.push('\n');
833 }
834
835 text
836 }
837
838 /// Serializes this thread into a format for storage or telemetry.
839 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
840 let initial_project_snapshot = self.initial_project_snapshot.clone();
841 cx.spawn(async move |this, cx| {
842 let initial_project_snapshot = initial_project_snapshot.await;
843 this.read_with(cx, |this, cx| SerializedThread {
844 version: SerializedThread::VERSION.to_string(),
845 summary: this.summary_or_default(),
846 updated_at: this.updated_at(),
847 messages: this
848 .messages()
849 .map(|message| SerializedMessage {
850 id: message.id,
851 role: message.role,
852 segments: message
853 .segments
854 .iter()
855 .map(|segment| match segment {
856 MessageSegment::Text(text) => {
857 SerializedMessageSegment::Text { text: text.clone() }
858 }
859 MessageSegment::Thinking(text) => {
860 SerializedMessageSegment::Thinking { text: text.clone() }
861 }
862 })
863 .collect(),
864 tool_uses: this
865 .tool_uses_for_message(message.id, cx)
866 .into_iter()
867 .map(|tool_use| SerializedToolUse {
868 id: tool_use.id,
869 name: tool_use.name,
870 input: tool_use.input,
871 })
872 .collect(),
873 tool_results: this
874 .tool_results_for_message(message.id)
875 .into_iter()
876 .map(|tool_result| SerializedToolResult {
877 tool_use_id: tool_result.tool_use_id.clone(),
878 is_error: tool_result.is_error,
879 content: tool_result.content.clone(),
880 })
881 .collect(),
882 context: message.context.clone(),
883 })
884 .collect(),
885 initial_project_snapshot,
886 cumulative_token_usage: this.cumulative_token_usage,
887 request_token_usage: this.request_token_usage.clone(),
888 detailed_summary_state: this.detailed_summary_state.clone(),
889 exceeded_window_error: this.exceeded_window_error.clone(),
890 })
891 })
892 }
893
894 pub fn send_to_model(
895 &mut self,
896 model: Arc<dyn LanguageModel>,
897 request_kind: RequestKind,
898 cx: &mut Context<Self>,
899 ) {
900 let mut request = self.to_completion_request(request_kind, cx);
901 if model.supports_tools() {
902 request.tools = {
903 let mut tools = Vec::new();
904 tools.extend(
905 self.tools()
906 .read(cx)
907 .enabled_tools(cx)
908 .into_iter()
909 .filter_map(|tool| {
910 // Skip tools that cannot be supported
911 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
912 Some(LanguageModelRequestTool {
913 name: tool.name(),
914 description: tool.description(),
915 input_schema,
916 })
917 }),
918 );
919
920 tools
921 };
922 }
923
924 self.stream_completion(request, model, cx);
925 }
926
927 pub fn used_tools_since_last_user_message(&self) -> bool {
928 for message in self.messages.iter().rev() {
929 if self.tool_use.message_has_tool_results(message.id) {
930 return true;
931 } else if message.role == Role::User {
932 return false;
933 }
934 }
935
936 false
937 }
938
939 pub fn to_completion_request(
940 &self,
941 request_kind: RequestKind,
942 cx: &mut Context<Self>,
943 ) -> LanguageModelRequest {
944 let mut request = LanguageModelRequest {
945 messages: vec![],
946 tools: Vec::new(),
947 stop: Vec::new(),
948 temperature: None,
949 };
950
951 if let Some(project_context) = self.project_context.borrow().as_ref() {
952 match self
953 .prompt_builder
954 .generate_assistant_system_prompt(project_context)
955 {
956 Err(err) => {
957 let message = format!("{err:?}").into();
958 log::error!("{message}");
959 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
960 header: "Error generating system prompt".into(),
961 message,
962 }));
963 }
964 Ok(system_prompt) => {
965 request.messages.push(LanguageModelRequestMessage {
966 role: Role::System,
967 content: vec![MessageContent::Text(system_prompt)],
968 cache: true,
969 });
970 }
971 }
972 } else {
973 let message = "Context for system prompt unexpectedly not ready.".into();
974 log::error!("{message}");
975 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
976 header: "Error generating system prompt".into(),
977 message,
978 }));
979 }
980
981 for message in &self.messages {
982 let mut request_message = LanguageModelRequestMessage {
983 role: message.role,
984 content: Vec::new(),
985 cache: false,
986 };
987
988 match request_kind {
989 RequestKind::Chat => {
990 self.tool_use
991 .attach_tool_results(message.id, &mut request_message);
992 }
993 RequestKind::Summarize => {
994 // We don't care about tool use during summarization.
995 if self.tool_use.message_has_tool_results(message.id) {
996 continue;
997 }
998 }
999 }
1000
1001 if !message.segments.is_empty() {
1002 request_message
1003 .content
1004 .push(MessageContent::Text(message.to_string()));
1005 }
1006
1007 match request_kind {
1008 RequestKind::Chat => {
1009 self.tool_use
1010 .attach_tool_uses(message.id, &mut request_message);
1011 }
1012 RequestKind::Summarize => {
1013 // We don't care about tool use during summarization.
1014 }
1015 };
1016
1017 request.messages.push(request_message);
1018 }
1019
1020 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1021 if let Some(last) = request.messages.last_mut() {
1022 last.cache = true;
1023 }
1024
1025 self.attached_tracked_files_state(&mut request.messages, cx);
1026
1027 request
1028 }
1029
1030 fn attached_tracked_files_state(
1031 &self,
1032 messages: &mut Vec<LanguageModelRequestMessage>,
1033 cx: &App,
1034 ) {
1035 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1036
1037 let mut stale_message = String::new();
1038
1039 let action_log = self.action_log.read(cx);
1040
1041 for stale_file in action_log.stale_buffers(cx) {
1042 let Some(file) = stale_file.read(cx).file() else {
1043 continue;
1044 };
1045
1046 if stale_message.is_empty() {
1047 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1048 }
1049
1050 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1051 }
1052
1053 let mut content = Vec::with_capacity(2);
1054
1055 if !stale_message.is_empty() {
1056 content.push(stale_message.into());
1057 }
1058
1059 if action_log.has_edited_files_since_project_diagnostics_check() {
1060 content.push(
1061 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1062 and fix all errors AND warnings you introduced! \
1063 DO NOT mention you're going to do this until you're done."
1064 .into(),
1065 );
1066 }
1067
1068 if !content.is_empty() {
1069 let context_message = LanguageModelRequestMessage {
1070 role: Role::User,
1071 content,
1072 cache: false,
1073 };
1074
1075 messages.push(context_message);
1076 }
1077 }
1078
1079 pub fn stream_completion(
1080 &mut self,
1081 request: LanguageModelRequest,
1082 model: Arc<dyn LanguageModel>,
1083 cx: &mut Context<Self>,
1084 ) {
1085 let pending_completion_id = post_inc(&mut self.completion_count);
1086 let task = cx.spawn(async move |thread, cx| {
1087 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1088 let initial_token_usage =
1089 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1090 let stream_completion = async {
1091 let (mut events, usage) = stream_completion_future.await?;
1092 let mut stop_reason = StopReason::EndTurn;
1093 let mut current_token_usage = TokenUsage::default();
1094
1095 if let Some(usage) = usage {
1096 thread
1097 .update(cx, |_thread, cx| {
1098 cx.emit(ThreadEvent::UsageUpdated(usage));
1099 })
1100 .ok();
1101 }
1102
1103 while let Some(event) = events.next().await {
1104 let event = event?;
1105
1106 thread.update(cx, |thread, cx| {
1107 match event {
1108 LanguageModelCompletionEvent::StartMessage { .. } => {
1109 thread.insert_message(
1110 Role::Assistant,
1111 vec![MessageSegment::Text(String::new())],
1112 cx,
1113 );
1114 }
1115 LanguageModelCompletionEvent::Stop(reason) => {
1116 stop_reason = reason;
1117 }
1118 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1119 thread.update_token_usage_at_last_message(token_usage);
1120 thread.cumulative_token_usage = thread.cumulative_token_usage
1121 + token_usage
1122 - current_token_usage;
1123 current_token_usage = token_usage;
1124 }
1125 LanguageModelCompletionEvent::Text(chunk) => {
1126 if let Some(last_message) = thread.messages.last_mut() {
1127 if last_message.role == Role::Assistant {
1128 last_message.push_text(&chunk);
1129 cx.emit(ThreadEvent::StreamedAssistantText(
1130 last_message.id,
1131 chunk,
1132 ));
1133 } else {
1134 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1135 // of a new Assistant response.
1136 //
1137 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1138 // will result in duplicating the text of the chunk in the rendered Markdown.
1139 thread.insert_message(
1140 Role::Assistant,
1141 vec![MessageSegment::Text(chunk.to_string())],
1142 cx,
1143 );
1144 };
1145 }
1146 }
1147 LanguageModelCompletionEvent::Thinking(chunk) => {
1148 if let Some(last_message) = thread.messages.last_mut() {
1149 if last_message.role == Role::Assistant {
1150 last_message.push_thinking(&chunk);
1151 cx.emit(ThreadEvent::StreamedAssistantThinking(
1152 last_message.id,
1153 chunk,
1154 ));
1155 } else {
1156 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1157 // of a new Assistant response.
1158 //
1159 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1160 // will result in duplicating the text of the chunk in the rendered Markdown.
1161 thread.insert_message(
1162 Role::Assistant,
1163 vec![MessageSegment::Thinking(chunk.to_string())],
1164 cx,
1165 );
1166 };
1167 }
1168 }
1169 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1170 let last_assistant_message_id = thread
1171 .messages
1172 .iter_mut()
1173 .rfind(|message| message.role == Role::Assistant)
1174 .map(|message| message.id)
1175 .unwrap_or_else(|| {
1176 thread.insert_message(Role::Assistant, vec![], cx)
1177 });
1178
1179 thread.tool_use.request_tool_use(
1180 last_assistant_message_id,
1181 tool_use,
1182 cx,
1183 );
1184 }
1185 }
1186
1187 thread.touch_updated_at();
1188 cx.emit(ThreadEvent::StreamedCompletion);
1189 cx.notify();
1190
1191 thread.auto_capture_telemetry(cx);
1192 })?;
1193
1194 smol::future::yield_now().await;
1195 }
1196
1197 thread.update(cx, |thread, cx| {
1198 thread
1199 .pending_completions
1200 .retain(|completion| completion.id != pending_completion_id);
1201
1202 if thread.summary.is_none() && thread.messages.len() >= 2 {
1203 thread.summarize(cx);
1204 }
1205 })?;
1206
1207 anyhow::Ok(stop_reason)
1208 };
1209
1210 let result = stream_completion.await;
1211
1212 thread
1213 .update(cx, |thread, cx| {
1214 thread.finalize_pending_checkpoint(cx);
1215 match result.as_ref() {
1216 Ok(stop_reason) => match stop_reason {
1217 StopReason::ToolUse => {
1218 let tool_uses = thread.use_pending_tools(cx);
1219 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1220 }
1221 StopReason::EndTurn => {}
1222 StopReason::MaxTokens => {}
1223 },
1224 Err(error) => {
1225 if error.is::<PaymentRequiredError>() {
1226 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1227 } else if error.is::<MaxMonthlySpendReachedError>() {
1228 cx.emit(ThreadEvent::ShowError(
1229 ThreadError::MaxMonthlySpendReached,
1230 ));
1231 } else if let Some(error) =
1232 error.downcast_ref::<ModelRequestLimitReachedError>()
1233 {
1234 cx.emit(ThreadEvent::ShowError(
1235 ThreadError::ModelRequestLimitReached { plan: error.plan },
1236 ));
1237 } else if let Some(known_error) =
1238 error.downcast_ref::<LanguageModelKnownError>()
1239 {
1240 match known_error {
1241 LanguageModelKnownError::ContextWindowLimitExceeded {
1242 tokens,
1243 } => {
1244 thread.exceeded_window_error = Some(ExceededWindowError {
1245 model_id: model.id(),
1246 token_count: *tokens,
1247 });
1248 cx.notify();
1249 }
1250 }
1251 } else {
1252 let error_message = error
1253 .chain()
1254 .map(|err| err.to_string())
1255 .collect::<Vec<_>>()
1256 .join("\n");
1257 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1258 header: "Error interacting with language model".into(),
1259 message: SharedString::from(error_message.clone()),
1260 }));
1261 }
1262
1263 thread.cancel_last_completion(cx);
1264 }
1265 }
1266 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1267
1268 thread.auto_capture_telemetry(cx);
1269
1270 if let Ok(initial_usage) = initial_token_usage {
1271 let usage = thread.cumulative_token_usage - initial_usage;
1272
1273 telemetry::event!(
1274 "Assistant Thread Completion",
1275 thread_id = thread.id().to_string(),
1276 model = model.telemetry_id(),
1277 model_provider = model.provider_id().to_string(),
1278 input_tokens = usage.input_tokens,
1279 output_tokens = usage.output_tokens,
1280 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1281 cache_read_input_tokens = usage.cache_read_input_tokens,
1282 );
1283 }
1284 })
1285 .ok();
1286 });
1287
1288 self.pending_completions.push(PendingCompletion {
1289 id: pending_completion_id,
1290 _task: task,
1291 });
1292 }
1293
1294 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1295 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1296 return;
1297 };
1298
1299 if !model.provider.is_authenticated(cx) {
1300 return;
1301 }
1302
1303 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1304 request.messages.push(LanguageModelRequestMessage {
1305 role: Role::User,
1306 content: vec![
1307 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1308 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1309 If the conversation is about a specific subject, include it in the title. \
1310 Be descriptive. DO NOT speak in the first person."
1311 .into(),
1312 ],
1313 cache: false,
1314 });
1315
1316 self.pending_summary = cx.spawn(async move |this, cx| {
1317 async move {
1318 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1319 let (mut messages, usage) = stream.await?;
1320
1321 if let Some(usage) = usage {
1322 this.update(cx, |_thread, cx| {
1323 cx.emit(ThreadEvent::UsageUpdated(usage));
1324 })
1325 .ok();
1326 }
1327
1328 let mut new_summary = String::new();
1329 while let Some(message) = messages.stream.next().await {
1330 let text = message?;
1331 let mut lines = text.lines();
1332 new_summary.extend(lines.next());
1333
1334 // Stop if the LLM generated multiple lines.
1335 if lines.next().is_some() {
1336 break;
1337 }
1338 }
1339
1340 this.update(cx, |this, cx| {
1341 if !new_summary.is_empty() {
1342 this.summary = Some(new_summary.into());
1343 }
1344
1345 cx.emit(ThreadEvent::SummaryGenerated);
1346 })?;
1347
1348 anyhow::Ok(())
1349 }
1350 .log_err()
1351 .await
1352 });
1353 }
1354
1355 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1356 let last_message_id = self.messages.last().map(|message| message.id)?;
1357
1358 match &self.detailed_summary_state {
1359 DetailedSummaryState::Generating { message_id, .. }
1360 | DetailedSummaryState::Generated { message_id, .. }
1361 if *message_id == last_message_id =>
1362 {
1363 // Already up-to-date
1364 return None;
1365 }
1366 _ => {}
1367 }
1368
1369 let ConfiguredModel { model, provider } =
1370 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1371
1372 if !provider.is_authenticated(cx) {
1373 return None;
1374 }
1375
1376 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1377
1378 request.messages.push(LanguageModelRequestMessage {
1379 role: Role::User,
1380 content: vec![
1381 "Generate a detailed summary of this conversation. Include:\n\
1382 1. A brief overview of what was discussed\n\
1383 2. Key facts or information discovered\n\
1384 3. Outcomes or conclusions reached\n\
1385 4. Any action items or next steps if any\n\
1386 Format it in Markdown with headings and bullet points."
1387 .into(),
1388 ],
1389 cache: false,
1390 });
1391
1392 let task = cx.spawn(async move |thread, cx| {
1393 let stream = model.stream_completion_text(request, &cx);
1394 let Some(mut messages) = stream.await.log_err() else {
1395 thread
1396 .update(cx, |this, _cx| {
1397 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1398 })
1399 .log_err();
1400
1401 return;
1402 };
1403
1404 let mut new_detailed_summary = String::new();
1405
1406 while let Some(chunk) = messages.stream.next().await {
1407 if let Some(chunk) = chunk.log_err() {
1408 new_detailed_summary.push_str(&chunk);
1409 }
1410 }
1411
1412 thread
1413 .update(cx, |this, _cx| {
1414 this.detailed_summary_state = DetailedSummaryState::Generated {
1415 text: new_detailed_summary.into(),
1416 message_id: last_message_id,
1417 };
1418 })
1419 .log_err();
1420 });
1421
1422 self.detailed_summary_state = DetailedSummaryState::Generating {
1423 message_id: last_message_id,
1424 };
1425
1426 Some(task)
1427 }
1428
1429 pub fn is_generating_detailed_summary(&self) -> bool {
1430 matches!(
1431 self.detailed_summary_state,
1432 DetailedSummaryState::Generating { .. }
1433 )
1434 }
1435
1436 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1437 self.auto_capture_telemetry(cx);
1438 let request = self.to_completion_request(RequestKind::Chat, cx);
1439 let messages = Arc::new(request.messages);
1440 let pending_tool_uses = self
1441 .tool_use
1442 .pending_tool_uses()
1443 .into_iter()
1444 .filter(|tool_use| tool_use.status.is_idle())
1445 .cloned()
1446 .collect::<Vec<_>>();
1447
1448 for tool_use in pending_tool_uses.iter() {
1449 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1450 if tool.needs_confirmation(&tool_use.input, cx)
1451 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1452 {
1453 self.tool_use.confirm_tool_use(
1454 tool_use.id.clone(),
1455 tool_use.ui_text.clone(),
1456 tool_use.input.clone(),
1457 messages.clone(),
1458 tool,
1459 );
1460 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1461 } else {
1462 self.run_tool(
1463 tool_use.id.clone(),
1464 tool_use.ui_text.clone(),
1465 tool_use.input.clone(),
1466 &messages,
1467 tool,
1468 cx,
1469 );
1470 }
1471 }
1472 }
1473
1474 pending_tool_uses
1475 }
1476
1477 pub fn run_tool(
1478 &mut self,
1479 tool_use_id: LanguageModelToolUseId,
1480 ui_text: impl Into<SharedString>,
1481 input: serde_json::Value,
1482 messages: &[LanguageModelRequestMessage],
1483 tool: Arc<dyn Tool>,
1484 cx: &mut Context<Thread>,
1485 ) {
1486 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1487 self.tool_use
1488 .run_pending_tool(tool_use_id, ui_text.into(), task);
1489 }
1490
1491 fn spawn_tool_use(
1492 &mut self,
1493 tool_use_id: LanguageModelToolUseId,
1494 messages: &[LanguageModelRequestMessage],
1495 input: serde_json::Value,
1496 tool: Arc<dyn Tool>,
1497 cx: &mut Context<Thread>,
1498 ) -> Task<()> {
1499 let tool_name: Arc<str> = tool.name().into();
1500
1501 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1502 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1503 } else {
1504 tool.run(
1505 input,
1506 messages,
1507 self.project.clone(),
1508 self.action_log.clone(),
1509 cx,
1510 )
1511 };
1512
1513 // Store the card separately if it exists
1514 if let Some(card) = tool_result.card.clone() {
1515 self.tool_use
1516 .insert_tool_result_card(tool_use_id.clone(), card);
1517 }
1518
1519 cx.spawn({
1520 async move |thread: WeakEntity<Thread>, cx| {
1521 let output = tool_result.output.await;
1522
1523 thread
1524 .update(cx, |thread, cx| {
1525 let pending_tool_use = thread.tool_use.insert_tool_output(
1526 tool_use_id.clone(),
1527 tool_name,
1528 output,
1529 cx,
1530 );
1531 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1532 })
1533 .ok();
1534 }
1535 })
1536 }
1537
1538 fn tool_finished(
1539 &mut self,
1540 tool_use_id: LanguageModelToolUseId,
1541 pending_tool_use: Option<PendingToolUse>,
1542 canceled: bool,
1543 cx: &mut Context<Self>,
1544 ) {
1545 if self.all_tools_finished() {
1546 let model_registry = LanguageModelRegistry::read_global(cx);
1547 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1548 self.attach_tool_results(cx);
1549 if !canceled {
1550 self.send_to_model(model, RequestKind::Chat, cx);
1551 }
1552 }
1553 }
1554
1555 cx.emit(ThreadEvent::ToolFinished {
1556 tool_use_id,
1557 pending_tool_use,
1558 });
1559 }
1560
1561 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1562 // Insert a user message to contain the tool results.
1563 self.insert_user_message(
1564 // TODO: Sending up a user message without any content results in the model sending back
1565 // responses that also don't have any content. We currently don't handle this case well,
1566 // so for now we provide some text to keep the model on track.
1567 "Here are the tool results.",
1568 Vec::new(),
1569 None,
1570 cx,
1571 );
1572 }
1573
1574 /// Cancels the last pending completion, if there are any pending.
1575 ///
1576 /// Returns whether a completion was canceled.
1577 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1578 let canceled = if self.pending_completions.pop().is_some() {
1579 true
1580 } else {
1581 let mut canceled = false;
1582 for pending_tool_use in self.tool_use.cancel_pending() {
1583 canceled = true;
1584 self.tool_finished(
1585 pending_tool_use.id.clone(),
1586 Some(pending_tool_use),
1587 true,
1588 cx,
1589 );
1590 }
1591 canceled
1592 };
1593 self.finalize_pending_checkpoint(cx);
1594 canceled
1595 }
1596
1597 pub fn feedback(&self) -> Option<ThreadFeedback> {
1598 self.feedback
1599 }
1600
1601 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1602 self.message_feedback.get(&message_id).copied()
1603 }
1604
1605 pub fn report_message_feedback(
1606 &mut self,
1607 message_id: MessageId,
1608 feedback: ThreadFeedback,
1609 cx: &mut Context<Self>,
1610 ) -> Task<Result<()>> {
1611 if self.message_feedback.get(&message_id) == Some(&feedback) {
1612 return Task::ready(Ok(()));
1613 }
1614
1615 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1616 let serialized_thread = self.serialize(cx);
1617 let thread_id = self.id().clone();
1618 let client = self.project.read(cx).client();
1619
1620 let enabled_tool_names: Vec<String> = self
1621 .tools()
1622 .read(cx)
1623 .enabled_tools(cx)
1624 .iter()
1625 .map(|tool| tool.name().to_string())
1626 .collect();
1627
1628 self.message_feedback.insert(message_id, feedback);
1629
1630 cx.notify();
1631
1632 let message_content = self
1633 .message(message_id)
1634 .map(|msg| msg.to_string())
1635 .unwrap_or_default();
1636
1637 cx.background_spawn(async move {
1638 let final_project_snapshot = final_project_snapshot.await;
1639 let serialized_thread = serialized_thread.await?;
1640 let thread_data =
1641 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1642
1643 let rating = match feedback {
1644 ThreadFeedback::Positive => "positive",
1645 ThreadFeedback::Negative => "negative",
1646 };
1647 telemetry::event!(
1648 "Assistant Thread Rated",
1649 rating,
1650 thread_id,
1651 enabled_tool_names,
1652 message_id = message_id.0,
1653 message_content,
1654 thread_data,
1655 final_project_snapshot
1656 );
1657 client.telemetry().flush_events();
1658
1659 Ok(())
1660 })
1661 }
1662
1663 pub fn report_feedback(
1664 &mut self,
1665 feedback: ThreadFeedback,
1666 cx: &mut Context<Self>,
1667 ) -> Task<Result<()>> {
1668 let last_assistant_message_id = self
1669 .messages
1670 .iter()
1671 .rev()
1672 .find(|msg| msg.role == Role::Assistant)
1673 .map(|msg| msg.id);
1674
1675 if let Some(message_id) = last_assistant_message_id {
1676 self.report_message_feedback(message_id, feedback, cx)
1677 } else {
1678 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1679 let serialized_thread = self.serialize(cx);
1680 let thread_id = self.id().clone();
1681 let client = self.project.read(cx).client();
1682 self.feedback = Some(feedback);
1683 cx.notify();
1684
1685 cx.background_spawn(async move {
1686 let final_project_snapshot = final_project_snapshot.await;
1687 let serialized_thread = serialized_thread.await?;
1688 let thread_data = serde_json::to_value(serialized_thread)
1689 .unwrap_or_else(|_| serde_json::Value::Null);
1690
1691 let rating = match feedback {
1692 ThreadFeedback::Positive => "positive",
1693 ThreadFeedback::Negative => "negative",
1694 };
1695 telemetry::event!(
1696 "Assistant Thread Rated",
1697 rating,
1698 thread_id,
1699 thread_data,
1700 final_project_snapshot
1701 );
1702 client.telemetry().flush_events();
1703
1704 Ok(())
1705 })
1706 }
1707 }
1708
1709 /// Create a snapshot of the current project state including git information and unsaved buffers.
1710 fn project_snapshot(
1711 project: Entity<Project>,
1712 cx: &mut Context<Self>,
1713 ) -> Task<Arc<ProjectSnapshot>> {
1714 let git_store = project.read(cx).git_store().clone();
1715 let worktree_snapshots: Vec<_> = project
1716 .read(cx)
1717 .visible_worktrees(cx)
1718 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1719 .collect();
1720
1721 cx.spawn(async move |_, cx| {
1722 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1723
1724 let mut unsaved_buffers = Vec::new();
1725 cx.update(|app_cx| {
1726 let buffer_store = project.read(app_cx).buffer_store();
1727 for buffer_handle in buffer_store.read(app_cx).buffers() {
1728 let buffer = buffer_handle.read(app_cx);
1729 if buffer.is_dirty() {
1730 if let Some(file) = buffer.file() {
1731 let path = file.path().to_string_lossy().to_string();
1732 unsaved_buffers.push(path);
1733 }
1734 }
1735 }
1736 })
1737 .ok();
1738
1739 Arc::new(ProjectSnapshot {
1740 worktree_snapshots,
1741 unsaved_buffer_paths: unsaved_buffers,
1742 timestamp: Utc::now(),
1743 })
1744 })
1745 }
1746
1747 fn worktree_snapshot(
1748 worktree: Entity<project::Worktree>,
1749 git_store: Entity<GitStore>,
1750 cx: &App,
1751 ) -> Task<WorktreeSnapshot> {
1752 cx.spawn(async move |cx| {
1753 // Get worktree path and snapshot
1754 let worktree_info = cx.update(|app_cx| {
1755 let worktree = worktree.read(app_cx);
1756 let path = worktree.abs_path().to_string_lossy().to_string();
1757 let snapshot = worktree.snapshot();
1758 (path, snapshot)
1759 });
1760
1761 let Ok((worktree_path, _snapshot)) = worktree_info else {
1762 return WorktreeSnapshot {
1763 worktree_path: String::new(),
1764 git_state: None,
1765 };
1766 };
1767
1768 let git_state = git_store
1769 .update(cx, |git_store, cx| {
1770 git_store
1771 .repositories()
1772 .values()
1773 .find(|repo| {
1774 repo.read(cx)
1775 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1776 .is_some()
1777 })
1778 .cloned()
1779 })
1780 .ok()
1781 .flatten()
1782 .map(|repo| {
1783 repo.update(cx, |repo, _| {
1784 let current_branch =
1785 repo.branch.as_ref().map(|branch| branch.name.to_string());
1786 repo.send_job(None, |state, _| async move {
1787 let RepositoryState::Local { backend, .. } = state else {
1788 return GitState {
1789 remote_url: None,
1790 head_sha: None,
1791 current_branch,
1792 diff: None,
1793 };
1794 };
1795
1796 let remote_url = backend.remote_url("origin");
1797 let head_sha = backend.head_sha();
1798 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1799
1800 GitState {
1801 remote_url,
1802 head_sha,
1803 current_branch,
1804 diff,
1805 }
1806 })
1807 })
1808 });
1809
1810 let git_state = match git_state {
1811 Some(git_state) => match git_state.ok() {
1812 Some(git_state) => git_state.await.ok(),
1813 None => None,
1814 },
1815 None => None,
1816 };
1817
1818 WorktreeSnapshot {
1819 worktree_path,
1820 git_state,
1821 }
1822 })
1823 }
1824
1825 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1826 let mut markdown = Vec::new();
1827
1828 if let Some(summary) = self.summary() {
1829 writeln!(markdown, "# {summary}\n")?;
1830 };
1831
1832 for message in self.messages() {
1833 writeln!(
1834 markdown,
1835 "## {role}\n",
1836 role = match message.role {
1837 Role::User => "User",
1838 Role::Assistant => "Assistant",
1839 Role::System => "System",
1840 }
1841 )?;
1842
1843 if !message.context.is_empty() {
1844 writeln!(markdown, "{}", message.context)?;
1845 }
1846
1847 for segment in &message.segments {
1848 match segment {
1849 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1850 MessageSegment::Thinking(text) => {
1851 writeln!(markdown, "<think>{}</think>\n", text)?
1852 }
1853 }
1854 }
1855
1856 for tool_use in self.tool_uses_for_message(message.id, cx) {
1857 writeln!(
1858 markdown,
1859 "**Use Tool: {} ({})**",
1860 tool_use.name, tool_use.id
1861 )?;
1862 writeln!(markdown, "```json")?;
1863 writeln!(
1864 markdown,
1865 "{}",
1866 serde_json::to_string_pretty(&tool_use.input)?
1867 )?;
1868 writeln!(markdown, "```")?;
1869 }
1870
1871 for tool_result in self.tool_results_for_message(message.id) {
1872 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1873 if tool_result.is_error {
1874 write!(markdown, " (Error)")?;
1875 }
1876
1877 writeln!(markdown, "**\n")?;
1878 writeln!(markdown, "{}", tool_result.content)?;
1879 }
1880 }
1881
1882 Ok(String::from_utf8_lossy(&markdown).to_string())
1883 }
1884
1885 pub fn keep_edits_in_range(
1886 &mut self,
1887 buffer: Entity<language::Buffer>,
1888 buffer_range: Range<language::Anchor>,
1889 cx: &mut Context<Self>,
1890 ) {
1891 self.action_log.update(cx, |action_log, cx| {
1892 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1893 });
1894 }
1895
1896 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1897 self.action_log
1898 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1899 }
1900
1901 pub fn reject_edits_in_ranges(
1902 &mut self,
1903 buffer: Entity<language::Buffer>,
1904 buffer_ranges: Vec<Range<language::Anchor>>,
1905 cx: &mut Context<Self>,
1906 ) -> Task<Result<()>> {
1907 self.action_log.update(cx, |action_log, cx| {
1908 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1909 })
1910 }
1911
1912 pub fn action_log(&self) -> &Entity<ActionLog> {
1913 &self.action_log
1914 }
1915
1916 pub fn project(&self) -> &Entity<Project> {
1917 &self.project
1918 }
1919
1920 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1921 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1922 return;
1923 }
1924
1925 let now = Instant::now();
1926 if let Some(last) = self.last_auto_capture_at {
1927 if now.duration_since(last).as_secs() < 10 {
1928 return;
1929 }
1930 }
1931
1932 self.last_auto_capture_at = Some(now);
1933
1934 let thread_id = self.id().clone();
1935 let github_login = self
1936 .project
1937 .read(cx)
1938 .user_store()
1939 .read(cx)
1940 .current_user()
1941 .map(|user| user.github_login.clone());
1942 let client = self.project.read(cx).client().clone();
1943 let serialize_task = self.serialize(cx);
1944
1945 cx.background_executor()
1946 .spawn(async move {
1947 if let Ok(serialized_thread) = serialize_task.await {
1948 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1949 telemetry::event!(
1950 "Agent Thread Auto-Captured",
1951 thread_id = thread_id.to_string(),
1952 thread_data = thread_data,
1953 auto_capture_reason = "tracked_user",
1954 github_login = github_login
1955 );
1956
1957 client.telemetry().flush_events();
1958 }
1959 }
1960 })
1961 .detach();
1962 }
1963
1964 pub fn cumulative_token_usage(&self) -> TokenUsage {
1965 self.cumulative_token_usage
1966 }
1967
1968 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1969 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1970 return TotalTokenUsage::default();
1971 };
1972
1973 let max = model.model.max_token_count();
1974
1975 let index = self
1976 .messages
1977 .iter()
1978 .position(|msg| msg.id == message_id)
1979 .unwrap_or(0);
1980
1981 if index == 0 {
1982 return TotalTokenUsage { total: 0, max };
1983 }
1984
1985 let token_usage = &self
1986 .request_token_usage
1987 .get(index - 1)
1988 .cloned()
1989 .unwrap_or_default();
1990
1991 TotalTokenUsage {
1992 total: token_usage.total_tokens() as usize,
1993 max,
1994 }
1995 }
1996
1997 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1998 let model_registry = LanguageModelRegistry::read_global(cx);
1999 let Some(model) = model_registry.default_model() else {
2000 return TotalTokenUsage::default();
2001 };
2002
2003 let max = model.model.max_token_count();
2004
2005 if let Some(exceeded_error) = &self.exceeded_window_error {
2006 if model.model.id() == exceeded_error.model_id {
2007 return TotalTokenUsage {
2008 total: exceeded_error.token_count,
2009 max,
2010 };
2011 }
2012 }
2013
2014 let total = self
2015 .token_usage_at_last_message()
2016 .unwrap_or_default()
2017 .total_tokens() as usize;
2018
2019 TotalTokenUsage { total, max }
2020 }
2021
2022 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2023 self.request_token_usage
2024 .get(self.messages.len().saturating_sub(1))
2025 .or_else(|| self.request_token_usage.last())
2026 .cloned()
2027 }
2028
2029 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2030 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2031 self.request_token_usage
2032 .resize(self.messages.len(), placeholder);
2033
2034 if let Some(last) = self.request_token_usage.last_mut() {
2035 *last = token_usage;
2036 }
2037 }
2038
2039 pub fn deny_tool_use(
2040 &mut self,
2041 tool_use_id: LanguageModelToolUseId,
2042 tool_name: Arc<str>,
2043 cx: &mut Context<Self>,
2044 ) {
2045 let err = Err(anyhow::anyhow!(
2046 "Permission to run tool action denied by user"
2047 ));
2048
2049 self.tool_use
2050 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2051 self.tool_finished(tool_use_id.clone(), None, true, cx);
2052 }
2053}
2054
2055#[derive(Debug, Clone, Error)]
2056pub enum ThreadError {
2057 #[error("Payment required")]
2058 PaymentRequired,
2059 #[error("Max monthly spend reached")]
2060 MaxMonthlySpendReached,
2061 #[error("Model request limit reached")]
2062 ModelRequestLimitReached { plan: Plan },
2063 #[error("Message {header}: {message}")]
2064 Message {
2065 header: SharedString,
2066 message: SharedString,
2067 },
2068}
2069
2070#[derive(Debug, Clone)]
2071pub enum ThreadEvent {
2072 ShowError(ThreadError),
2073 UsageUpdated(RequestUsage),
2074 StreamedCompletion,
2075 StreamedAssistantText(MessageId, String),
2076 StreamedAssistantThinking(MessageId, String),
2077 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2078 MessageAdded(MessageId),
2079 MessageEdited(MessageId),
2080 MessageDeleted(MessageId),
2081 SummaryGenerated,
2082 SummaryChanged,
2083 UsePendingTools {
2084 tool_uses: Vec<PendingToolUse>,
2085 },
2086 ToolFinished {
2087 #[allow(unused)]
2088 tool_use_id: LanguageModelToolUseId,
2089 /// The pending tool use that corresponds to this tool.
2090 pending_tool_use: Option<PendingToolUse>,
2091 },
2092 CheckpointChanged,
2093 ToolConfirmationNeeded,
2094}
2095
2096impl EventEmitter<ThreadEvent> for Thread {}
2097
2098struct PendingCompletion {
2099 id: usize,
2100 _task: Task<()>,
2101}
2102
2103#[cfg(test)]
2104mod tests {
2105 use super::*;
2106 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2107 use assistant_settings::AssistantSettings;
2108 use context_server::ContextServerSettings;
2109 use editor::EditorSettings;
2110 use gpui::TestAppContext;
2111 use project::{FakeFs, Project};
2112 use prompt_store::PromptBuilder;
2113 use serde_json::json;
2114 use settings::{Settings, SettingsStore};
2115 use std::sync::Arc;
2116 use theme::ThemeSettings;
2117 use util::path;
2118 use workspace::Workspace;
2119
2120 #[gpui::test]
2121 async fn test_message_with_context(cx: &mut TestAppContext) {
2122 init_test_settings(cx);
2123
2124 let project = create_test_project(
2125 cx,
2126 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2127 )
2128 .await;
2129
2130 let (_workspace, _thread_store, thread, context_store) =
2131 setup_test_environment(cx, project.clone()).await;
2132
2133 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2134 .await
2135 .unwrap();
2136
2137 let context =
2138 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2139
2140 // Insert user message with context
2141 let message_id = thread.update(cx, |thread, cx| {
2142 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2143 });
2144
2145 // Check content and context in message object
2146 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2147
2148 // Use different path format strings based on platform for the test
2149 #[cfg(windows)]
2150 let path_part = r"test\code.rs";
2151 #[cfg(not(windows))]
2152 let path_part = "test/code.rs";
2153
2154 let expected_context = format!(
2155 r#"
2156<context>
2157The following items were attached by the user. You don't need to use other tools to read them.
2158
2159<files>
2160```rs {path_part}
2161fn main() {{
2162 println!("Hello, world!");
2163}}
2164```
2165</files>
2166</context>
2167"#
2168 );
2169
2170 assert_eq!(message.role, Role::User);
2171 assert_eq!(message.segments.len(), 1);
2172 assert_eq!(
2173 message.segments[0],
2174 MessageSegment::Text("Please explain this code".to_string())
2175 );
2176 assert_eq!(message.context, expected_context);
2177
2178 // Check message in request
2179 let request = thread.update(cx, |thread, cx| {
2180 thread.to_completion_request(RequestKind::Chat, cx)
2181 });
2182
2183 assert_eq!(request.messages.len(), 2);
2184 let expected_full_message = format!("{}Please explain this code", expected_context);
2185 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2186 }
2187
2188 #[gpui::test]
2189 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2190 init_test_settings(cx);
2191
2192 let project = create_test_project(
2193 cx,
2194 json!({
2195 "file1.rs": "fn function1() {}\n",
2196 "file2.rs": "fn function2() {}\n",
2197 "file3.rs": "fn function3() {}\n",
2198 }),
2199 )
2200 .await;
2201
2202 let (_, _thread_store, thread, context_store) =
2203 setup_test_environment(cx, project.clone()).await;
2204
2205 // Open files individually
2206 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2207 .await
2208 .unwrap();
2209 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2210 .await
2211 .unwrap();
2212 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2213 .await
2214 .unwrap();
2215
2216 // Get the context objects
2217 let contexts = context_store.update(cx, |store, _| store.context().clone());
2218 assert_eq!(contexts.len(), 3);
2219
2220 // First message with context 1
2221 let message1_id = thread.update(cx, |thread, cx| {
2222 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2223 });
2224
2225 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2226 let message2_id = thread.update(cx, |thread, cx| {
2227 thread.insert_user_message(
2228 "Message 2",
2229 vec![contexts[0].clone(), contexts[1].clone()],
2230 None,
2231 cx,
2232 )
2233 });
2234
2235 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2236 let message3_id = thread.update(cx, |thread, cx| {
2237 thread.insert_user_message(
2238 "Message 3",
2239 vec![
2240 contexts[0].clone(),
2241 contexts[1].clone(),
2242 contexts[2].clone(),
2243 ],
2244 None,
2245 cx,
2246 )
2247 });
2248
2249 // Check what contexts are included in each message
2250 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2251 (
2252 thread.message(message1_id).unwrap().clone(),
2253 thread.message(message2_id).unwrap().clone(),
2254 thread.message(message3_id).unwrap().clone(),
2255 )
2256 });
2257
2258 // First message should include context 1
2259 assert!(message1.context.contains("file1.rs"));
2260
2261 // Second message should include only context 2 (not 1)
2262 assert!(!message2.context.contains("file1.rs"));
2263 assert!(message2.context.contains("file2.rs"));
2264
2265 // Third message should include only context 3 (not 1 or 2)
2266 assert!(!message3.context.contains("file1.rs"));
2267 assert!(!message3.context.contains("file2.rs"));
2268 assert!(message3.context.contains("file3.rs"));
2269
2270 // Check entire request to make sure all contexts are properly included
2271 let request = thread.update(cx, |thread, cx| {
2272 thread.to_completion_request(RequestKind::Chat, cx)
2273 });
2274
2275 // The request should contain all 3 messages
2276 assert_eq!(request.messages.len(), 4);
2277
2278 // Check that the contexts are properly formatted in each message
2279 assert!(request.messages[1].string_contents().contains("file1.rs"));
2280 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2281 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2282
2283 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2284 assert!(request.messages[2].string_contents().contains("file2.rs"));
2285 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2286
2287 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2288 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2289 assert!(request.messages[3].string_contents().contains("file3.rs"));
2290 }
2291
2292 #[gpui::test]
2293 async fn test_message_without_files(cx: &mut TestAppContext) {
2294 init_test_settings(cx);
2295
2296 let project = create_test_project(
2297 cx,
2298 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2299 )
2300 .await;
2301
2302 let (_, _thread_store, thread, _context_store) =
2303 setup_test_environment(cx, project.clone()).await;
2304
2305 // Insert user message without any context (empty context vector)
2306 let message_id = thread.update(cx, |thread, cx| {
2307 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2308 });
2309
2310 // Check content and context in message object
2311 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2312
2313 // Context should be empty when no files are included
2314 assert_eq!(message.role, Role::User);
2315 assert_eq!(message.segments.len(), 1);
2316 assert_eq!(
2317 message.segments[0],
2318 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2319 );
2320 assert_eq!(message.context, "");
2321
2322 // Check message in request
2323 let request = thread.update(cx, |thread, cx| {
2324 thread.to_completion_request(RequestKind::Chat, cx)
2325 });
2326
2327 assert_eq!(request.messages.len(), 2);
2328 assert_eq!(
2329 request.messages[1].string_contents(),
2330 "What is the best way to learn Rust?"
2331 );
2332
2333 // Add second message, also without context
2334 let message2_id = thread.update(cx, |thread, cx| {
2335 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2336 });
2337
2338 let message2 =
2339 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2340 assert_eq!(message2.context, "");
2341
2342 // Check that both messages appear in the request
2343 let request = thread.update(cx, |thread, cx| {
2344 thread.to_completion_request(RequestKind::Chat, cx)
2345 });
2346
2347 assert_eq!(request.messages.len(), 3);
2348 assert_eq!(
2349 request.messages[1].string_contents(),
2350 "What is the best way to learn Rust?"
2351 );
2352 assert_eq!(
2353 request.messages[2].string_contents(),
2354 "Are there any good books?"
2355 );
2356 }
2357
2358 #[gpui::test]
2359 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2360 init_test_settings(cx);
2361
2362 let project = create_test_project(
2363 cx,
2364 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2365 )
2366 .await;
2367
2368 let (_workspace, _thread_store, thread, context_store) =
2369 setup_test_environment(cx, project.clone()).await;
2370
2371 // Open buffer and add it to context
2372 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2373 .await
2374 .unwrap();
2375
2376 let context =
2377 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2378
2379 // Insert user message with the buffer as context
2380 thread.update(cx, |thread, cx| {
2381 thread.insert_user_message("Explain this code", vec![context], None, cx)
2382 });
2383
2384 // Create a request and check that it doesn't have a stale buffer warning yet
2385 let initial_request = thread.update(cx, |thread, cx| {
2386 thread.to_completion_request(RequestKind::Chat, cx)
2387 });
2388
2389 // Make sure we don't have a stale file warning yet
2390 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2391 msg.string_contents()
2392 .contains("These files changed since last read:")
2393 });
2394 assert!(
2395 !has_stale_warning,
2396 "Should not have stale buffer warning before buffer is modified"
2397 );
2398
2399 // Modify the buffer
2400 buffer.update(cx, |buffer, cx| {
2401 // Find a position at the end of line 1
2402 buffer.edit(
2403 [(1..1, "\n println!(\"Added a new line\");\n")],
2404 None,
2405 cx,
2406 );
2407 });
2408
2409 // Insert another user message without context
2410 thread.update(cx, |thread, cx| {
2411 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2412 });
2413
2414 // Create a new request and check for the stale buffer warning
2415 let new_request = thread.update(cx, |thread, cx| {
2416 thread.to_completion_request(RequestKind::Chat, cx)
2417 });
2418
2419 // We should have a stale file warning as the last message
2420 let last_message = new_request
2421 .messages
2422 .last()
2423 .expect("Request should have messages");
2424
2425 // The last message should be the stale buffer notification
2426 assert_eq!(last_message.role, Role::User);
2427
2428 // Check the exact content of the message
2429 let expected_content = "These files changed since last read:\n- code.rs\n";
2430 assert_eq!(
2431 last_message.string_contents(),
2432 expected_content,
2433 "Last message should be exactly the stale buffer notification"
2434 );
2435 }
2436
2437 fn init_test_settings(cx: &mut TestAppContext) {
2438 cx.update(|cx| {
2439 let settings_store = SettingsStore::test(cx);
2440 cx.set_global(settings_store);
2441 language::init(cx);
2442 Project::init_settings(cx);
2443 AssistantSettings::register(cx);
2444 prompt_store::init(cx);
2445 thread_store::init(cx);
2446 workspace::init_settings(cx);
2447 ThemeSettings::register(cx);
2448 ContextServerSettings::register(cx);
2449 EditorSettings::register(cx);
2450 });
2451 }
2452
2453 // Helper to create a test project with test files
2454 async fn create_test_project(
2455 cx: &mut TestAppContext,
2456 files: serde_json::Value,
2457 ) -> Entity<Project> {
2458 let fs = FakeFs::new(cx.executor());
2459 fs.insert_tree(path!("/test"), files).await;
2460 Project::test(fs, [path!("/test").as_ref()], cx).await
2461 }
2462
2463 async fn setup_test_environment(
2464 cx: &mut TestAppContext,
2465 project: Entity<Project>,
2466 ) -> (
2467 Entity<Workspace>,
2468 Entity<ThreadStore>,
2469 Entity<Thread>,
2470 Entity<ContextStore>,
2471 ) {
2472 let (workspace, cx) =
2473 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2474
2475 let thread_store = cx
2476 .update(|_, cx| {
2477 ThreadStore::load(
2478 project.clone(),
2479 cx.new(|_| ToolWorkingSet::default()),
2480 Arc::new(PromptBuilder::new(None).unwrap()),
2481 cx,
2482 )
2483 })
2484 .await
2485 .unwrap();
2486
2487 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2488 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2489
2490 (workspace, thread_store, thread, context_store)
2491 }
2492
2493 async fn add_file_to_context(
2494 project: &Entity<Project>,
2495 context_store: &Entity<ContextStore>,
2496 path: &str,
2497 cx: &mut TestAppContext,
2498 ) -> Result<Entity<language::Buffer>> {
2499 let buffer_path = project
2500 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2501 .unwrap();
2502
2503 let buffer = project
2504 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2505 .await
2506 .unwrap();
2507
2508 context_store
2509 .update(cx, |store, cx| {
2510 store.add_file_from_buffer(buffer.clone(), cx)
2511 })
2512 .await?;
2513
2514 Ok(buffer)
2515 }
2516}