1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
23};
24use project::Project;
25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
26use prompt_store::PromptBuilder;
27use proto::Plan;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::Settings;
31use thiserror::Error;
32use util::{ResultExt as _, TryFutureExt as _, post_inc};
33use uuid::Uuid;
34
35use crate::context::{AssistantContext, ContextId, format_context_as_string};
36use crate::thread_store::{
37 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
38 SerializedToolUse, SharedProjectContext,
39};
40use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
41
42#[derive(Debug, Clone, Copy)]
43pub enum RequestKind {
44 Chat,
45 /// Used when summarizing a thread.
46 Summarize,
47}
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
73pub struct MessageId(pub(crate) usize);
74
75impl MessageId {
76 fn post_inc(&mut self) -> Self {
77 Self(post_inc(&mut self.0))
78 }
79}
80
81/// A message in a [`Thread`].
82#[derive(Debug, Clone)]
83pub struct Message {
84 pub id: MessageId,
85 pub role: Role,
86 pub segments: Vec<MessageSegment>,
87 pub context: String,
88}
89
90impl Message {
91 /// Returns whether the message contains any meaningful text that should be displayed
92 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
93 pub fn should_display_content(&self) -> bool {
94 self.segments.iter().all(|segment| segment.should_display())
95 }
96
97 pub fn push_thinking(&mut self, text: &str) {
98 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
99 segment.push_str(text);
100 } else {
101 self.segments
102 .push(MessageSegment::Thinking(text.to_string()));
103 }
104 }
105
106 pub fn push_text(&mut self, text: &str) {
107 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
108 segment.push_str(text);
109 } else {
110 self.segments.push(MessageSegment::Text(text.to_string()));
111 }
112 }
113
114 pub fn to_string(&self) -> String {
115 let mut result = String::new();
116
117 if !self.context.is_empty() {
118 result.push_str(&self.context);
119 }
120
121 for segment in &self.segments {
122 match segment {
123 MessageSegment::Text(text) => result.push_str(text),
124 MessageSegment::Thinking(text) => {
125 result.push_str("<think>");
126 result.push_str(text);
127 result.push_str("</think>");
128 }
129 }
130 }
131
132 result
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub enum MessageSegment {
138 Text(String),
139 Thinking(String),
140}
141
142impl MessageSegment {
143 pub fn text_mut(&mut self) -> &mut String {
144 match self {
145 Self::Text(text) => text,
146 Self::Thinking(text) => text,
147 }
148 }
149
150 pub fn should_display(&self) -> bool {
151 // We add USING_TOOL_MARKER when making a request that includes tool uses
152 // without non-whitespace text around them, and this can cause the model
153 // to mimic the pattern, so we consider those segments not displayable.
154 match self {
155 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
156 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ProjectSnapshot {
163 pub worktree_snapshots: Vec<WorktreeSnapshot>,
164 pub unsaved_buffer_paths: Vec<String>,
165 pub timestamp: DateTime<Utc>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct WorktreeSnapshot {
170 pub worktree_path: String,
171 pub git_state: Option<GitState>,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct GitState {
176 pub remote_url: Option<String>,
177 pub head_sha: Option<String>,
178 pub current_branch: Option<String>,
179 pub diff: Option<String>,
180}
181
182#[derive(Clone)]
183pub struct ThreadCheckpoint {
184 message_id: MessageId,
185 git_checkpoint: GitStoreCheckpoint,
186}
187
188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
189pub enum ThreadFeedback {
190 Positive,
191 Negative,
192}
193
194pub enum LastRestoreCheckpoint {
195 Pending {
196 message_id: MessageId,
197 },
198 Error {
199 message_id: MessageId,
200 error: String,
201 },
202}
203
204impl LastRestoreCheckpoint {
205 pub fn message_id(&self) -> MessageId {
206 match self {
207 LastRestoreCheckpoint::Pending { message_id } => *message_id,
208 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
209 }
210 }
211}
212
213#[derive(Clone, Debug, Default, Serialize, Deserialize)]
214pub enum DetailedSummaryState {
215 #[default]
216 NotGenerated,
217 Generating {
218 message_id: MessageId,
219 },
220 Generated {
221 text: SharedString,
222 message_id: MessageId,
223 },
224}
225
226#[derive(Default)]
227pub struct TotalTokenUsage {
228 pub total: usize,
229 pub max: usize,
230}
231
232impl TotalTokenUsage {
233 pub fn ratio(&self) -> TokenUsageRatio {
234 #[cfg(debug_assertions)]
235 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
236 .unwrap_or("0.8".to_string())
237 .parse()
238 .unwrap();
239 #[cfg(not(debug_assertions))]
240 let warning_threshold: f32 = 0.8;
241
242 if self.total >= self.max {
243 TokenUsageRatio::Exceeded
244 } else if self.total as f32 / self.max as f32 >= warning_threshold {
245 TokenUsageRatio::Warning
246 } else {
247 TokenUsageRatio::Normal
248 }
249 }
250
251 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
252 TotalTokenUsage {
253 total: self.total + tokens,
254 max: self.max,
255 }
256 }
257}
258
259#[derive(Debug, Default, PartialEq, Eq)]
260pub enum TokenUsageRatio {
261 #[default]
262 Normal,
263 Warning,
264 Exceeded,
265}
266
267/// A thread of conversation with the LLM.
268pub struct Thread {
269 id: ThreadId,
270 updated_at: DateTime<Utc>,
271 summary: Option<SharedString>,
272 pending_summary: Task<Option<()>>,
273 detailed_summary_state: DetailedSummaryState,
274 messages: Vec<Message>,
275 next_message_id: MessageId,
276 context: BTreeMap<ContextId, AssistantContext>,
277 context_by_message: HashMap<MessageId, Vec<ContextId>>,
278 project_context: SharedProjectContext,
279 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
280 completion_count: usize,
281 pending_completions: Vec<PendingCompletion>,
282 project: Entity<Project>,
283 prompt_builder: Arc<PromptBuilder>,
284 tools: Entity<ToolWorkingSet>,
285 tool_use: ToolUseState,
286 action_log: Entity<ActionLog>,
287 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
288 pending_checkpoint: Option<ThreadCheckpoint>,
289 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
290 request_token_usage: Vec<TokenUsage>,
291 cumulative_token_usage: TokenUsage,
292 exceeded_window_error: Option<ExceededWindowError>,
293 feedback: Option<ThreadFeedback>,
294 message_feedback: HashMap<MessageId, ThreadFeedback>,
295 last_auto_capture_at: Option<Instant>,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct ExceededWindowError {
300 /// Model used when last message exceeded context window
301 model_id: LanguageModelId,
302 /// Token count including last message
303 token_count: usize,
304}
305
306impl Thread {
307 pub fn new(
308 project: Entity<Project>,
309 tools: Entity<ToolWorkingSet>,
310 prompt_builder: Arc<PromptBuilder>,
311 system_prompt: SharedProjectContext,
312 cx: &mut Context<Self>,
313 ) -> Self {
314 Self {
315 id: ThreadId::new(),
316 updated_at: Utc::now(),
317 summary: None,
318 pending_summary: Task::ready(None),
319 detailed_summary_state: DetailedSummaryState::NotGenerated,
320 messages: Vec::new(),
321 next_message_id: MessageId(0),
322 context: BTreeMap::default(),
323 context_by_message: HashMap::default(),
324 project_context: system_prompt,
325 checkpoints_by_message: HashMap::default(),
326 completion_count: 0,
327 pending_completions: Vec::new(),
328 project: project.clone(),
329 prompt_builder,
330 tools: tools.clone(),
331 last_restore_checkpoint: None,
332 pending_checkpoint: None,
333 tool_use: ToolUseState::new(tools.clone()),
334 action_log: cx.new(|_| ActionLog::new(project.clone())),
335 initial_project_snapshot: {
336 let project_snapshot = Self::project_snapshot(project, cx);
337 cx.foreground_executor()
338 .spawn(async move { Some(project_snapshot.await) })
339 .shared()
340 },
341 request_token_usage: Vec::new(),
342 cumulative_token_usage: TokenUsage::default(),
343 exceeded_window_error: None,
344 feedback: None,
345 message_feedback: HashMap::default(),
346 last_auto_capture_at: None,
347 }
348 }
349
350 pub fn deserialize(
351 id: ThreadId,
352 serialized: SerializedThread,
353 project: Entity<Project>,
354 tools: Entity<ToolWorkingSet>,
355 prompt_builder: Arc<PromptBuilder>,
356 project_context: SharedProjectContext,
357 cx: &mut Context<Self>,
358 ) -> Self {
359 let next_message_id = MessageId(
360 serialized
361 .messages
362 .last()
363 .map(|message| message.id.0 + 1)
364 .unwrap_or(0),
365 );
366 let tool_use =
367 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
368
369 Self {
370 id,
371 updated_at: serialized.updated_at,
372 summary: Some(serialized.summary),
373 pending_summary: Task::ready(None),
374 detailed_summary_state: serialized.detailed_summary_state,
375 messages: serialized
376 .messages
377 .into_iter()
378 .map(|message| Message {
379 id: message.id,
380 role: message.role,
381 segments: message
382 .segments
383 .into_iter()
384 .map(|segment| match segment {
385 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
386 SerializedMessageSegment::Thinking { text } => {
387 MessageSegment::Thinking(text)
388 }
389 })
390 .collect(),
391 context: message.context,
392 })
393 .collect(),
394 next_message_id,
395 context: BTreeMap::default(),
396 context_by_message: HashMap::default(),
397 project_context,
398 checkpoints_by_message: HashMap::default(),
399 completion_count: 0,
400 pending_completions: Vec::new(),
401 last_restore_checkpoint: None,
402 pending_checkpoint: None,
403 project: project.clone(),
404 prompt_builder,
405 tools,
406 tool_use,
407 action_log: cx.new(|_| ActionLog::new(project)),
408 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
409 request_token_usage: serialized.request_token_usage,
410 cumulative_token_usage: serialized.cumulative_token_usage,
411 exceeded_window_error: None,
412 feedback: None,
413 message_feedback: HashMap::default(),
414 last_auto_capture_at: None,
415 }
416 }
417
418 pub fn id(&self) -> &ThreadId {
419 &self.id
420 }
421
422 pub fn is_empty(&self) -> bool {
423 self.messages.is_empty()
424 }
425
426 pub fn updated_at(&self) -> DateTime<Utc> {
427 self.updated_at
428 }
429
430 pub fn touch_updated_at(&mut self) {
431 self.updated_at = Utc::now();
432 }
433
434 pub fn summary(&self) -> Option<SharedString> {
435 self.summary.clone()
436 }
437
438 pub fn project_context(&self) -> SharedProjectContext {
439 self.project_context.clone()
440 }
441
442 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
443
444 pub fn summary_or_default(&self) -> SharedString {
445 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
446 }
447
448 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
449 let Some(current_summary) = &self.summary else {
450 // Don't allow setting summary until generated
451 return;
452 };
453
454 let mut new_summary = new_summary.into();
455
456 if new_summary.is_empty() {
457 new_summary = Self::DEFAULT_SUMMARY;
458 }
459
460 if current_summary != &new_summary {
461 self.summary = Some(new_summary);
462 cx.emit(ThreadEvent::SummaryChanged);
463 }
464 }
465
466 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
467 self.latest_detailed_summary()
468 .unwrap_or_else(|| self.text().into())
469 }
470
471 fn latest_detailed_summary(&self) -> Option<SharedString> {
472 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
473 Some(text.clone())
474 } else {
475 None
476 }
477 }
478
479 pub fn message(&self, id: MessageId) -> Option<&Message> {
480 self.messages.iter().find(|message| message.id == id)
481 }
482
483 pub fn messages(&self) -> impl Iterator<Item = &Message> {
484 self.messages.iter()
485 }
486
487 pub fn is_generating(&self) -> bool {
488 !self.pending_completions.is_empty() || !self.all_tools_finished()
489 }
490
491 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
492 &self.tools
493 }
494
495 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
496 self.tool_use
497 .pending_tool_uses()
498 .into_iter()
499 .find(|tool_use| &tool_use.id == id)
500 }
501
502 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
503 self.tool_use
504 .pending_tool_uses()
505 .into_iter()
506 .filter(|tool_use| tool_use.status.needs_confirmation())
507 }
508
509 pub fn has_pending_tool_uses(&self) -> bool {
510 !self.tool_use.pending_tool_uses().is_empty()
511 }
512
513 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
514 self.checkpoints_by_message.get(&id).cloned()
515 }
516
517 pub fn restore_checkpoint(
518 &mut self,
519 checkpoint: ThreadCheckpoint,
520 cx: &mut Context<Self>,
521 ) -> Task<Result<()>> {
522 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
523 message_id: checkpoint.message_id,
524 });
525 cx.emit(ThreadEvent::CheckpointChanged);
526 cx.notify();
527
528 let git_store = self.project().read(cx).git_store().clone();
529 let restore = git_store.update(cx, |git_store, cx| {
530 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
531 });
532
533 cx.spawn(async move |this, cx| {
534 let result = restore.await;
535 this.update(cx, |this, cx| {
536 if let Err(err) = result.as_ref() {
537 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
538 message_id: checkpoint.message_id,
539 error: err.to_string(),
540 });
541 } else {
542 this.truncate(checkpoint.message_id, cx);
543 this.last_restore_checkpoint = None;
544 }
545 this.pending_checkpoint = None;
546 cx.emit(ThreadEvent::CheckpointChanged);
547 cx.notify();
548 })?;
549 result
550 })
551 }
552
553 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
554 let pending_checkpoint = if self.is_generating() {
555 return;
556 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
557 checkpoint
558 } else {
559 return;
560 };
561
562 let git_store = self.project.read(cx).git_store().clone();
563 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
564 cx.spawn(async move |this, cx| match final_checkpoint.await {
565 Ok(final_checkpoint) => {
566 let equal = git_store
567 .update(cx, |store, cx| {
568 store.compare_checkpoints(
569 pending_checkpoint.git_checkpoint.clone(),
570 final_checkpoint.clone(),
571 cx,
572 )
573 })?
574 .await
575 .unwrap_or(false);
576
577 if equal {
578 git_store
579 .update(cx, |store, cx| {
580 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
581 })?
582 .detach();
583 } else {
584 this.update(cx, |this, cx| {
585 this.insert_checkpoint(pending_checkpoint, cx)
586 })?;
587 }
588
589 git_store
590 .update(cx, |store, cx| {
591 store.delete_checkpoint(final_checkpoint, cx)
592 })?
593 .detach();
594
595 Ok(())
596 }
597 Err(_) => this.update(cx, |this, cx| {
598 this.insert_checkpoint(pending_checkpoint, cx)
599 }),
600 })
601 .detach();
602 }
603
604 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
605 self.checkpoints_by_message
606 .insert(checkpoint.message_id, checkpoint);
607 cx.emit(ThreadEvent::CheckpointChanged);
608 cx.notify();
609 }
610
611 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
612 self.last_restore_checkpoint.as_ref()
613 }
614
615 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
616 let Some(message_ix) = self
617 .messages
618 .iter()
619 .rposition(|message| message.id == message_id)
620 else {
621 return;
622 };
623 for deleted_message in self.messages.drain(message_ix..) {
624 self.context_by_message.remove(&deleted_message.id);
625 self.checkpoints_by_message.remove(&deleted_message.id);
626 }
627 cx.notify();
628 }
629
630 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
631 self.context_by_message
632 .get(&id)
633 .into_iter()
634 .flat_map(|context| {
635 context
636 .iter()
637 .filter_map(|context_id| self.context.get(&context_id))
638 })
639 }
640
641 /// Returns whether all of the tool uses have finished running.
642 pub fn all_tools_finished(&self) -> bool {
643 // If the only pending tool uses left are the ones with errors, then
644 // that means that we've finished running all of the pending tools.
645 self.tool_use
646 .pending_tool_uses()
647 .iter()
648 .all(|tool_use| tool_use.status.is_error())
649 }
650
651 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
652 self.tool_use.tool_uses_for_message(id, cx)
653 }
654
655 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
656 self.tool_use.tool_results_for_message(id)
657 }
658
659 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
660 self.tool_use.tool_result(id)
661 }
662
663 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
664 Some(&self.tool_use.tool_result(id)?.content)
665 }
666
667 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
668 self.tool_use.tool_result_card(id).cloned()
669 }
670
671 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
672 self.tool_use.message_has_tool_results(message_id)
673 }
674
675 /// Filter out contexts that have already been included in previous messages
676 pub fn filter_new_context<'a>(
677 &self,
678 context: impl Iterator<Item = &'a AssistantContext>,
679 ) -> impl Iterator<Item = &'a AssistantContext> {
680 context.filter(|ctx| self.is_context_new(ctx))
681 }
682
683 fn is_context_new(&self, context: &AssistantContext) -> bool {
684 !self.context.contains_key(&context.id())
685 }
686
687 pub fn insert_user_message(
688 &mut self,
689 text: impl Into<String>,
690 context: Vec<AssistantContext>,
691 git_checkpoint: Option<GitStoreCheckpoint>,
692 cx: &mut Context<Self>,
693 ) -> MessageId {
694 let text = text.into();
695
696 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
697
698 let new_context: Vec<_> = context
699 .into_iter()
700 .filter(|ctx| self.is_context_new(ctx))
701 .collect();
702
703 if !new_context.is_empty() {
704 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
705 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
706 message.context = context_string;
707 }
708 }
709
710 self.action_log.update(cx, |log, cx| {
711 // Track all buffers added as context
712 for ctx in &new_context {
713 match ctx {
714 AssistantContext::File(file_ctx) => {
715 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
716 }
717 AssistantContext::Directory(dir_ctx) => {
718 for context_buffer in &dir_ctx.context_buffers {
719 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
720 }
721 }
722 AssistantContext::Symbol(symbol_ctx) => {
723 log.buffer_added_as_context(
724 symbol_ctx.context_symbol.buffer.clone(),
725 cx,
726 );
727 }
728 AssistantContext::Excerpt(excerpt_context) => {
729 log.buffer_added_as_context(
730 excerpt_context.context_buffer.buffer.clone(),
731 cx,
732 );
733 }
734 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
735 }
736 }
737 });
738 }
739
740 let context_ids = new_context
741 .iter()
742 .map(|context| context.id())
743 .collect::<Vec<_>>();
744 self.context.extend(
745 new_context
746 .into_iter()
747 .map(|context| (context.id(), context)),
748 );
749 self.context_by_message.insert(message_id, context_ids);
750
751 if let Some(git_checkpoint) = git_checkpoint {
752 self.pending_checkpoint = Some(ThreadCheckpoint {
753 message_id,
754 git_checkpoint,
755 });
756 }
757
758 self.auto_capture_telemetry(cx);
759
760 message_id
761 }
762
763 pub fn insert_message(
764 &mut self,
765 role: Role,
766 segments: Vec<MessageSegment>,
767 cx: &mut Context<Self>,
768 ) -> MessageId {
769 let id = self.next_message_id.post_inc();
770 self.messages.push(Message {
771 id,
772 role,
773 segments,
774 context: String::new(),
775 });
776 self.touch_updated_at();
777 cx.emit(ThreadEvent::MessageAdded(id));
778 id
779 }
780
781 pub fn edit_message(
782 &mut self,
783 id: MessageId,
784 new_role: Role,
785 new_segments: Vec<MessageSegment>,
786 cx: &mut Context<Self>,
787 ) -> bool {
788 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
789 return false;
790 };
791 message.role = new_role;
792 message.segments = new_segments;
793 self.touch_updated_at();
794 cx.emit(ThreadEvent::MessageEdited(id));
795 true
796 }
797
798 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
799 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
800 return false;
801 };
802 self.messages.remove(index);
803 self.context_by_message.remove(&id);
804 self.touch_updated_at();
805 cx.emit(ThreadEvent::MessageDeleted(id));
806 true
807 }
808
809 /// Returns the representation of this [`Thread`] in a textual form.
810 ///
811 /// This is the representation we use when attaching a thread as context to another thread.
812 pub fn text(&self) -> String {
813 let mut text = String::new();
814
815 for message in &self.messages {
816 text.push_str(match message.role {
817 language_model::Role::User => "User:",
818 language_model::Role::Assistant => "Assistant:",
819 language_model::Role::System => "System:",
820 });
821 text.push('\n');
822
823 for segment in &message.segments {
824 match segment {
825 MessageSegment::Text(content) => text.push_str(content),
826 MessageSegment::Thinking(content) => {
827 text.push_str(&format!("<think>{}</think>", content))
828 }
829 }
830 }
831 text.push('\n');
832 }
833
834 text
835 }
836
837 /// Serializes this thread into a format for storage or telemetry.
838 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
839 let initial_project_snapshot = self.initial_project_snapshot.clone();
840 cx.spawn(async move |this, cx| {
841 let initial_project_snapshot = initial_project_snapshot.await;
842 this.read_with(cx, |this, cx| SerializedThread {
843 version: SerializedThread::VERSION.to_string(),
844 summary: this.summary_or_default(),
845 updated_at: this.updated_at(),
846 messages: this
847 .messages()
848 .map(|message| SerializedMessage {
849 id: message.id,
850 role: message.role,
851 segments: message
852 .segments
853 .iter()
854 .map(|segment| match segment {
855 MessageSegment::Text(text) => {
856 SerializedMessageSegment::Text { text: text.clone() }
857 }
858 MessageSegment::Thinking(text) => {
859 SerializedMessageSegment::Thinking { text: text.clone() }
860 }
861 })
862 .collect(),
863 tool_uses: this
864 .tool_uses_for_message(message.id, cx)
865 .into_iter()
866 .map(|tool_use| SerializedToolUse {
867 id: tool_use.id,
868 name: tool_use.name,
869 input: tool_use.input,
870 })
871 .collect(),
872 tool_results: this
873 .tool_results_for_message(message.id)
874 .into_iter()
875 .map(|tool_result| SerializedToolResult {
876 tool_use_id: tool_result.tool_use_id.clone(),
877 is_error: tool_result.is_error,
878 content: tool_result.content.clone(),
879 })
880 .collect(),
881 context: message.context.clone(),
882 })
883 .collect(),
884 initial_project_snapshot,
885 cumulative_token_usage: this.cumulative_token_usage,
886 request_token_usage: this.request_token_usage.clone(),
887 detailed_summary_state: this.detailed_summary_state.clone(),
888 exceeded_window_error: this.exceeded_window_error.clone(),
889 })
890 })
891 }
892
893 pub fn send_to_model(
894 &mut self,
895 model: Arc<dyn LanguageModel>,
896 request_kind: RequestKind,
897 cx: &mut Context<Self>,
898 ) {
899 let mut request = self.to_completion_request(request_kind, cx);
900 if model.supports_tools() {
901 request.tools = {
902 let mut tools = Vec::new();
903 tools.extend(
904 self.tools()
905 .read(cx)
906 .enabled_tools(cx)
907 .into_iter()
908 .filter_map(|tool| {
909 // Skip tools that cannot be supported
910 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
911 Some(LanguageModelRequestTool {
912 name: tool.name(),
913 description: tool.description(),
914 input_schema,
915 })
916 }),
917 );
918
919 tools
920 };
921 }
922
923 self.stream_completion(request, model, cx);
924 }
925
926 pub fn used_tools_since_last_user_message(&self) -> bool {
927 for message in self.messages.iter().rev() {
928 if self.tool_use.message_has_tool_results(message.id) {
929 return true;
930 } else if message.role == Role::User {
931 return false;
932 }
933 }
934
935 false
936 }
937
938 pub fn to_completion_request(
939 &self,
940 request_kind: RequestKind,
941 cx: &App,
942 ) -> LanguageModelRequest {
943 let mut request = LanguageModelRequest {
944 messages: vec![],
945 tools: Vec::new(),
946 stop: Vec::new(),
947 temperature: None,
948 };
949
950 if let Some(project_context) = self.project_context.borrow().as_ref() {
951 if let Some(system_prompt) = self
952 .prompt_builder
953 .generate_assistant_system_prompt(project_context)
954 .context("failed to generate assistant system prompt")
955 .log_err()
956 {
957 request.messages.push(LanguageModelRequestMessage {
958 role: Role::System,
959 content: vec![MessageContent::Text(system_prompt)],
960 cache: true,
961 });
962 }
963 } else {
964 log::error!("project_context not set.")
965 }
966
967 for message in &self.messages {
968 let mut request_message = LanguageModelRequestMessage {
969 role: message.role,
970 content: Vec::new(),
971 cache: false,
972 };
973
974 match request_kind {
975 RequestKind::Chat => {
976 self.tool_use
977 .attach_tool_results(message.id, &mut request_message);
978 }
979 RequestKind::Summarize => {
980 // We don't care about tool use during summarization.
981 if self.tool_use.message_has_tool_results(message.id) {
982 continue;
983 }
984 }
985 }
986
987 if !message.segments.is_empty() {
988 request_message
989 .content
990 .push(MessageContent::Text(message.to_string()));
991 }
992
993 match request_kind {
994 RequestKind::Chat => {
995 self.tool_use
996 .attach_tool_uses(message.id, &mut request_message);
997 }
998 RequestKind::Summarize => {
999 // We don't care about tool use during summarization.
1000 }
1001 };
1002
1003 request.messages.push(request_message);
1004 }
1005
1006 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1007 if let Some(last) = request.messages.last_mut() {
1008 last.cache = true;
1009 }
1010
1011 self.attached_tracked_files_state(&mut request.messages, cx);
1012
1013 request
1014 }
1015
1016 fn attached_tracked_files_state(
1017 &self,
1018 messages: &mut Vec<LanguageModelRequestMessage>,
1019 cx: &App,
1020 ) {
1021 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1022
1023 let mut stale_message = String::new();
1024
1025 let action_log = self.action_log.read(cx);
1026
1027 for stale_file in action_log.stale_buffers(cx) {
1028 let Some(file) = stale_file.read(cx).file() else {
1029 continue;
1030 };
1031
1032 if stale_message.is_empty() {
1033 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1034 }
1035
1036 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1037 }
1038
1039 let mut content = Vec::with_capacity(2);
1040
1041 if !stale_message.is_empty() {
1042 content.push(stale_message.into());
1043 }
1044
1045 if action_log.has_edited_files_since_project_diagnostics_check() {
1046 content.push(
1047 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1048 and fix all errors AND warnings you introduced! \
1049 DO NOT mention you're going to do this until you're done."
1050 .into(),
1051 );
1052 }
1053
1054 if !content.is_empty() {
1055 let context_message = LanguageModelRequestMessage {
1056 role: Role::User,
1057 content,
1058 cache: false,
1059 };
1060
1061 messages.push(context_message);
1062 }
1063 }
1064
1065 pub fn stream_completion(
1066 &mut self,
1067 request: LanguageModelRequest,
1068 model: Arc<dyn LanguageModel>,
1069 cx: &mut Context<Self>,
1070 ) {
1071 let pending_completion_id = post_inc(&mut self.completion_count);
1072 let task = cx.spawn(async move |thread, cx| {
1073 let stream = model.stream_completion(request, &cx);
1074 let initial_token_usage =
1075 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1076 let stream_completion = async {
1077 let mut events = stream.await?;
1078 let mut stop_reason = StopReason::EndTurn;
1079 let mut current_token_usage = TokenUsage::default();
1080
1081 while let Some(event) = events.next().await {
1082 let event = event?;
1083
1084 thread.update(cx, |thread, cx| {
1085 match event {
1086 LanguageModelCompletionEvent::StartMessage { .. } => {
1087 thread.insert_message(
1088 Role::Assistant,
1089 vec![MessageSegment::Text(String::new())],
1090 cx,
1091 );
1092 }
1093 LanguageModelCompletionEvent::Stop(reason) => {
1094 stop_reason = reason;
1095 }
1096 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1097 thread.update_token_usage_at_last_message(token_usage);
1098 thread.cumulative_token_usage = thread.cumulative_token_usage
1099 + token_usage
1100 - current_token_usage;
1101 current_token_usage = token_usage;
1102 }
1103 LanguageModelCompletionEvent::Text(chunk) => {
1104 if let Some(last_message) = thread.messages.last_mut() {
1105 if last_message.role == Role::Assistant {
1106 last_message.push_text(&chunk);
1107 cx.emit(ThreadEvent::StreamedAssistantText(
1108 last_message.id,
1109 chunk,
1110 ));
1111 } else {
1112 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1113 // of a new Assistant response.
1114 //
1115 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1116 // will result in duplicating the text of the chunk in the rendered Markdown.
1117 thread.insert_message(
1118 Role::Assistant,
1119 vec![MessageSegment::Text(chunk.to_string())],
1120 cx,
1121 );
1122 };
1123 }
1124 }
1125 LanguageModelCompletionEvent::Thinking(chunk) => {
1126 if let Some(last_message) = thread.messages.last_mut() {
1127 if last_message.role == Role::Assistant {
1128 last_message.push_thinking(&chunk);
1129 cx.emit(ThreadEvent::StreamedAssistantThinking(
1130 last_message.id,
1131 chunk,
1132 ));
1133 } else {
1134 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1135 // of a new Assistant response.
1136 //
1137 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1138 // will result in duplicating the text of the chunk in the rendered Markdown.
1139 thread.insert_message(
1140 Role::Assistant,
1141 vec![MessageSegment::Thinking(chunk.to_string())],
1142 cx,
1143 );
1144 };
1145 }
1146 }
1147 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1148 let last_assistant_message_id = thread
1149 .messages
1150 .iter_mut()
1151 .rfind(|message| message.role == Role::Assistant)
1152 .map(|message| message.id)
1153 .unwrap_or_else(|| {
1154 thread.insert_message(Role::Assistant, vec![], cx)
1155 });
1156
1157 thread.tool_use.request_tool_use(
1158 last_assistant_message_id,
1159 tool_use,
1160 cx,
1161 );
1162 }
1163 }
1164
1165 thread.touch_updated_at();
1166 cx.emit(ThreadEvent::StreamedCompletion);
1167 cx.notify();
1168
1169 thread.auto_capture_telemetry(cx);
1170 })?;
1171
1172 smol::future::yield_now().await;
1173 }
1174
1175 thread.update(cx, |thread, cx| {
1176 thread
1177 .pending_completions
1178 .retain(|completion| completion.id != pending_completion_id);
1179
1180 if thread.summary.is_none() && thread.messages.len() >= 2 {
1181 thread.summarize(cx);
1182 }
1183 })?;
1184
1185 anyhow::Ok(stop_reason)
1186 };
1187
1188 let result = stream_completion.await;
1189
1190 thread
1191 .update(cx, |thread, cx| {
1192 thread.finalize_pending_checkpoint(cx);
1193 match result.as_ref() {
1194 Ok(stop_reason) => match stop_reason {
1195 StopReason::ToolUse => {
1196 let tool_uses = thread.use_pending_tools(cx);
1197 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1198 }
1199 StopReason::EndTurn => {}
1200 StopReason::MaxTokens => {}
1201 },
1202 Err(error) => {
1203 if error.is::<PaymentRequiredError>() {
1204 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1205 } else if error.is::<MaxMonthlySpendReachedError>() {
1206 cx.emit(ThreadEvent::ShowError(
1207 ThreadError::MaxMonthlySpendReached,
1208 ));
1209 } else if let Some(error) =
1210 error.downcast_ref::<ModelRequestLimitReachedError>()
1211 {
1212 cx.emit(ThreadEvent::ShowError(
1213 ThreadError::ModelRequestLimitReached { plan: error.plan },
1214 ));
1215 } else if let Some(known_error) =
1216 error.downcast_ref::<LanguageModelKnownError>()
1217 {
1218 match known_error {
1219 LanguageModelKnownError::ContextWindowLimitExceeded {
1220 tokens,
1221 } => {
1222 thread.exceeded_window_error = Some(ExceededWindowError {
1223 model_id: model.id(),
1224 token_count: *tokens,
1225 });
1226 cx.notify();
1227 }
1228 }
1229 } else {
1230 let error_message = error
1231 .chain()
1232 .map(|err| err.to_string())
1233 .collect::<Vec<_>>()
1234 .join("\n");
1235 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1236 header: "Error interacting with language model".into(),
1237 message: SharedString::from(error_message.clone()),
1238 }));
1239 }
1240
1241 thread.cancel_last_completion(cx);
1242 }
1243 }
1244 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1245
1246 thread.auto_capture_telemetry(cx);
1247
1248 if let Ok(initial_usage) = initial_token_usage {
1249 let usage = thread.cumulative_token_usage - initial_usage;
1250
1251 telemetry::event!(
1252 "Assistant Thread Completion",
1253 thread_id = thread.id().to_string(),
1254 model = model.telemetry_id(),
1255 model_provider = model.provider_id().to_string(),
1256 input_tokens = usage.input_tokens,
1257 output_tokens = usage.output_tokens,
1258 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1259 cache_read_input_tokens = usage.cache_read_input_tokens,
1260 );
1261 }
1262 })
1263 .ok();
1264 });
1265
1266 self.pending_completions.push(PendingCompletion {
1267 id: pending_completion_id,
1268 _task: task,
1269 });
1270 }
1271
1272 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1273 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1274 return;
1275 };
1276
1277 if !model.provider.is_authenticated(cx) {
1278 return;
1279 }
1280
1281 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1282 request.messages.push(LanguageModelRequestMessage {
1283 role: Role::User,
1284 content: vec![
1285 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1286 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1287 If the conversation is about a specific subject, include it in the title. \
1288 Be descriptive. DO NOT speak in the first person."
1289 .into(),
1290 ],
1291 cache: false,
1292 });
1293
1294 self.pending_summary = cx.spawn(async move |this, cx| {
1295 async move {
1296 let stream = model.model.stream_completion_text(request, &cx);
1297 let mut messages = stream.await?;
1298
1299 let mut new_summary = String::new();
1300 while let Some(message) = messages.stream.next().await {
1301 let text = message?;
1302 let mut lines = text.lines();
1303 new_summary.extend(lines.next());
1304
1305 // Stop if the LLM generated multiple lines.
1306 if lines.next().is_some() {
1307 break;
1308 }
1309 }
1310
1311 this.update(cx, |this, cx| {
1312 if !new_summary.is_empty() {
1313 this.summary = Some(new_summary.into());
1314 }
1315
1316 cx.emit(ThreadEvent::SummaryGenerated);
1317 })?;
1318
1319 anyhow::Ok(())
1320 }
1321 .log_err()
1322 .await
1323 });
1324 }
1325
1326 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1327 let last_message_id = self.messages.last().map(|message| message.id)?;
1328
1329 match &self.detailed_summary_state {
1330 DetailedSummaryState::Generating { message_id, .. }
1331 | DetailedSummaryState::Generated { message_id, .. }
1332 if *message_id == last_message_id =>
1333 {
1334 // Already up-to-date
1335 return None;
1336 }
1337 _ => {}
1338 }
1339
1340 let ConfiguredModel { model, provider } =
1341 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1342
1343 if !provider.is_authenticated(cx) {
1344 return None;
1345 }
1346
1347 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1348
1349 request.messages.push(LanguageModelRequestMessage {
1350 role: Role::User,
1351 content: vec![
1352 "Generate a detailed summary of this conversation. Include:\n\
1353 1. A brief overview of what was discussed\n\
1354 2. Key facts or information discovered\n\
1355 3. Outcomes or conclusions reached\n\
1356 4. Any action items or next steps if any\n\
1357 Format it in Markdown with headings and bullet points."
1358 .into(),
1359 ],
1360 cache: false,
1361 });
1362
1363 let task = cx.spawn(async move |thread, cx| {
1364 let stream = model.stream_completion_text(request, &cx);
1365 let Some(mut messages) = stream.await.log_err() else {
1366 thread
1367 .update(cx, |this, _cx| {
1368 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1369 })
1370 .log_err();
1371
1372 return;
1373 };
1374
1375 let mut new_detailed_summary = String::new();
1376
1377 while let Some(chunk) = messages.stream.next().await {
1378 if let Some(chunk) = chunk.log_err() {
1379 new_detailed_summary.push_str(&chunk);
1380 }
1381 }
1382
1383 thread
1384 .update(cx, |this, _cx| {
1385 this.detailed_summary_state = DetailedSummaryState::Generated {
1386 text: new_detailed_summary.into(),
1387 message_id: last_message_id,
1388 };
1389 })
1390 .log_err();
1391 });
1392
1393 self.detailed_summary_state = DetailedSummaryState::Generating {
1394 message_id: last_message_id,
1395 };
1396
1397 Some(task)
1398 }
1399
1400 pub fn is_generating_detailed_summary(&self) -> bool {
1401 matches!(
1402 self.detailed_summary_state,
1403 DetailedSummaryState::Generating { .. }
1404 )
1405 }
1406
1407 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1408 self.auto_capture_telemetry(cx);
1409 let request = self.to_completion_request(RequestKind::Chat, cx);
1410 let messages = Arc::new(request.messages);
1411 let pending_tool_uses = self
1412 .tool_use
1413 .pending_tool_uses()
1414 .into_iter()
1415 .filter(|tool_use| tool_use.status.is_idle())
1416 .cloned()
1417 .collect::<Vec<_>>();
1418
1419 for tool_use in pending_tool_uses.iter() {
1420 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1421 if tool.needs_confirmation(&tool_use.input, cx)
1422 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1423 {
1424 self.tool_use.confirm_tool_use(
1425 tool_use.id.clone(),
1426 tool_use.ui_text.clone(),
1427 tool_use.input.clone(),
1428 messages.clone(),
1429 tool,
1430 );
1431 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1432 } else {
1433 self.run_tool(
1434 tool_use.id.clone(),
1435 tool_use.ui_text.clone(),
1436 tool_use.input.clone(),
1437 &messages,
1438 tool,
1439 cx,
1440 );
1441 }
1442 }
1443 }
1444
1445 pending_tool_uses
1446 }
1447
1448 pub fn run_tool(
1449 &mut self,
1450 tool_use_id: LanguageModelToolUseId,
1451 ui_text: impl Into<SharedString>,
1452 input: serde_json::Value,
1453 messages: &[LanguageModelRequestMessage],
1454 tool: Arc<dyn Tool>,
1455 cx: &mut Context<Thread>,
1456 ) {
1457 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1458 self.tool_use
1459 .run_pending_tool(tool_use_id, ui_text.into(), task);
1460 }
1461
1462 fn spawn_tool_use(
1463 &mut self,
1464 tool_use_id: LanguageModelToolUseId,
1465 messages: &[LanguageModelRequestMessage],
1466 input: serde_json::Value,
1467 tool: Arc<dyn Tool>,
1468 cx: &mut Context<Thread>,
1469 ) -> Task<()> {
1470 let tool_name: Arc<str> = tool.name().into();
1471
1472 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1473 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1474 } else {
1475 tool.run(
1476 input,
1477 messages,
1478 self.project.clone(),
1479 self.action_log.clone(),
1480 cx,
1481 )
1482 };
1483
1484 // Store the card separately if it exists
1485 if let Some(card) = tool_result.card.clone() {
1486 self.tool_use
1487 .insert_tool_result_card(tool_use_id.clone(), card);
1488 }
1489
1490 cx.spawn({
1491 async move |thread: WeakEntity<Thread>, cx| {
1492 let output = tool_result.output.await;
1493
1494 thread
1495 .update(cx, |thread, cx| {
1496 let pending_tool_use = thread.tool_use.insert_tool_output(
1497 tool_use_id.clone(),
1498 tool_name,
1499 output,
1500 cx,
1501 );
1502 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1503 })
1504 .ok();
1505 }
1506 })
1507 }
1508
1509 fn tool_finished(
1510 &mut self,
1511 tool_use_id: LanguageModelToolUseId,
1512 pending_tool_use: Option<PendingToolUse>,
1513 canceled: bool,
1514 cx: &mut Context<Self>,
1515 ) {
1516 if self.all_tools_finished() {
1517 let model_registry = LanguageModelRegistry::read_global(cx);
1518 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1519 self.attach_tool_results(cx);
1520 if !canceled {
1521 self.send_to_model(model, RequestKind::Chat, cx);
1522 }
1523 }
1524 }
1525
1526 cx.emit(ThreadEvent::ToolFinished {
1527 tool_use_id,
1528 pending_tool_use,
1529 });
1530 }
1531
1532 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1533 // Insert a user message to contain the tool results.
1534 self.insert_user_message(
1535 // TODO: Sending up a user message without any content results in the model sending back
1536 // responses that also don't have any content. We currently don't handle this case well,
1537 // so for now we provide some text to keep the model on track.
1538 "Here are the tool results.",
1539 Vec::new(),
1540 None,
1541 cx,
1542 );
1543 }
1544
1545 /// Cancels the last pending completion, if there are any pending.
1546 ///
1547 /// Returns whether a completion was canceled.
1548 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1549 let canceled = if self.pending_completions.pop().is_some() {
1550 true
1551 } else {
1552 let mut canceled = false;
1553 for pending_tool_use in self.tool_use.cancel_pending() {
1554 canceled = true;
1555 self.tool_finished(
1556 pending_tool_use.id.clone(),
1557 Some(pending_tool_use),
1558 true,
1559 cx,
1560 );
1561 }
1562 canceled
1563 };
1564 self.finalize_pending_checkpoint(cx);
1565 canceled
1566 }
1567
1568 pub fn feedback(&self) -> Option<ThreadFeedback> {
1569 self.feedback
1570 }
1571
1572 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1573 self.message_feedback.get(&message_id).copied()
1574 }
1575
1576 pub fn report_message_feedback(
1577 &mut self,
1578 message_id: MessageId,
1579 feedback: ThreadFeedback,
1580 cx: &mut Context<Self>,
1581 ) -> Task<Result<()>> {
1582 if self.message_feedback.get(&message_id) == Some(&feedback) {
1583 return Task::ready(Ok(()));
1584 }
1585
1586 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1587 let serialized_thread = self.serialize(cx);
1588 let thread_id = self.id().clone();
1589 let client = self.project.read(cx).client();
1590
1591 let enabled_tool_names: Vec<String> = self
1592 .tools()
1593 .read(cx)
1594 .enabled_tools(cx)
1595 .iter()
1596 .map(|tool| tool.name().to_string())
1597 .collect();
1598
1599 self.message_feedback.insert(message_id, feedback);
1600
1601 cx.notify();
1602
1603 let message_content = self
1604 .message(message_id)
1605 .map(|msg| msg.to_string())
1606 .unwrap_or_default();
1607
1608 cx.background_spawn(async move {
1609 let final_project_snapshot = final_project_snapshot.await;
1610 let serialized_thread = serialized_thread.await?;
1611 let thread_data =
1612 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1613
1614 let rating = match feedback {
1615 ThreadFeedback::Positive => "positive",
1616 ThreadFeedback::Negative => "negative",
1617 };
1618 telemetry::event!(
1619 "Assistant Thread Rated",
1620 rating,
1621 thread_id,
1622 enabled_tool_names,
1623 message_id = message_id.0,
1624 message_content,
1625 thread_data,
1626 final_project_snapshot
1627 );
1628 client.telemetry().flush_events();
1629
1630 Ok(())
1631 })
1632 }
1633
1634 pub fn report_feedback(
1635 &mut self,
1636 feedback: ThreadFeedback,
1637 cx: &mut Context<Self>,
1638 ) -> Task<Result<()>> {
1639 let last_assistant_message_id = self
1640 .messages
1641 .iter()
1642 .rev()
1643 .find(|msg| msg.role == Role::Assistant)
1644 .map(|msg| msg.id);
1645
1646 if let Some(message_id) = last_assistant_message_id {
1647 self.report_message_feedback(message_id, feedback, cx)
1648 } else {
1649 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1650 let serialized_thread = self.serialize(cx);
1651 let thread_id = self.id().clone();
1652 let client = self.project.read(cx).client();
1653 self.feedback = Some(feedback);
1654 cx.notify();
1655
1656 cx.background_spawn(async move {
1657 let final_project_snapshot = final_project_snapshot.await;
1658 let serialized_thread = serialized_thread.await?;
1659 let thread_data = serde_json::to_value(serialized_thread)
1660 .unwrap_or_else(|_| serde_json::Value::Null);
1661
1662 let rating = match feedback {
1663 ThreadFeedback::Positive => "positive",
1664 ThreadFeedback::Negative => "negative",
1665 };
1666 telemetry::event!(
1667 "Assistant Thread Rated",
1668 rating,
1669 thread_id,
1670 thread_data,
1671 final_project_snapshot
1672 );
1673 client.telemetry().flush_events();
1674
1675 Ok(())
1676 })
1677 }
1678 }
1679
1680 /// Create a snapshot of the current project state including git information and unsaved buffers.
1681 fn project_snapshot(
1682 project: Entity<Project>,
1683 cx: &mut Context<Self>,
1684 ) -> Task<Arc<ProjectSnapshot>> {
1685 let git_store = project.read(cx).git_store().clone();
1686 let worktree_snapshots: Vec<_> = project
1687 .read(cx)
1688 .visible_worktrees(cx)
1689 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1690 .collect();
1691
1692 cx.spawn(async move |_, cx| {
1693 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1694
1695 let mut unsaved_buffers = Vec::new();
1696 cx.update(|app_cx| {
1697 let buffer_store = project.read(app_cx).buffer_store();
1698 for buffer_handle in buffer_store.read(app_cx).buffers() {
1699 let buffer = buffer_handle.read(app_cx);
1700 if buffer.is_dirty() {
1701 if let Some(file) = buffer.file() {
1702 let path = file.path().to_string_lossy().to_string();
1703 unsaved_buffers.push(path);
1704 }
1705 }
1706 }
1707 })
1708 .ok();
1709
1710 Arc::new(ProjectSnapshot {
1711 worktree_snapshots,
1712 unsaved_buffer_paths: unsaved_buffers,
1713 timestamp: Utc::now(),
1714 })
1715 })
1716 }
1717
1718 fn worktree_snapshot(
1719 worktree: Entity<project::Worktree>,
1720 git_store: Entity<GitStore>,
1721 cx: &App,
1722 ) -> Task<WorktreeSnapshot> {
1723 cx.spawn(async move |cx| {
1724 // Get worktree path and snapshot
1725 let worktree_info = cx.update(|app_cx| {
1726 let worktree = worktree.read(app_cx);
1727 let path = worktree.abs_path().to_string_lossy().to_string();
1728 let snapshot = worktree.snapshot();
1729 (path, snapshot)
1730 });
1731
1732 let Ok((worktree_path, _snapshot)) = worktree_info else {
1733 return WorktreeSnapshot {
1734 worktree_path: String::new(),
1735 git_state: None,
1736 };
1737 };
1738
1739 let git_state = git_store
1740 .update(cx, |git_store, cx| {
1741 git_store
1742 .repositories()
1743 .values()
1744 .find(|repo| {
1745 repo.read(cx)
1746 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1747 .is_some()
1748 })
1749 .cloned()
1750 })
1751 .ok()
1752 .flatten()
1753 .map(|repo| {
1754 repo.update(cx, |repo, _| {
1755 let current_branch =
1756 repo.branch.as_ref().map(|branch| branch.name.to_string());
1757 repo.send_job(None, |state, _| async move {
1758 let RepositoryState::Local { backend, .. } = state else {
1759 return GitState {
1760 remote_url: None,
1761 head_sha: None,
1762 current_branch,
1763 diff: None,
1764 };
1765 };
1766
1767 let remote_url = backend.remote_url("origin");
1768 let head_sha = backend.head_sha();
1769 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1770
1771 GitState {
1772 remote_url,
1773 head_sha,
1774 current_branch,
1775 diff,
1776 }
1777 })
1778 })
1779 });
1780
1781 let git_state = match git_state {
1782 Some(git_state) => match git_state.ok() {
1783 Some(git_state) => git_state.await.ok(),
1784 None => None,
1785 },
1786 None => None,
1787 };
1788
1789 WorktreeSnapshot {
1790 worktree_path,
1791 git_state,
1792 }
1793 })
1794 }
1795
1796 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1797 let mut markdown = Vec::new();
1798
1799 if let Some(summary) = self.summary() {
1800 writeln!(markdown, "# {summary}\n")?;
1801 };
1802
1803 for message in self.messages() {
1804 writeln!(
1805 markdown,
1806 "## {role}\n",
1807 role = match message.role {
1808 Role::User => "User",
1809 Role::Assistant => "Assistant",
1810 Role::System => "System",
1811 }
1812 )?;
1813
1814 if !message.context.is_empty() {
1815 writeln!(markdown, "{}", message.context)?;
1816 }
1817
1818 for segment in &message.segments {
1819 match segment {
1820 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1821 MessageSegment::Thinking(text) => {
1822 writeln!(markdown, "<think>{}</think>\n", text)?
1823 }
1824 }
1825 }
1826
1827 for tool_use in self.tool_uses_for_message(message.id, cx) {
1828 writeln!(
1829 markdown,
1830 "**Use Tool: {} ({})**",
1831 tool_use.name, tool_use.id
1832 )?;
1833 writeln!(markdown, "```json")?;
1834 writeln!(
1835 markdown,
1836 "{}",
1837 serde_json::to_string_pretty(&tool_use.input)?
1838 )?;
1839 writeln!(markdown, "```")?;
1840 }
1841
1842 for tool_result in self.tool_results_for_message(message.id) {
1843 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1844 if tool_result.is_error {
1845 write!(markdown, " (Error)")?;
1846 }
1847
1848 writeln!(markdown, "**\n")?;
1849 writeln!(markdown, "{}", tool_result.content)?;
1850 }
1851 }
1852
1853 Ok(String::from_utf8_lossy(&markdown).to_string())
1854 }
1855
1856 pub fn keep_edits_in_range(
1857 &mut self,
1858 buffer: Entity<language::Buffer>,
1859 buffer_range: Range<language::Anchor>,
1860 cx: &mut Context<Self>,
1861 ) {
1862 self.action_log.update(cx, |action_log, cx| {
1863 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1864 });
1865 }
1866
1867 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1868 self.action_log
1869 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1870 }
1871
1872 pub fn reject_edits_in_ranges(
1873 &mut self,
1874 buffer: Entity<language::Buffer>,
1875 buffer_ranges: Vec<Range<language::Anchor>>,
1876 cx: &mut Context<Self>,
1877 ) -> Task<Result<()>> {
1878 self.action_log.update(cx, |action_log, cx| {
1879 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1880 })
1881 }
1882
1883 pub fn action_log(&self) -> &Entity<ActionLog> {
1884 &self.action_log
1885 }
1886
1887 pub fn project(&self) -> &Entity<Project> {
1888 &self.project
1889 }
1890
1891 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1892 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1893 return;
1894 }
1895
1896 let now = Instant::now();
1897 if let Some(last) = self.last_auto_capture_at {
1898 if now.duration_since(last).as_secs() < 10 {
1899 return;
1900 }
1901 }
1902
1903 self.last_auto_capture_at = Some(now);
1904
1905 let thread_id = self.id().clone();
1906 let github_login = self
1907 .project
1908 .read(cx)
1909 .user_store()
1910 .read(cx)
1911 .current_user()
1912 .map(|user| user.github_login.clone());
1913 let client = self.project.read(cx).client().clone();
1914 let serialize_task = self.serialize(cx);
1915
1916 cx.background_executor()
1917 .spawn(async move {
1918 if let Ok(serialized_thread) = serialize_task.await {
1919 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1920 telemetry::event!(
1921 "Agent Thread Auto-Captured",
1922 thread_id = thread_id.to_string(),
1923 thread_data = thread_data,
1924 auto_capture_reason = "tracked_user",
1925 github_login = github_login
1926 );
1927
1928 client.telemetry().flush_events();
1929 }
1930 }
1931 })
1932 .detach();
1933 }
1934
1935 pub fn cumulative_token_usage(&self) -> TokenUsage {
1936 self.cumulative_token_usage
1937 }
1938
1939 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1940 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1941 return TotalTokenUsage::default();
1942 };
1943
1944 let max = model.model.max_token_count();
1945
1946 let index = self
1947 .messages
1948 .iter()
1949 .position(|msg| msg.id == message_id)
1950 .unwrap_or(0);
1951
1952 if index == 0 {
1953 return TotalTokenUsage { total: 0, max };
1954 }
1955
1956 let token_usage = &self
1957 .request_token_usage
1958 .get(index - 1)
1959 .cloned()
1960 .unwrap_or_default();
1961
1962 TotalTokenUsage {
1963 total: token_usage.total_tokens() as usize,
1964 max,
1965 }
1966 }
1967
1968 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1969 let model_registry = LanguageModelRegistry::read_global(cx);
1970 let Some(model) = model_registry.default_model() else {
1971 return TotalTokenUsage::default();
1972 };
1973
1974 let max = model.model.max_token_count();
1975
1976 if let Some(exceeded_error) = &self.exceeded_window_error {
1977 if model.model.id() == exceeded_error.model_id {
1978 return TotalTokenUsage {
1979 total: exceeded_error.token_count,
1980 max,
1981 };
1982 }
1983 }
1984
1985 let total = self
1986 .token_usage_at_last_message()
1987 .unwrap_or_default()
1988 .total_tokens() as usize;
1989
1990 TotalTokenUsage { total, max }
1991 }
1992
1993 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
1994 self.request_token_usage
1995 .get(self.messages.len().saturating_sub(1))
1996 .or_else(|| self.request_token_usage.last())
1997 .cloned()
1998 }
1999
2000 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2001 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2002 self.request_token_usage
2003 .resize(self.messages.len(), placeholder);
2004
2005 if let Some(last) = self.request_token_usage.last_mut() {
2006 *last = token_usage;
2007 }
2008 }
2009
2010 pub fn deny_tool_use(
2011 &mut self,
2012 tool_use_id: LanguageModelToolUseId,
2013 tool_name: Arc<str>,
2014 cx: &mut Context<Self>,
2015 ) {
2016 let err = Err(anyhow::anyhow!(
2017 "Permission to run tool action denied by user"
2018 ));
2019
2020 self.tool_use
2021 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2022 self.tool_finished(tool_use_id.clone(), None, true, cx);
2023 }
2024}
2025
2026#[derive(Debug, Clone, Error)]
2027pub enum ThreadError {
2028 #[error("Payment required")]
2029 PaymentRequired,
2030 #[error("Max monthly spend reached")]
2031 MaxMonthlySpendReached,
2032 #[error("Model request limit reached")]
2033 ModelRequestLimitReached { plan: Plan },
2034 #[error("Message {header}: {message}")]
2035 Message {
2036 header: SharedString,
2037 message: SharedString,
2038 },
2039}
2040
2041#[derive(Debug, Clone)]
2042pub enum ThreadEvent {
2043 ShowError(ThreadError),
2044 StreamedCompletion,
2045 StreamedAssistantText(MessageId, String),
2046 StreamedAssistantThinking(MessageId, String),
2047 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2048 MessageAdded(MessageId),
2049 MessageEdited(MessageId),
2050 MessageDeleted(MessageId),
2051 SummaryGenerated,
2052 SummaryChanged,
2053 UsePendingTools {
2054 tool_uses: Vec<PendingToolUse>,
2055 },
2056 ToolFinished {
2057 #[allow(unused)]
2058 tool_use_id: LanguageModelToolUseId,
2059 /// The pending tool use that corresponds to this tool.
2060 pending_tool_use: Option<PendingToolUse>,
2061 },
2062 CheckpointChanged,
2063 ToolConfirmationNeeded,
2064}
2065
2066impl EventEmitter<ThreadEvent> for Thread {}
2067
2068struct PendingCompletion {
2069 id: usize,
2070 _task: Task<()>,
2071}
2072
2073#[cfg(test)]
2074mod tests {
2075 use super::*;
2076 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2077 use assistant_settings::AssistantSettings;
2078 use context_server::ContextServerSettings;
2079 use editor::EditorSettings;
2080 use gpui::TestAppContext;
2081 use project::{FakeFs, Project};
2082 use prompt_store::PromptBuilder;
2083 use serde_json::json;
2084 use settings::{Settings, SettingsStore};
2085 use std::sync::Arc;
2086 use theme::ThemeSettings;
2087 use util::path;
2088 use workspace::Workspace;
2089
2090 #[gpui::test]
2091 async fn test_message_with_context(cx: &mut TestAppContext) {
2092 init_test_settings(cx);
2093
2094 let project = create_test_project(
2095 cx,
2096 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2097 )
2098 .await;
2099
2100 let (_workspace, _thread_store, thread, context_store) =
2101 setup_test_environment(cx, project.clone()).await;
2102
2103 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2104 .await
2105 .unwrap();
2106
2107 let context =
2108 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2109
2110 // Insert user message with context
2111 let message_id = thread.update(cx, |thread, cx| {
2112 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2113 });
2114
2115 // Check content and context in message object
2116 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2117
2118 // Use different path format strings based on platform for the test
2119 #[cfg(windows)]
2120 let path_part = r"test\code.rs";
2121 #[cfg(not(windows))]
2122 let path_part = "test/code.rs";
2123
2124 let expected_context = format!(
2125 r#"
2126<context>
2127The following items were attached by the user. You don't need to use other tools to read them.
2128
2129<files>
2130```rs {path_part}
2131fn main() {{
2132 println!("Hello, world!");
2133}}
2134```
2135</files>
2136</context>
2137"#
2138 );
2139
2140 assert_eq!(message.role, Role::User);
2141 assert_eq!(message.segments.len(), 1);
2142 assert_eq!(
2143 message.segments[0],
2144 MessageSegment::Text("Please explain this code".to_string())
2145 );
2146 assert_eq!(message.context, expected_context);
2147
2148 // Check message in request
2149 let request = thread.read_with(cx, |thread, cx| {
2150 thread.to_completion_request(RequestKind::Chat, cx)
2151 });
2152
2153 assert_eq!(request.messages.len(), 2);
2154 let expected_full_message = format!("{}Please explain this code", expected_context);
2155 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2156 }
2157
2158 #[gpui::test]
2159 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2160 init_test_settings(cx);
2161
2162 let project = create_test_project(
2163 cx,
2164 json!({
2165 "file1.rs": "fn function1() {}\n",
2166 "file2.rs": "fn function2() {}\n",
2167 "file3.rs": "fn function3() {}\n",
2168 }),
2169 )
2170 .await;
2171
2172 let (_, _thread_store, thread, context_store) =
2173 setup_test_environment(cx, project.clone()).await;
2174
2175 // Open files individually
2176 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2177 .await
2178 .unwrap();
2179 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2180 .await
2181 .unwrap();
2182 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2183 .await
2184 .unwrap();
2185
2186 // Get the context objects
2187 let contexts = context_store.update(cx, |store, _| store.context().clone());
2188 assert_eq!(contexts.len(), 3);
2189
2190 // First message with context 1
2191 let message1_id = thread.update(cx, |thread, cx| {
2192 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2193 });
2194
2195 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2196 let message2_id = thread.update(cx, |thread, cx| {
2197 thread.insert_user_message(
2198 "Message 2",
2199 vec![contexts[0].clone(), contexts[1].clone()],
2200 None,
2201 cx,
2202 )
2203 });
2204
2205 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2206 let message3_id = thread.update(cx, |thread, cx| {
2207 thread.insert_user_message(
2208 "Message 3",
2209 vec![
2210 contexts[0].clone(),
2211 contexts[1].clone(),
2212 contexts[2].clone(),
2213 ],
2214 None,
2215 cx,
2216 )
2217 });
2218
2219 // Check what contexts are included in each message
2220 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2221 (
2222 thread.message(message1_id).unwrap().clone(),
2223 thread.message(message2_id).unwrap().clone(),
2224 thread.message(message3_id).unwrap().clone(),
2225 )
2226 });
2227
2228 // First message should include context 1
2229 assert!(message1.context.contains("file1.rs"));
2230
2231 // Second message should include only context 2 (not 1)
2232 assert!(!message2.context.contains("file1.rs"));
2233 assert!(message2.context.contains("file2.rs"));
2234
2235 // Third message should include only context 3 (not 1 or 2)
2236 assert!(!message3.context.contains("file1.rs"));
2237 assert!(!message3.context.contains("file2.rs"));
2238 assert!(message3.context.contains("file3.rs"));
2239
2240 // Check entire request to make sure all contexts are properly included
2241 let request = thread.read_with(cx, |thread, cx| {
2242 thread.to_completion_request(RequestKind::Chat, cx)
2243 });
2244
2245 // The request should contain all 3 messages
2246 assert_eq!(request.messages.len(), 4);
2247
2248 // Check that the contexts are properly formatted in each message
2249 assert!(request.messages[1].string_contents().contains("file1.rs"));
2250 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2251 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2252
2253 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2254 assert!(request.messages[2].string_contents().contains("file2.rs"));
2255 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2256
2257 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2258 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2259 assert!(request.messages[3].string_contents().contains("file3.rs"));
2260 }
2261
2262 #[gpui::test]
2263 async fn test_message_without_files(cx: &mut TestAppContext) {
2264 init_test_settings(cx);
2265
2266 let project = create_test_project(
2267 cx,
2268 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2269 )
2270 .await;
2271
2272 let (_, _thread_store, thread, _context_store) =
2273 setup_test_environment(cx, project.clone()).await;
2274
2275 // Insert user message without any context (empty context vector)
2276 let message_id = thread.update(cx, |thread, cx| {
2277 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2278 });
2279
2280 // Check content and context in message object
2281 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2282
2283 // Context should be empty when no files are included
2284 assert_eq!(message.role, Role::User);
2285 assert_eq!(message.segments.len(), 1);
2286 assert_eq!(
2287 message.segments[0],
2288 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2289 );
2290 assert_eq!(message.context, "");
2291
2292 // Check message in request
2293 let request = thread.read_with(cx, |thread, cx| {
2294 thread.to_completion_request(RequestKind::Chat, cx)
2295 });
2296
2297 assert_eq!(request.messages.len(), 2);
2298 assert_eq!(
2299 request.messages[1].string_contents(),
2300 "What is the best way to learn Rust?"
2301 );
2302
2303 // Add second message, also without context
2304 let message2_id = thread.update(cx, |thread, cx| {
2305 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2306 });
2307
2308 let message2 =
2309 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2310 assert_eq!(message2.context, "");
2311
2312 // Check that both messages appear in the request
2313 let request = thread.read_with(cx, |thread, cx| {
2314 thread.to_completion_request(RequestKind::Chat, cx)
2315 });
2316
2317 assert_eq!(request.messages.len(), 3);
2318 assert_eq!(
2319 request.messages[1].string_contents(),
2320 "What is the best way to learn Rust?"
2321 );
2322 assert_eq!(
2323 request.messages[2].string_contents(),
2324 "Are there any good books?"
2325 );
2326 }
2327
2328 #[gpui::test]
2329 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2330 init_test_settings(cx);
2331
2332 let project = create_test_project(
2333 cx,
2334 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2335 )
2336 .await;
2337
2338 let (_workspace, _thread_store, thread, context_store) =
2339 setup_test_environment(cx, project.clone()).await;
2340
2341 // Open buffer and add it to context
2342 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2343 .await
2344 .unwrap();
2345
2346 let context =
2347 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2348
2349 // Insert user message with the buffer as context
2350 thread.update(cx, |thread, cx| {
2351 thread.insert_user_message("Explain this code", vec![context], None, cx)
2352 });
2353
2354 // Create a request and check that it doesn't have a stale buffer warning yet
2355 let initial_request = thread.read_with(cx, |thread, cx| {
2356 thread.to_completion_request(RequestKind::Chat, cx)
2357 });
2358
2359 // Make sure we don't have a stale file warning yet
2360 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2361 msg.string_contents()
2362 .contains("These files changed since last read:")
2363 });
2364 assert!(
2365 !has_stale_warning,
2366 "Should not have stale buffer warning before buffer is modified"
2367 );
2368
2369 // Modify the buffer
2370 buffer.update(cx, |buffer, cx| {
2371 // Find a position at the end of line 1
2372 buffer.edit(
2373 [(1..1, "\n println!(\"Added a new line\");\n")],
2374 None,
2375 cx,
2376 );
2377 });
2378
2379 // Insert another user message without context
2380 thread.update(cx, |thread, cx| {
2381 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2382 });
2383
2384 // Create a new request and check for the stale buffer warning
2385 let new_request = thread.read_with(cx, |thread, cx| {
2386 thread.to_completion_request(RequestKind::Chat, cx)
2387 });
2388
2389 // We should have a stale file warning as the last message
2390 let last_message = new_request
2391 .messages
2392 .last()
2393 .expect("Request should have messages");
2394
2395 // The last message should be the stale buffer notification
2396 assert_eq!(last_message.role, Role::User);
2397
2398 // Check the exact content of the message
2399 let expected_content = "These files changed since last read:\n- code.rs\n";
2400 assert_eq!(
2401 last_message.string_contents(),
2402 expected_content,
2403 "Last message should be exactly the stale buffer notification"
2404 );
2405 }
2406
2407 fn init_test_settings(cx: &mut TestAppContext) {
2408 cx.update(|cx| {
2409 let settings_store = SettingsStore::test(cx);
2410 cx.set_global(settings_store);
2411 language::init(cx);
2412 Project::init_settings(cx);
2413 AssistantSettings::register(cx);
2414 thread_store::init(cx);
2415 workspace::init_settings(cx);
2416 ThemeSettings::register(cx);
2417 ContextServerSettings::register(cx);
2418 EditorSettings::register(cx);
2419 });
2420 }
2421
2422 // Helper to create a test project with test files
2423 async fn create_test_project(
2424 cx: &mut TestAppContext,
2425 files: serde_json::Value,
2426 ) -> Entity<Project> {
2427 let fs = FakeFs::new(cx.executor());
2428 fs.insert_tree(path!("/test"), files).await;
2429 Project::test(fs, [path!("/test").as_ref()], cx).await
2430 }
2431
2432 async fn setup_test_environment(
2433 cx: &mut TestAppContext,
2434 project: Entity<Project>,
2435 ) -> (
2436 Entity<Workspace>,
2437 Entity<ThreadStore>,
2438 Entity<Thread>,
2439 Entity<ContextStore>,
2440 ) {
2441 let (workspace, cx) =
2442 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2443
2444 let thread_store = cx
2445 .update(|_, cx| {
2446 ThreadStore::load(
2447 project.clone(),
2448 cx.new(|_| ToolWorkingSet::default()),
2449 Arc::new(PromptBuilder::new(None).unwrap()),
2450 cx,
2451 )
2452 })
2453 .await;
2454
2455 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2456 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2457
2458 (workspace, thread_store, thread, context_store)
2459 }
2460
2461 async fn add_file_to_context(
2462 project: &Entity<Project>,
2463 context_store: &Entity<ContextStore>,
2464 path: &str,
2465 cx: &mut TestAppContext,
2466 ) -> Result<Entity<language::Buffer>> {
2467 let buffer_path = project
2468 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2469 .unwrap();
2470
2471 let buffer = project
2472 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2473 .await
2474 .unwrap();
2475
2476 context_store
2477 .update(cx, |store, cx| {
2478 store.add_file_from_buffer(buffer.clone(), cx)
2479 })
2480 .await?;
2481
2482 Ok(buffer)
2483 }
2484}