1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
23};
24use project::Project;
25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
26use prompt_store::PromptBuilder;
27use proto::Plan;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::Settings;
31use thiserror::Error;
32use util::{ResultExt as _, TryFutureExt as _, post_inc};
33use uuid::Uuid;
34
35use crate::context::{AssistantContext, ContextId, format_context_as_string};
36use crate::thread_store::{
37 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
38 SerializedToolUse, SharedProjectContext,
39};
40use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
41
42#[derive(Debug, Clone, Copy)]
43pub enum RequestKind {
44 Chat,
45 /// Used when summarizing a thread.
46 Summarize,
47}
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
73pub struct MessageId(pub(crate) usize);
74
75impl MessageId {
76 fn post_inc(&mut self) -> Self {
77 Self(post_inc(&mut self.0))
78 }
79}
80
81/// A message in a [`Thread`].
82#[derive(Debug, Clone)]
83pub struct Message {
84 pub id: MessageId,
85 pub role: Role,
86 pub segments: Vec<MessageSegment>,
87 pub context: String,
88}
89
90impl Message {
91 /// Returns whether the message contains any meaningful text that should be displayed
92 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
93 pub fn should_display_content(&self) -> bool {
94 self.segments.iter().all(|segment| segment.should_display())
95 }
96
97 pub fn push_thinking(&mut self, text: &str) {
98 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
99 segment.push_str(text);
100 } else {
101 self.segments
102 .push(MessageSegment::Thinking(text.to_string()));
103 }
104 }
105
106 pub fn push_text(&mut self, text: &str) {
107 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
108 segment.push_str(text);
109 } else {
110 self.segments.push(MessageSegment::Text(text.to_string()));
111 }
112 }
113
114 pub fn to_string(&self) -> String {
115 let mut result = String::new();
116
117 if !self.context.is_empty() {
118 result.push_str(&self.context);
119 }
120
121 for segment in &self.segments {
122 match segment {
123 MessageSegment::Text(text) => result.push_str(text),
124 MessageSegment::Thinking(text) => {
125 result.push_str("<think>");
126 result.push_str(text);
127 result.push_str("</think>");
128 }
129 }
130 }
131
132 result
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub enum MessageSegment {
138 Text(String),
139 Thinking(String),
140}
141
142impl MessageSegment {
143 pub fn text_mut(&mut self) -> &mut String {
144 match self {
145 Self::Text(text) => text,
146 Self::Thinking(text) => text,
147 }
148 }
149
150 pub fn should_display(&self) -> bool {
151 // We add USING_TOOL_MARKER when making a request that includes tool uses
152 // without non-whitespace text around them, and this can cause the model
153 // to mimic the pattern, so we consider those segments not displayable.
154 match self {
155 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
156 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ProjectSnapshot {
163 pub worktree_snapshots: Vec<WorktreeSnapshot>,
164 pub unsaved_buffer_paths: Vec<String>,
165 pub timestamp: DateTime<Utc>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct WorktreeSnapshot {
170 pub worktree_path: String,
171 pub git_state: Option<GitState>,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct GitState {
176 pub remote_url: Option<String>,
177 pub head_sha: Option<String>,
178 pub current_branch: Option<String>,
179 pub diff: Option<String>,
180}
181
182#[derive(Clone)]
183pub struct ThreadCheckpoint {
184 message_id: MessageId,
185 git_checkpoint: GitStoreCheckpoint,
186}
187
188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
189pub enum ThreadFeedback {
190 Positive,
191 Negative,
192}
193
194pub enum LastRestoreCheckpoint {
195 Pending {
196 message_id: MessageId,
197 },
198 Error {
199 message_id: MessageId,
200 error: String,
201 },
202}
203
204impl LastRestoreCheckpoint {
205 pub fn message_id(&self) -> MessageId {
206 match self {
207 LastRestoreCheckpoint::Pending { message_id } => *message_id,
208 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
209 }
210 }
211}
212
213#[derive(Clone, Debug, Default, Serialize, Deserialize)]
214pub enum DetailedSummaryState {
215 #[default]
216 NotGenerated,
217 Generating {
218 message_id: MessageId,
219 },
220 Generated {
221 text: SharedString,
222 message_id: MessageId,
223 },
224}
225
226#[derive(Default)]
227pub struct TotalTokenUsage {
228 pub total: usize,
229 pub max: usize,
230}
231
232impl TotalTokenUsage {
233 pub fn ratio(&self) -> TokenUsageRatio {
234 #[cfg(debug_assertions)]
235 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
236 .unwrap_or("0.8".to_string())
237 .parse()
238 .unwrap();
239 #[cfg(not(debug_assertions))]
240 let warning_threshold: f32 = 0.8;
241
242 if self.total >= self.max {
243 TokenUsageRatio::Exceeded
244 } else if self.total as f32 / self.max as f32 >= warning_threshold {
245 TokenUsageRatio::Warning
246 } else {
247 TokenUsageRatio::Normal
248 }
249 }
250
251 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
252 TotalTokenUsage {
253 total: self.total + tokens,
254 max: self.max,
255 }
256 }
257}
258
259#[derive(Debug, Default, PartialEq, Eq)]
260pub enum TokenUsageRatio {
261 #[default]
262 Normal,
263 Warning,
264 Exceeded,
265}
266
267/// A thread of conversation with the LLM.
268pub struct Thread {
269 id: ThreadId,
270 updated_at: DateTime<Utc>,
271 summary: Option<SharedString>,
272 pending_summary: Task<Option<()>>,
273 detailed_summary_state: DetailedSummaryState,
274 messages: Vec<Message>,
275 next_message_id: MessageId,
276 context: BTreeMap<ContextId, AssistantContext>,
277 context_by_message: HashMap<MessageId, Vec<ContextId>>,
278 project_context: SharedProjectContext,
279 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
280 completion_count: usize,
281 pending_completions: Vec<PendingCompletion>,
282 project: Entity<Project>,
283 prompt_builder: Arc<PromptBuilder>,
284 tools: Entity<ToolWorkingSet>,
285 tool_use: ToolUseState,
286 action_log: Entity<ActionLog>,
287 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
288 pending_checkpoint: Option<ThreadCheckpoint>,
289 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
290 request_token_usage: Vec<TokenUsage>,
291 cumulative_token_usage: TokenUsage,
292 exceeded_window_error: Option<ExceededWindowError>,
293 feedback: Option<ThreadFeedback>,
294 message_feedback: HashMap<MessageId, ThreadFeedback>,
295 last_auto_capture_at: Option<Instant>,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct ExceededWindowError {
300 /// Model used when last message exceeded context window
301 model_id: LanguageModelId,
302 /// Token count including last message
303 token_count: usize,
304}
305
306impl Thread {
307 pub fn new(
308 project: Entity<Project>,
309 tools: Entity<ToolWorkingSet>,
310 prompt_builder: Arc<PromptBuilder>,
311 system_prompt: SharedProjectContext,
312 cx: &mut Context<Self>,
313 ) -> Self {
314 Self {
315 id: ThreadId::new(),
316 updated_at: Utc::now(),
317 summary: None,
318 pending_summary: Task::ready(None),
319 detailed_summary_state: DetailedSummaryState::NotGenerated,
320 messages: Vec::new(),
321 next_message_id: MessageId(0),
322 context: BTreeMap::default(),
323 context_by_message: HashMap::default(),
324 project_context: system_prompt,
325 checkpoints_by_message: HashMap::default(),
326 completion_count: 0,
327 pending_completions: Vec::new(),
328 project: project.clone(),
329 prompt_builder,
330 tools: tools.clone(),
331 last_restore_checkpoint: None,
332 pending_checkpoint: None,
333 tool_use: ToolUseState::new(tools.clone()),
334 action_log: cx.new(|_| ActionLog::new(project.clone())),
335 initial_project_snapshot: {
336 let project_snapshot = Self::project_snapshot(project, cx);
337 cx.foreground_executor()
338 .spawn(async move { Some(project_snapshot.await) })
339 .shared()
340 },
341 request_token_usage: Vec::new(),
342 cumulative_token_usage: TokenUsage::default(),
343 exceeded_window_error: None,
344 feedback: None,
345 message_feedback: HashMap::default(),
346 last_auto_capture_at: None,
347 }
348 }
349
350 pub fn deserialize(
351 id: ThreadId,
352 serialized: SerializedThread,
353 project: Entity<Project>,
354 tools: Entity<ToolWorkingSet>,
355 prompt_builder: Arc<PromptBuilder>,
356 project_context: SharedProjectContext,
357 cx: &mut Context<Self>,
358 ) -> Self {
359 let next_message_id = MessageId(
360 serialized
361 .messages
362 .last()
363 .map(|message| message.id.0 + 1)
364 .unwrap_or(0),
365 );
366 let tool_use =
367 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
368
369 Self {
370 id,
371 updated_at: serialized.updated_at,
372 summary: Some(serialized.summary),
373 pending_summary: Task::ready(None),
374 detailed_summary_state: serialized.detailed_summary_state,
375 messages: serialized
376 .messages
377 .into_iter()
378 .map(|message| Message {
379 id: message.id,
380 role: message.role,
381 segments: message
382 .segments
383 .into_iter()
384 .map(|segment| match segment {
385 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
386 SerializedMessageSegment::Thinking { text } => {
387 MessageSegment::Thinking(text)
388 }
389 })
390 .collect(),
391 context: message.context,
392 })
393 .collect(),
394 next_message_id,
395 context: BTreeMap::default(),
396 context_by_message: HashMap::default(),
397 project_context,
398 checkpoints_by_message: HashMap::default(),
399 completion_count: 0,
400 pending_completions: Vec::new(),
401 last_restore_checkpoint: None,
402 pending_checkpoint: None,
403 project: project.clone(),
404 prompt_builder,
405 tools,
406 tool_use,
407 action_log: cx.new(|_| ActionLog::new(project)),
408 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
409 request_token_usage: serialized.request_token_usage,
410 cumulative_token_usage: serialized.cumulative_token_usage,
411 exceeded_window_error: None,
412 feedback: None,
413 message_feedback: HashMap::default(),
414 last_auto_capture_at: None,
415 }
416 }
417
418 pub fn id(&self) -> &ThreadId {
419 &self.id
420 }
421
422 pub fn is_empty(&self) -> bool {
423 self.messages.is_empty()
424 }
425
426 pub fn updated_at(&self) -> DateTime<Utc> {
427 self.updated_at
428 }
429
430 pub fn touch_updated_at(&mut self) {
431 self.updated_at = Utc::now();
432 }
433
434 pub fn summary(&self) -> Option<SharedString> {
435 self.summary.clone()
436 }
437
438 pub fn project_context(&self) -> SharedProjectContext {
439 self.project_context.clone()
440 }
441
442 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
443
444 pub fn summary_or_default(&self) -> SharedString {
445 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
446 }
447
448 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
449 let Some(current_summary) = &self.summary else {
450 // Don't allow setting summary until generated
451 return;
452 };
453
454 let mut new_summary = new_summary.into();
455
456 if new_summary.is_empty() {
457 new_summary = Self::DEFAULT_SUMMARY;
458 }
459
460 if current_summary != &new_summary {
461 self.summary = Some(new_summary);
462 cx.emit(ThreadEvent::SummaryChanged);
463 }
464 }
465
466 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
467 self.latest_detailed_summary()
468 .unwrap_or_else(|| self.text().into())
469 }
470
471 fn latest_detailed_summary(&self) -> Option<SharedString> {
472 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
473 Some(text.clone())
474 } else {
475 None
476 }
477 }
478
479 pub fn message(&self, id: MessageId) -> Option<&Message> {
480 self.messages.iter().find(|message| message.id == id)
481 }
482
483 pub fn messages(&self) -> impl Iterator<Item = &Message> {
484 self.messages.iter()
485 }
486
487 pub fn is_generating(&self) -> bool {
488 !self.pending_completions.is_empty() || !self.all_tools_finished()
489 }
490
491 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
492 &self.tools
493 }
494
495 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
496 self.tool_use
497 .pending_tool_uses()
498 .into_iter()
499 .find(|tool_use| &tool_use.id == id)
500 }
501
502 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
503 self.tool_use
504 .pending_tool_uses()
505 .into_iter()
506 .filter(|tool_use| tool_use.status.needs_confirmation())
507 }
508
509 pub fn has_pending_tool_uses(&self) -> bool {
510 !self.tool_use.pending_tool_uses().is_empty()
511 }
512
513 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
514 self.checkpoints_by_message.get(&id).cloned()
515 }
516
517 pub fn restore_checkpoint(
518 &mut self,
519 checkpoint: ThreadCheckpoint,
520 cx: &mut Context<Self>,
521 ) -> Task<Result<()>> {
522 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
523 message_id: checkpoint.message_id,
524 });
525 cx.emit(ThreadEvent::CheckpointChanged);
526 cx.notify();
527
528 let git_store = self.project().read(cx).git_store().clone();
529 let restore = git_store.update(cx, |git_store, cx| {
530 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
531 });
532
533 cx.spawn(async move |this, cx| {
534 let result = restore.await;
535 this.update(cx, |this, cx| {
536 if let Err(err) = result.as_ref() {
537 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
538 message_id: checkpoint.message_id,
539 error: err.to_string(),
540 });
541 } else {
542 this.truncate(checkpoint.message_id, cx);
543 this.last_restore_checkpoint = None;
544 }
545 this.pending_checkpoint = None;
546 cx.emit(ThreadEvent::CheckpointChanged);
547 cx.notify();
548 })?;
549 result
550 })
551 }
552
553 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
554 let pending_checkpoint = if self.is_generating() {
555 return;
556 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
557 checkpoint
558 } else {
559 return;
560 };
561
562 let git_store = self.project.read(cx).git_store().clone();
563 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
564 cx.spawn(async move |this, cx| match final_checkpoint.await {
565 Ok(final_checkpoint) => {
566 let equal = git_store
567 .update(cx, |store, cx| {
568 store.compare_checkpoints(
569 pending_checkpoint.git_checkpoint.clone(),
570 final_checkpoint.clone(),
571 cx,
572 )
573 })?
574 .await
575 .unwrap_or(false);
576
577 if equal {
578 git_store
579 .update(cx, |store, cx| {
580 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
581 })?
582 .detach();
583 } else {
584 this.update(cx, |this, cx| {
585 this.insert_checkpoint(pending_checkpoint, cx)
586 })?;
587 }
588
589 git_store
590 .update(cx, |store, cx| {
591 store.delete_checkpoint(final_checkpoint, cx)
592 })?
593 .detach();
594
595 Ok(())
596 }
597 Err(_) => this.update(cx, |this, cx| {
598 this.insert_checkpoint(pending_checkpoint, cx)
599 }),
600 })
601 .detach();
602 }
603
604 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
605 self.checkpoints_by_message
606 .insert(checkpoint.message_id, checkpoint);
607 cx.emit(ThreadEvent::CheckpointChanged);
608 cx.notify();
609 }
610
611 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
612 self.last_restore_checkpoint.as_ref()
613 }
614
615 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
616 let Some(message_ix) = self
617 .messages
618 .iter()
619 .rposition(|message| message.id == message_id)
620 else {
621 return;
622 };
623 for deleted_message in self.messages.drain(message_ix..) {
624 self.context_by_message.remove(&deleted_message.id);
625 self.checkpoints_by_message.remove(&deleted_message.id);
626 }
627 cx.notify();
628 }
629
630 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
631 self.context_by_message
632 .get(&id)
633 .into_iter()
634 .flat_map(|context| {
635 context
636 .iter()
637 .filter_map(|context_id| self.context.get(&context_id))
638 })
639 }
640
641 /// Returns whether all of the tool uses have finished running.
642 pub fn all_tools_finished(&self) -> bool {
643 // If the only pending tool uses left are the ones with errors, then
644 // that means that we've finished running all of the pending tools.
645 self.tool_use
646 .pending_tool_uses()
647 .iter()
648 .all(|tool_use| tool_use.status.is_error())
649 }
650
651 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
652 self.tool_use.tool_uses_for_message(id, cx)
653 }
654
655 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
656 self.tool_use.tool_results_for_message(id)
657 }
658
659 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
660 self.tool_use.tool_result(id)
661 }
662
663 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
664 Some(&self.tool_use.tool_result(id)?.content)
665 }
666
667 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
668 self.tool_use.tool_result_card(id).cloned()
669 }
670
671 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
672 self.tool_use.message_has_tool_results(message_id)
673 }
674
675 /// Filter out contexts that have already been included in previous messages
676 pub fn filter_new_context<'a>(
677 &self,
678 context: impl Iterator<Item = &'a AssistantContext>,
679 ) -> impl Iterator<Item = &'a AssistantContext> {
680 context.filter(|ctx| self.is_context_new(ctx))
681 }
682
683 fn is_context_new(&self, context: &AssistantContext) -> bool {
684 !self.context.contains_key(&context.id())
685 }
686
687 pub fn insert_user_message(
688 &mut self,
689 text: impl Into<String>,
690 context: Vec<AssistantContext>,
691 git_checkpoint: Option<GitStoreCheckpoint>,
692 cx: &mut Context<Self>,
693 ) -> MessageId {
694 let text = text.into();
695
696 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
697
698 let new_context: Vec<_> = context
699 .into_iter()
700 .filter(|ctx| self.is_context_new(ctx))
701 .collect();
702
703 if !new_context.is_empty() {
704 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
705 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
706 message.context = context_string;
707 }
708 }
709
710 self.action_log.update(cx, |log, cx| {
711 // Track all buffers added as context
712 for ctx in &new_context {
713 match ctx {
714 AssistantContext::File(file_ctx) => {
715 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
716 }
717 AssistantContext::Directory(dir_ctx) => {
718 for context_buffer in &dir_ctx.context_buffers {
719 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
720 }
721 }
722 AssistantContext::Symbol(symbol_ctx) => {
723 log.buffer_added_as_context(
724 symbol_ctx.context_symbol.buffer.clone(),
725 cx,
726 );
727 }
728 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
729 }
730 }
731 });
732 }
733
734 let context_ids = new_context
735 .iter()
736 .map(|context| context.id())
737 .collect::<Vec<_>>();
738 self.context.extend(
739 new_context
740 .into_iter()
741 .map(|context| (context.id(), context)),
742 );
743 self.context_by_message.insert(message_id, context_ids);
744
745 if let Some(git_checkpoint) = git_checkpoint {
746 self.pending_checkpoint = Some(ThreadCheckpoint {
747 message_id,
748 git_checkpoint,
749 });
750 }
751
752 self.auto_capture_telemetry(cx);
753
754 message_id
755 }
756
757 pub fn insert_message(
758 &mut self,
759 role: Role,
760 segments: Vec<MessageSegment>,
761 cx: &mut Context<Self>,
762 ) -> MessageId {
763 let id = self.next_message_id.post_inc();
764 self.messages.push(Message {
765 id,
766 role,
767 segments,
768 context: String::new(),
769 });
770 self.touch_updated_at();
771 cx.emit(ThreadEvent::MessageAdded(id));
772 id
773 }
774
775 pub fn edit_message(
776 &mut self,
777 id: MessageId,
778 new_role: Role,
779 new_segments: Vec<MessageSegment>,
780 cx: &mut Context<Self>,
781 ) -> bool {
782 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
783 return false;
784 };
785 message.role = new_role;
786 message.segments = new_segments;
787 self.touch_updated_at();
788 cx.emit(ThreadEvent::MessageEdited(id));
789 true
790 }
791
792 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
793 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
794 return false;
795 };
796 self.messages.remove(index);
797 self.context_by_message.remove(&id);
798 self.touch_updated_at();
799 cx.emit(ThreadEvent::MessageDeleted(id));
800 true
801 }
802
803 /// Returns the representation of this [`Thread`] in a textual form.
804 ///
805 /// This is the representation we use when attaching a thread as context to another thread.
806 pub fn text(&self) -> String {
807 let mut text = String::new();
808
809 for message in &self.messages {
810 text.push_str(match message.role {
811 language_model::Role::User => "User:",
812 language_model::Role::Assistant => "Assistant:",
813 language_model::Role::System => "System:",
814 });
815 text.push('\n');
816
817 for segment in &message.segments {
818 match segment {
819 MessageSegment::Text(content) => text.push_str(content),
820 MessageSegment::Thinking(content) => {
821 text.push_str(&format!("<think>{}</think>", content))
822 }
823 }
824 }
825 text.push('\n');
826 }
827
828 text
829 }
830
831 /// Serializes this thread into a format for storage or telemetry.
832 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
833 let initial_project_snapshot = self.initial_project_snapshot.clone();
834 cx.spawn(async move |this, cx| {
835 let initial_project_snapshot = initial_project_snapshot.await;
836 this.read_with(cx, |this, cx| SerializedThread {
837 version: SerializedThread::VERSION.to_string(),
838 summary: this.summary_or_default(),
839 updated_at: this.updated_at(),
840 messages: this
841 .messages()
842 .map(|message| SerializedMessage {
843 id: message.id,
844 role: message.role,
845 segments: message
846 .segments
847 .iter()
848 .map(|segment| match segment {
849 MessageSegment::Text(text) => {
850 SerializedMessageSegment::Text { text: text.clone() }
851 }
852 MessageSegment::Thinking(text) => {
853 SerializedMessageSegment::Thinking { text: text.clone() }
854 }
855 })
856 .collect(),
857 tool_uses: this
858 .tool_uses_for_message(message.id, cx)
859 .into_iter()
860 .map(|tool_use| SerializedToolUse {
861 id: tool_use.id,
862 name: tool_use.name,
863 input: tool_use.input,
864 })
865 .collect(),
866 tool_results: this
867 .tool_results_for_message(message.id)
868 .into_iter()
869 .map(|tool_result| SerializedToolResult {
870 tool_use_id: tool_result.tool_use_id.clone(),
871 is_error: tool_result.is_error,
872 content: tool_result.content.clone(),
873 })
874 .collect(),
875 context: message.context.clone(),
876 })
877 .collect(),
878 initial_project_snapshot,
879 cumulative_token_usage: this.cumulative_token_usage,
880 request_token_usage: this.request_token_usage.clone(),
881 detailed_summary_state: this.detailed_summary_state.clone(),
882 exceeded_window_error: this.exceeded_window_error.clone(),
883 })
884 })
885 }
886
887 pub fn send_to_model(
888 &mut self,
889 model: Arc<dyn LanguageModel>,
890 request_kind: RequestKind,
891 cx: &mut Context<Self>,
892 ) {
893 let mut request = self.to_completion_request(request_kind, cx);
894 if model.supports_tools() {
895 request.tools = {
896 let mut tools = Vec::new();
897 tools.extend(
898 self.tools()
899 .read(cx)
900 .enabled_tools(cx)
901 .into_iter()
902 .filter_map(|tool| {
903 // Skip tools that cannot be supported
904 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
905 Some(LanguageModelRequestTool {
906 name: tool.name(),
907 description: tool.description(),
908 input_schema,
909 })
910 }),
911 );
912
913 tools
914 };
915 }
916
917 self.stream_completion(request, model, cx);
918 }
919
920 pub fn used_tools_since_last_user_message(&self) -> bool {
921 for message in self.messages.iter().rev() {
922 if self.tool_use.message_has_tool_results(message.id) {
923 return true;
924 } else if message.role == Role::User {
925 return false;
926 }
927 }
928
929 false
930 }
931
932 pub fn to_completion_request(
933 &self,
934 request_kind: RequestKind,
935 cx: &App,
936 ) -> LanguageModelRequest {
937 let mut request = LanguageModelRequest {
938 messages: vec![],
939 tools: Vec::new(),
940 stop: Vec::new(),
941 temperature: None,
942 };
943
944 if let Some(project_context) = self.project_context.borrow().as_ref() {
945 if let Some(system_prompt) = self
946 .prompt_builder
947 .generate_assistant_system_prompt(project_context)
948 .context("failed to generate assistant system prompt")
949 .log_err()
950 {
951 request.messages.push(LanguageModelRequestMessage {
952 role: Role::System,
953 content: vec![MessageContent::Text(system_prompt)],
954 cache: true,
955 });
956 }
957 } else {
958 log::error!("project_context not set.")
959 }
960
961 for message in &self.messages {
962 let mut request_message = LanguageModelRequestMessage {
963 role: message.role,
964 content: Vec::new(),
965 cache: false,
966 };
967
968 match request_kind {
969 RequestKind::Chat => {
970 self.tool_use
971 .attach_tool_results(message.id, &mut request_message);
972 }
973 RequestKind::Summarize => {
974 // We don't care about tool use during summarization.
975 if self.tool_use.message_has_tool_results(message.id) {
976 continue;
977 }
978 }
979 }
980
981 if !message.segments.is_empty() {
982 request_message
983 .content
984 .push(MessageContent::Text(message.to_string()));
985 }
986
987 match request_kind {
988 RequestKind::Chat => {
989 self.tool_use
990 .attach_tool_uses(message.id, &mut request_message);
991 }
992 RequestKind::Summarize => {
993 // We don't care about tool use during summarization.
994 }
995 };
996
997 request.messages.push(request_message);
998 }
999
1000 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1001 if let Some(last) = request.messages.last_mut() {
1002 last.cache = true;
1003 }
1004
1005 self.attached_tracked_files_state(&mut request.messages, cx);
1006
1007 request
1008 }
1009
1010 fn attached_tracked_files_state(
1011 &self,
1012 messages: &mut Vec<LanguageModelRequestMessage>,
1013 cx: &App,
1014 ) {
1015 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1016
1017 let mut stale_message = String::new();
1018
1019 let action_log = self.action_log.read(cx);
1020
1021 for stale_file in action_log.stale_buffers(cx) {
1022 let Some(file) = stale_file.read(cx).file() else {
1023 continue;
1024 };
1025
1026 if stale_message.is_empty() {
1027 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1028 }
1029
1030 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1031 }
1032
1033 let mut content = Vec::with_capacity(2);
1034
1035 if !stale_message.is_empty() {
1036 content.push(stale_message.into());
1037 }
1038
1039 if action_log.has_edited_files_since_project_diagnostics_check() {
1040 content.push(
1041 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1042 and fix all errors AND warnings you introduced! \
1043 DO NOT mention you're going to do this until you're done."
1044 .into(),
1045 );
1046 }
1047
1048 if !content.is_empty() {
1049 let context_message = LanguageModelRequestMessage {
1050 role: Role::User,
1051 content,
1052 cache: false,
1053 };
1054
1055 messages.push(context_message);
1056 }
1057 }
1058
1059 pub fn stream_completion(
1060 &mut self,
1061 request: LanguageModelRequest,
1062 model: Arc<dyn LanguageModel>,
1063 cx: &mut Context<Self>,
1064 ) {
1065 let pending_completion_id = post_inc(&mut self.completion_count);
1066 let task = cx.spawn(async move |thread, cx| {
1067 let stream = model.stream_completion(request, &cx);
1068 let initial_token_usage =
1069 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1070 let stream_completion = async {
1071 let mut events = stream.await?;
1072 let mut stop_reason = StopReason::EndTurn;
1073 let mut current_token_usage = TokenUsage::default();
1074
1075 while let Some(event) = events.next().await {
1076 let event = event?;
1077
1078 thread.update(cx, |thread, cx| {
1079 match event {
1080 LanguageModelCompletionEvent::StartMessage { .. } => {
1081 thread.insert_message(
1082 Role::Assistant,
1083 vec![MessageSegment::Text(String::new())],
1084 cx,
1085 );
1086 }
1087 LanguageModelCompletionEvent::Stop(reason) => {
1088 stop_reason = reason;
1089 }
1090 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1091 thread.update_token_usage_at_last_message(token_usage);
1092 thread.cumulative_token_usage = thread.cumulative_token_usage
1093 + token_usage
1094 - current_token_usage;
1095 current_token_usage = token_usage;
1096 }
1097 LanguageModelCompletionEvent::Text(chunk) => {
1098 if let Some(last_message) = thread.messages.last_mut() {
1099 if last_message.role == Role::Assistant {
1100 last_message.push_text(&chunk);
1101 cx.emit(ThreadEvent::StreamedAssistantText(
1102 last_message.id,
1103 chunk,
1104 ));
1105 } else {
1106 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1107 // of a new Assistant response.
1108 //
1109 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1110 // will result in duplicating the text of the chunk in the rendered Markdown.
1111 thread.insert_message(
1112 Role::Assistant,
1113 vec![MessageSegment::Text(chunk.to_string())],
1114 cx,
1115 );
1116 };
1117 }
1118 }
1119 LanguageModelCompletionEvent::Thinking(chunk) => {
1120 if let Some(last_message) = thread.messages.last_mut() {
1121 if last_message.role == Role::Assistant {
1122 last_message.push_thinking(&chunk);
1123 cx.emit(ThreadEvent::StreamedAssistantThinking(
1124 last_message.id,
1125 chunk,
1126 ));
1127 } else {
1128 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1129 // of a new Assistant response.
1130 //
1131 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1132 // will result in duplicating the text of the chunk in the rendered Markdown.
1133 thread.insert_message(
1134 Role::Assistant,
1135 vec![MessageSegment::Thinking(chunk.to_string())],
1136 cx,
1137 );
1138 };
1139 }
1140 }
1141 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1142 let last_assistant_message_id = thread
1143 .messages
1144 .iter_mut()
1145 .rfind(|message| message.role == Role::Assistant)
1146 .map(|message| message.id)
1147 .unwrap_or_else(|| {
1148 thread.insert_message(Role::Assistant, vec![], cx)
1149 });
1150
1151 thread.tool_use.request_tool_use(
1152 last_assistant_message_id,
1153 tool_use,
1154 cx,
1155 );
1156 }
1157 }
1158
1159 thread.touch_updated_at();
1160 cx.emit(ThreadEvent::StreamedCompletion);
1161 cx.notify();
1162
1163 thread.auto_capture_telemetry(cx);
1164 })?;
1165
1166 smol::future::yield_now().await;
1167 }
1168
1169 thread.update(cx, |thread, cx| {
1170 thread
1171 .pending_completions
1172 .retain(|completion| completion.id != pending_completion_id);
1173
1174 if thread.summary.is_none() && thread.messages.len() >= 2 {
1175 thread.summarize(cx);
1176 }
1177 })?;
1178
1179 anyhow::Ok(stop_reason)
1180 };
1181
1182 let result = stream_completion.await;
1183
1184 thread
1185 .update(cx, |thread, cx| {
1186 thread.finalize_pending_checkpoint(cx);
1187 match result.as_ref() {
1188 Ok(stop_reason) => match stop_reason {
1189 StopReason::ToolUse => {
1190 let tool_uses = thread.use_pending_tools(cx);
1191 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1192 }
1193 StopReason::EndTurn => {}
1194 StopReason::MaxTokens => {}
1195 },
1196 Err(error) => {
1197 if error.is::<PaymentRequiredError>() {
1198 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1199 } else if error.is::<MaxMonthlySpendReachedError>() {
1200 cx.emit(ThreadEvent::ShowError(
1201 ThreadError::MaxMonthlySpendReached,
1202 ));
1203 } else if let Some(error) =
1204 error.downcast_ref::<ModelRequestLimitReachedError>()
1205 {
1206 cx.emit(ThreadEvent::ShowError(
1207 ThreadError::ModelRequestLimitReached { plan: error.plan },
1208 ));
1209 } else if let Some(known_error) =
1210 error.downcast_ref::<LanguageModelKnownError>()
1211 {
1212 match known_error {
1213 LanguageModelKnownError::ContextWindowLimitExceeded {
1214 tokens,
1215 } => {
1216 thread.exceeded_window_error = Some(ExceededWindowError {
1217 model_id: model.id(),
1218 token_count: *tokens,
1219 });
1220 cx.notify();
1221 }
1222 }
1223 } else {
1224 let error_message = error
1225 .chain()
1226 .map(|err| err.to_string())
1227 .collect::<Vec<_>>()
1228 .join("\n");
1229 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1230 header: "Error interacting with language model".into(),
1231 message: SharedString::from(error_message.clone()),
1232 }));
1233 }
1234
1235 thread.cancel_last_completion(cx);
1236 }
1237 }
1238 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1239
1240 thread.auto_capture_telemetry(cx);
1241
1242 if let Ok(initial_usage) = initial_token_usage {
1243 let usage = thread.cumulative_token_usage - initial_usage;
1244
1245 telemetry::event!(
1246 "Assistant Thread Completion",
1247 thread_id = thread.id().to_string(),
1248 model = model.telemetry_id(),
1249 model_provider = model.provider_id().to_string(),
1250 input_tokens = usage.input_tokens,
1251 output_tokens = usage.output_tokens,
1252 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1253 cache_read_input_tokens = usage.cache_read_input_tokens,
1254 );
1255 }
1256 })
1257 .ok();
1258 });
1259
1260 self.pending_completions.push(PendingCompletion {
1261 id: pending_completion_id,
1262 _task: task,
1263 });
1264 }
1265
1266 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1267 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1268 return;
1269 };
1270
1271 if !model.provider.is_authenticated(cx) {
1272 return;
1273 }
1274
1275 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1276 request.messages.push(LanguageModelRequestMessage {
1277 role: Role::User,
1278 content: vec![
1279 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1280 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1281 If the conversation is about a specific subject, include it in the title. \
1282 Be descriptive. DO NOT speak in the first person."
1283 .into(),
1284 ],
1285 cache: false,
1286 });
1287
1288 self.pending_summary = cx.spawn(async move |this, cx| {
1289 async move {
1290 let stream = model.model.stream_completion_text(request, &cx);
1291 let mut messages = stream.await?;
1292
1293 let mut new_summary = String::new();
1294 while let Some(message) = messages.stream.next().await {
1295 let text = message?;
1296 let mut lines = text.lines();
1297 new_summary.extend(lines.next());
1298
1299 // Stop if the LLM generated multiple lines.
1300 if lines.next().is_some() {
1301 break;
1302 }
1303 }
1304
1305 this.update(cx, |this, cx| {
1306 if !new_summary.is_empty() {
1307 this.summary = Some(new_summary.into());
1308 }
1309
1310 cx.emit(ThreadEvent::SummaryGenerated);
1311 })?;
1312
1313 anyhow::Ok(())
1314 }
1315 .log_err()
1316 .await
1317 });
1318 }
1319
1320 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1321 let last_message_id = self.messages.last().map(|message| message.id)?;
1322
1323 match &self.detailed_summary_state {
1324 DetailedSummaryState::Generating { message_id, .. }
1325 | DetailedSummaryState::Generated { message_id, .. }
1326 if *message_id == last_message_id =>
1327 {
1328 // Already up-to-date
1329 return None;
1330 }
1331 _ => {}
1332 }
1333
1334 let ConfiguredModel { model, provider } =
1335 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1336
1337 if !provider.is_authenticated(cx) {
1338 return None;
1339 }
1340
1341 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1342
1343 request.messages.push(LanguageModelRequestMessage {
1344 role: Role::User,
1345 content: vec![
1346 "Generate a detailed summary of this conversation. Include:\n\
1347 1. A brief overview of what was discussed\n\
1348 2. Key facts or information discovered\n\
1349 3. Outcomes or conclusions reached\n\
1350 4. Any action items or next steps if any\n\
1351 Format it in Markdown with headings and bullet points."
1352 .into(),
1353 ],
1354 cache: false,
1355 });
1356
1357 let task = cx.spawn(async move |thread, cx| {
1358 let stream = model.stream_completion_text(request, &cx);
1359 let Some(mut messages) = stream.await.log_err() else {
1360 thread
1361 .update(cx, |this, _cx| {
1362 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1363 })
1364 .log_err();
1365
1366 return;
1367 };
1368
1369 let mut new_detailed_summary = String::new();
1370
1371 while let Some(chunk) = messages.stream.next().await {
1372 if let Some(chunk) = chunk.log_err() {
1373 new_detailed_summary.push_str(&chunk);
1374 }
1375 }
1376
1377 thread
1378 .update(cx, |this, _cx| {
1379 this.detailed_summary_state = DetailedSummaryState::Generated {
1380 text: new_detailed_summary.into(),
1381 message_id: last_message_id,
1382 };
1383 })
1384 .log_err();
1385 });
1386
1387 self.detailed_summary_state = DetailedSummaryState::Generating {
1388 message_id: last_message_id,
1389 };
1390
1391 Some(task)
1392 }
1393
1394 pub fn is_generating_detailed_summary(&self) -> bool {
1395 matches!(
1396 self.detailed_summary_state,
1397 DetailedSummaryState::Generating { .. }
1398 )
1399 }
1400
1401 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1402 self.auto_capture_telemetry(cx);
1403 let request = self.to_completion_request(RequestKind::Chat, cx);
1404 let messages = Arc::new(request.messages);
1405 let pending_tool_uses = self
1406 .tool_use
1407 .pending_tool_uses()
1408 .into_iter()
1409 .filter(|tool_use| tool_use.status.is_idle())
1410 .cloned()
1411 .collect::<Vec<_>>();
1412
1413 for tool_use in pending_tool_uses.iter() {
1414 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1415 if tool.needs_confirmation(&tool_use.input, cx)
1416 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1417 {
1418 self.tool_use.confirm_tool_use(
1419 tool_use.id.clone(),
1420 tool_use.ui_text.clone(),
1421 tool_use.input.clone(),
1422 messages.clone(),
1423 tool,
1424 );
1425 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1426 } else {
1427 self.run_tool(
1428 tool_use.id.clone(),
1429 tool_use.ui_text.clone(),
1430 tool_use.input.clone(),
1431 &messages,
1432 tool,
1433 cx,
1434 );
1435 }
1436 }
1437 }
1438
1439 pending_tool_uses
1440 }
1441
1442 pub fn run_tool(
1443 &mut self,
1444 tool_use_id: LanguageModelToolUseId,
1445 ui_text: impl Into<SharedString>,
1446 input: serde_json::Value,
1447 messages: &[LanguageModelRequestMessage],
1448 tool: Arc<dyn Tool>,
1449 cx: &mut Context<Thread>,
1450 ) {
1451 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1452 self.tool_use
1453 .run_pending_tool(tool_use_id, ui_text.into(), task);
1454 }
1455
1456 fn spawn_tool_use(
1457 &mut self,
1458 tool_use_id: LanguageModelToolUseId,
1459 messages: &[LanguageModelRequestMessage],
1460 input: serde_json::Value,
1461 tool: Arc<dyn Tool>,
1462 cx: &mut Context<Thread>,
1463 ) -> Task<()> {
1464 let tool_name: Arc<str> = tool.name().into();
1465
1466 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1467 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1468 } else {
1469 tool.run(
1470 input,
1471 messages,
1472 self.project.clone(),
1473 self.action_log.clone(),
1474 cx,
1475 )
1476 };
1477
1478 // Store the card separately if it exists
1479 if let Some(card) = tool_result.card.clone() {
1480 self.tool_use
1481 .insert_tool_result_card(tool_use_id.clone(), card);
1482 }
1483
1484 cx.spawn({
1485 async move |thread: WeakEntity<Thread>, cx| {
1486 let output = tool_result.output.await;
1487
1488 thread
1489 .update(cx, |thread, cx| {
1490 let pending_tool_use = thread.tool_use.insert_tool_output(
1491 tool_use_id.clone(),
1492 tool_name,
1493 output,
1494 cx,
1495 );
1496 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1497 })
1498 .ok();
1499 }
1500 })
1501 }
1502
1503 fn tool_finished(
1504 &mut self,
1505 tool_use_id: LanguageModelToolUseId,
1506 pending_tool_use: Option<PendingToolUse>,
1507 canceled: bool,
1508 cx: &mut Context<Self>,
1509 ) {
1510 if self.all_tools_finished() {
1511 let model_registry = LanguageModelRegistry::read_global(cx);
1512 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1513 self.attach_tool_results(cx);
1514 if !canceled {
1515 self.send_to_model(model, RequestKind::Chat, cx);
1516 }
1517 }
1518 }
1519
1520 cx.emit(ThreadEvent::ToolFinished {
1521 tool_use_id,
1522 pending_tool_use,
1523 });
1524 }
1525
1526 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1527 // Insert a user message to contain the tool results.
1528 self.insert_user_message(
1529 // TODO: Sending up a user message without any content results in the model sending back
1530 // responses that also don't have any content. We currently don't handle this case well,
1531 // so for now we provide some text to keep the model on track.
1532 "Here are the tool results.",
1533 Vec::new(),
1534 None,
1535 cx,
1536 );
1537 }
1538
1539 /// Cancels the last pending completion, if there are any pending.
1540 ///
1541 /// Returns whether a completion was canceled.
1542 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1543 let canceled = if self.pending_completions.pop().is_some() {
1544 true
1545 } else {
1546 let mut canceled = false;
1547 for pending_tool_use in self.tool_use.cancel_pending() {
1548 canceled = true;
1549 self.tool_finished(
1550 pending_tool_use.id.clone(),
1551 Some(pending_tool_use),
1552 true,
1553 cx,
1554 );
1555 }
1556 canceled
1557 };
1558 self.finalize_pending_checkpoint(cx);
1559 canceled
1560 }
1561
1562 pub fn feedback(&self) -> Option<ThreadFeedback> {
1563 self.feedback
1564 }
1565
1566 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1567 self.message_feedback.get(&message_id).copied()
1568 }
1569
1570 pub fn report_message_feedback(
1571 &mut self,
1572 message_id: MessageId,
1573 feedback: ThreadFeedback,
1574 cx: &mut Context<Self>,
1575 ) -> Task<Result<()>> {
1576 if self.message_feedback.get(&message_id) == Some(&feedback) {
1577 return Task::ready(Ok(()));
1578 }
1579
1580 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1581 let serialized_thread = self.serialize(cx);
1582 let thread_id = self.id().clone();
1583 let client = self.project.read(cx).client();
1584
1585 let enabled_tool_names: Vec<String> = self
1586 .tools()
1587 .read(cx)
1588 .enabled_tools(cx)
1589 .iter()
1590 .map(|tool| tool.name().to_string())
1591 .collect();
1592
1593 self.message_feedback.insert(message_id, feedback);
1594
1595 cx.notify();
1596
1597 let message_content = self
1598 .message(message_id)
1599 .map(|msg| msg.to_string())
1600 .unwrap_or_default();
1601
1602 cx.background_spawn(async move {
1603 let final_project_snapshot = final_project_snapshot.await;
1604 let serialized_thread = serialized_thread.await?;
1605 let thread_data =
1606 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1607
1608 let rating = match feedback {
1609 ThreadFeedback::Positive => "positive",
1610 ThreadFeedback::Negative => "negative",
1611 };
1612 telemetry::event!(
1613 "Assistant Thread Rated",
1614 rating,
1615 thread_id,
1616 enabled_tool_names,
1617 message_id = message_id.0,
1618 message_content,
1619 thread_data,
1620 final_project_snapshot
1621 );
1622 client.telemetry().flush_events();
1623
1624 Ok(())
1625 })
1626 }
1627
1628 pub fn report_feedback(
1629 &mut self,
1630 feedback: ThreadFeedback,
1631 cx: &mut Context<Self>,
1632 ) -> Task<Result<()>> {
1633 let last_assistant_message_id = self
1634 .messages
1635 .iter()
1636 .rev()
1637 .find(|msg| msg.role == Role::Assistant)
1638 .map(|msg| msg.id);
1639
1640 if let Some(message_id) = last_assistant_message_id {
1641 self.report_message_feedback(message_id, feedback, cx)
1642 } else {
1643 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1644 let serialized_thread = self.serialize(cx);
1645 let thread_id = self.id().clone();
1646 let client = self.project.read(cx).client();
1647 self.feedback = Some(feedback);
1648 cx.notify();
1649
1650 cx.background_spawn(async move {
1651 let final_project_snapshot = final_project_snapshot.await;
1652 let serialized_thread = serialized_thread.await?;
1653 let thread_data = serde_json::to_value(serialized_thread)
1654 .unwrap_or_else(|_| serde_json::Value::Null);
1655
1656 let rating = match feedback {
1657 ThreadFeedback::Positive => "positive",
1658 ThreadFeedback::Negative => "negative",
1659 };
1660 telemetry::event!(
1661 "Assistant Thread Rated",
1662 rating,
1663 thread_id,
1664 thread_data,
1665 final_project_snapshot
1666 );
1667 client.telemetry().flush_events();
1668
1669 Ok(())
1670 })
1671 }
1672 }
1673
1674 /// Create a snapshot of the current project state including git information and unsaved buffers.
1675 fn project_snapshot(
1676 project: Entity<Project>,
1677 cx: &mut Context<Self>,
1678 ) -> Task<Arc<ProjectSnapshot>> {
1679 let git_store = project.read(cx).git_store().clone();
1680 let worktree_snapshots: Vec<_> = project
1681 .read(cx)
1682 .visible_worktrees(cx)
1683 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1684 .collect();
1685
1686 cx.spawn(async move |_, cx| {
1687 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1688
1689 let mut unsaved_buffers = Vec::new();
1690 cx.update(|app_cx| {
1691 let buffer_store = project.read(app_cx).buffer_store();
1692 for buffer_handle in buffer_store.read(app_cx).buffers() {
1693 let buffer = buffer_handle.read(app_cx);
1694 if buffer.is_dirty() {
1695 if let Some(file) = buffer.file() {
1696 let path = file.path().to_string_lossy().to_string();
1697 unsaved_buffers.push(path);
1698 }
1699 }
1700 }
1701 })
1702 .ok();
1703
1704 Arc::new(ProjectSnapshot {
1705 worktree_snapshots,
1706 unsaved_buffer_paths: unsaved_buffers,
1707 timestamp: Utc::now(),
1708 })
1709 })
1710 }
1711
1712 fn worktree_snapshot(
1713 worktree: Entity<project::Worktree>,
1714 git_store: Entity<GitStore>,
1715 cx: &App,
1716 ) -> Task<WorktreeSnapshot> {
1717 cx.spawn(async move |cx| {
1718 // Get worktree path and snapshot
1719 let worktree_info = cx.update(|app_cx| {
1720 let worktree = worktree.read(app_cx);
1721 let path = worktree.abs_path().to_string_lossy().to_string();
1722 let snapshot = worktree.snapshot();
1723 (path, snapshot)
1724 });
1725
1726 let Ok((worktree_path, _snapshot)) = worktree_info else {
1727 return WorktreeSnapshot {
1728 worktree_path: String::new(),
1729 git_state: None,
1730 };
1731 };
1732
1733 let git_state = git_store
1734 .update(cx, |git_store, cx| {
1735 git_store
1736 .repositories()
1737 .values()
1738 .find(|repo| {
1739 repo.read(cx)
1740 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1741 .is_some()
1742 })
1743 .cloned()
1744 })
1745 .ok()
1746 .flatten()
1747 .map(|repo| {
1748 repo.update(cx, |repo, _| {
1749 let current_branch =
1750 repo.branch.as_ref().map(|branch| branch.name.to_string());
1751 repo.send_job(None, |state, _| async move {
1752 let RepositoryState::Local { backend, .. } = state else {
1753 return GitState {
1754 remote_url: None,
1755 head_sha: None,
1756 current_branch,
1757 diff: None,
1758 };
1759 };
1760
1761 let remote_url = backend.remote_url("origin");
1762 let head_sha = backend.head_sha();
1763 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1764
1765 GitState {
1766 remote_url,
1767 head_sha,
1768 current_branch,
1769 diff,
1770 }
1771 })
1772 })
1773 });
1774
1775 let git_state = match git_state {
1776 Some(git_state) => match git_state.ok() {
1777 Some(git_state) => git_state.await.ok(),
1778 None => None,
1779 },
1780 None => None,
1781 };
1782
1783 WorktreeSnapshot {
1784 worktree_path,
1785 git_state,
1786 }
1787 })
1788 }
1789
1790 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1791 let mut markdown = Vec::new();
1792
1793 if let Some(summary) = self.summary() {
1794 writeln!(markdown, "# {summary}\n")?;
1795 };
1796
1797 for message in self.messages() {
1798 writeln!(
1799 markdown,
1800 "## {role}\n",
1801 role = match message.role {
1802 Role::User => "User",
1803 Role::Assistant => "Assistant",
1804 Role::System => "System",
1805 }
1806 )?;
1807
1808 if !message.context.is_empty() {
1809 writeln!(markdown, "{}", message.context)?;
1810 }
1811
1812 for segment in &message.segments {
1813 match segment {
1814 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1815 MessageSegment::Thinking(text) => {
1816 writeln!(markdown, "<think>{}</think>\n", text)?
1817 }
1818 }
1819 }
1820
1821 for tool_use in self.tool_uses_for_message(message.id, cx) {
1822 writeln!(
1823 markdown,
1824 "**Use Tool: {} ({})**",
1825 tool_use.name, tool_use.id
1826 )?;
1827 writeln!(markdown, "```json")?;
1828 writeln!(
1829 markdown,
1830 "{}",
1831 serde_json::to_string_pretty(&tool_use.input)?
1832 )?;
1833 writeln!(markdown, "```")?;
1834 }
1835
1836 for tool_result in self.tool_results_for_message(message.id) {
1837 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1838 if tool_result.is_error {
1839 write!(markdown, " (Error)")?;
1840 }
1841
1842 writeln!(markdown, "**\n")?;
1843 writeln!(markdown, "{}", tool_result.content)?;
1844 }
1845 }
1846
1847 Ok(String::from_utf8_lossy(&markdown).to_string())
1848 }
1849
1850 pub fn keep_edits_in_range(
1851 &mut self,
1852 buffer: Entity<language::Buffer>,
1853 buffer_range: Range<language::Anchor>,
1854 cx: &mut Context<Self>,
1855 ) {
1856 self.action_log.update(cx, |action_log, cx| {
1857 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1858 });
1859 }
1860
1861 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1862 self.action_log
1863 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1864 }
1865
1866 pub fn reject_edits_in_ranges(
1867 &mut self,
1868 buffer: Entity<language::Buffer>,
1869 buffer_ranges: Vec<Range<language::Anchor>>,
1870 cx: &mut Context<Self>,
1871 ) -> Task<Result<()>> {
1872 self.action_log.update(cx, |action_log, cx| {
1873 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1874 })
1875 }
1876
1877 pub fn action_log(&self) -> &Entity<ActionLog> {
1878 &self.action_log
1879 }
1880
1881 pub fn project(&self) -> &Entity<Project> {
1882 &self.project
1883 }
1884
1885 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1886 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1887 return;
1888 }
1889
1890 let now = Instant::now();
1891 if let Some(last) = self.last_auto_capture_at {
1892 if now.duration_since(last).as_secs() < 10 {
1893 return;
1894 }
1895 }
1896
1897 self.last_auto_capture_at = Some(now);
1898
1899 let thread_id = self.id().clone();
1900 let github_login = self
1901 .project
1902 .read(cx)
1903 .user_store()
1904 .read(cx)
1905 .current_user()
1906 .map(|user| user.github_login.clone());
1907 let client = self.project.read(cx).client().clone();
1908 let serialize_task = self.serialize(cx);
1909
1910 cx.background_executor()
1911 .spawn(async move {
1912 if let Ok(serialized_thread) = serialize_task.await {
1913 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1914 telemetry::event!(
1915 "Agent Thread Auto-Captured",
1916 thread_id = thread_id.to_string(),
1917 thread_data = thread_data,
1918 auto_capture_reason = "tracked_user",
1919 github_login = github_login
1920 );
1921
1922 client.telemetry().flush_events();
1923 }
1924 }
1925 })
1926 .detach();
1927 }
1928
1929 pub fn cumulative_token_usage(&self) -> TokenUsage {
1930 self.cumulative_token_usage
1931 }
1932
1933 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1934 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1935 return TotalTokenUsage::default();
1936 };
1937
1938 let max = model.model.max_token_count();
1939
1940 let index = self
1941 .messages
1942 .iter()
1943 .position(|msg| msg.id == message_id)
1944 .unwrap_or(0);
1945
1946 if index == 0 {
1947 return TotalTokenUsage { total: 0, max };
1948 }
1949
1950 let token_usage = &self
1951 .request_token_usage
1952 .get(index - 1)
1953 .cloned()
1954 .unwrap_or_default();
1955
1956 TotalTokenUsage {
1957 total: token_usage.total_tokens() as usize,
1958 max,
1959 }
1960 }
1961
1962 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1963 let model_registry = LanguageModelRegistry::read_global(cx);
1964 let Some(model) = model_registry.default_model() else {
1965 return TotalTokenUsage::default();
1966 };
1967
1968 let max = model.model.max_token_count();
1969
1970 if let Some(exceeded_error) = &self.exceeded_window_error {
1971 if model.model.id() == exceeded_error.model_id {
1972 return TotalTokenUsage {
1973 total: exceeded_error.token_count,
1974 max,
1975 };
1976 }
1977 }
1978
1979 let total = self
1980 .token_usage_at_last_message()
1981 .unwrap_or_default()
1982 .total_tokens() as usize;
1983
1984 TotalTokenUsage { total, max }
1985 }
1986
1987 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
1988 self.request_token_usage
1989 .get(self.messages.len().saturating_sub(1))
1990 .or_else(|| self.request_token_usage.last())
1991 .cloned()
1992 }
1993
1994 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
1995 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
1996 self.request_token_usage
1997 .resize(self.messages.len(), placeholder);
1998
1999 if let Some(last) = self.request_token_usage.last_mut() {
2000 *last = token_usage;
2001 }
2002 }
2003
2004 pub fn deny_tool_use(
2005 &mut self,
2006 tool_use_id: LanguageModelToolUseId,
2007 tool_name: Arc<str>,
2008 cx: &mut Context<Self>,
2009 ) {
2010 let err = Err(anyhow::anyhow!(
2011 "Permission to run tool action denied by user"
2012 ));
2013
2014 self.tool_use
2015 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2016 self.tool_finished(tool_use_id.clone(), None, true, cx);
2017 }
2018}
2019
2020#[derive(Debug, Clone, Error)]
2021pub enum ThreadError {
2022 #[error("Payment required")]
2023 PaymentRequired,
2024 #[error("Max monthly spend reached")]
2025 MaxMonthlySpendReached,
2026 #[error("Model request limit reached")]
2027 ModelRequestLimitReached { plan: Plan },
2028 #[error("Message {header}: {message}")]
2029 Message {
2030 header: SharedString,
2031 message: SharedString,
2032 },
2033}
2034
2035#[derive(Debug, Clone)]
2036pub enum ThreadEvent {
2037 ShowError(ThreadError),
2038 StreamedCompletion,
2039 StreamedAssistantText(MessageId, String),
2040 StreamedAssistantThinking(MessageId, String),
2041 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2042 MessageAdded(MessageId),
2043 MessageEdited(MessageId),
2044 MessageDeleted(MessageId),
2045 SummaryGenerated,
2046 SummaryChanged,
2047 UsePendingTools {
2048 tool_uses: Vec<PendingToolUse>,
2049 },
2050 ToolFinished {
2051 #[allow(unused)]
2052 tool_use_id: LanguageModelToolUseId,
2053 /// The pending tool use that corresponds to this tool.
2054 pending_tool_use: Option<PendingToolUse>,
2055 },
2056 CheckpointChanged,
2057 ToolConfirmationNeeded,
2058}
2059
2060impl EventEmitter<ThreadEvent> for Thread {}
2061
2062struct PendingCompletion {
2063 id: usize,
2064 _task: Task<()>,
2065}
2066
2067#[cfg(test)]
2068mod tests {
2069 use super::*;
2070 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2071 use assistant_settings::AssistantSettings;
2072 use context_server::ContextServerSettings;
2073 use editor::EditorSettings;
2074 use gpui::TestAppContext;
2075 use project::{FakeFs, Project};
2076 use prompt_store::PromptBuilder;
2077 use serde_json::json;
2078 use settings::{Settings, SettingsStore};
2079 use std::sync::Arc;
2080 use theme::ThemeSettings;
2081 use util::path;
2082 use workspace::Workspace;
2083
2084 #[gpui::test]
2085 async fn test_message_with_context(cx: &mut TestAppContext) {
2086 init_test_settings(cx);
2087
2088 let project = create_test_project(
2089 cx,
2090 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2091 )
2092 .await;
2093
2094 let (_workspace, _thread_store, thread, context_store) =
2095 setup_test_environment(cx, project.clone()).await;
2096
2097 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2098 .await
2099 .unwrap();
2100
2101 let context =
2102 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2103
2104 // Insert user message with context
2105 let message_id = thread.update(cx, |thread, cx| {
2106 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2107 });
2108
2109 // Check content and context in message object
2110 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2111
2112 // Use different path format strings based on platform for the test
2113 #[cfg(windows)]
2114 let path_part = r"test\code.rs";
2115 #[cfg(not(windows))]
2116 let path_part = "test/code.rs";
2117
2118 let expected_context = format!(
2119 r#"
2120<context>
2121The following items were attached by the user. You don't need to use other tools to read them.
2122
2123<files>
2124```rs {path_part}
2125fn main() {{
2126 println!("Hello, world!");
2127}}
2128```
2129</files>
2130</context>
2131"#
2132 );
2133
2134 assert_eq!(message.role, Role::User);
2135 assert_eq!(message.segments.len(), 1);
2136 assert_eq!(
2137 message.segments[0],
2138 MessageSegment::Text("Please explain this code".to_string())
2139 );
2140 assert_eq!(message.context, expected_context);
2141
2142 // Check message in request
2143 let request = thread.read_with(cx, |thread, cx| {
2144 thread.to_completion_request(RequestKind::Chat, cx)
2145 });
2146
2147 assert_eq!(request.messages.len(), 2);
2148 let expected_full_message = format!("{}Please explain this code", expected_context);
2149 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2150 }
2151
2152 #[gpui::test]
2153 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2154 init_test_settings(cx);
2155
2156 let project = create_test_project(
2157 cx,
2158 json!({
2159 "file1.rs": "fn function1() {}\n",
2160 "file2.rs": "fn function2() {}\n",
2161 "file3.rs": "fn function3() {}\n",
2162 }),
2163 )
2164 .await;
2165
2166 let (_, _thread_store, thread, context_store) =
2167 setup_test_environment(cx, project.clone()).await;
2168
2169 // Open files individually
2170 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2171 .await
2172 .unwrap();
2173 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2174 .await
2175 .unwrap();
2176 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2177 .await
2178 .unwrap();
2179
2180 // Get the context objects
2181 let contexts = context_store.update(cx, |store, _| store.context().clone());
2182 assert_eq!(contexts.len(), 3);
2183
2184 // First message with context 1
2185 let message1_id = thread.update(cx, |thread, cx| {
2186 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2187 });
2188
2189 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2190 let message2_id = thread.update(cx, |thread, cx| {
2191 thread.insert_user_message(
2192 "Message 2",
2193 vec![contexts[0].clone(), contexts[1].clone()],
2194 None,
2195 cx,
2196 )
2197 });
2198
2199 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2200 let message3_id = thread.update(cx, |thread, cx| {
2201 thread.insert_user_message(
2202 "Message 3",
2203 vec![
2204 contexts[0].clone(),
2205 contexts[1].clone(),
2206 contexts[2].clone(),
2207 ],
2208 None,
2209 cx,
2210 )
2211 });
2212
2213 // Check what contexts are included in each message
2214 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2215 (
2216 thread.message(message1_id).unwrap().clone(),
2217 thread.message(message2_id).unwrap().clone(),
2218 thread.message(message3_id).unwrap().clone(),
2219 )
2220 });
2221
2222 // First message should include context 1
2223 assert!(message1.context.contains("file1.rs"));
2224
2225 // Second message should include only context 2 (not 1)
2226 assert!(!message2.context.contains("file1.rs"));
2227 assert!(message2.context.contains("file2.rs"));
2228
2229 // Third message should include only context 3 (not 1 or 2)
2230 assert!(!message3.context.contains("file1.rs"));
2231 assert!(!message3.context.contains("file2.rs"));
2232 assert!(message3.context.contains("file3.rs"));
2233
2234 // Check entire request to make sure all contexts are properly included
2235 let request = thread.read_with(cx, |thread, cx| {
2236 thread.to_completion_request(RequestKind::Chat, cx)
2237 });
2238
2239 // The request should contain all 3 messages
2240 assert_eq!(request.messages.len(), 4);
2241
2242 // Check that the contexts are properly formatted in each message
2243 assert!(request.messages[1].string_contents().contains("file1.rs"));
2244 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2245 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2246
2247 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2248 assert!(request.messages[2].string_contents().contains("file2.rs"));
2249 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2250
2251 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2252 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2253 assert!(request.messages[3].string_contents().contains("file3.rs"));
2254 }
2255
2256 #[gpui::test]
2257 async fn test_message_without_files(cx: &mut TestAppContext) {
2258 init_test_settings(cx);
2259
2260 let project = create_test_project(
2261 cx,
2262 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2263 )
2264 .await;
2265
2266 let (_, _thread_store, thread, _context_store) =
2267 setup_test_environment(cx, project.clone()).await;
2268
2269 // Insert user message without any context (empty context vector)
2270 let message_id = thread.update(cx, |thread, cx| {
2271 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2272 });
2273
2274 // Check content and context in message object
2275 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2276
2277 // Context should be empty when no files are included
2278 assert_eq!(message.role, Role::User);
2279 assert_eq!(message.segments.len(), 1);
2280 assert_eq!(
2281 message.segments[0],
2282 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2283 );
2284 assert_eq!(message.context, "");
2285
2286 // Check message in request
2287 let request = thread.read_with(cx, |thread, cx| {
2288 thread.to_completion_request(RequestKind::Chat, cx)
2289 });
2290
2291 assert_eq!(request.messages.len(), 2);
2292 assert_eq!(
2293 request.messages[1].string_contents(),
2294 "What is the best way to learn Rust?"
2295 );
2296
2297 // Add second message, also without context
2298 let message2_id = thread.update(cx, |thread, cx| {
2299 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2300 });
2301
2302 let message2 =
2303 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2304 assert_eq!(message2.context, "");
2305
2306 // Check that both messages appear in the request
2307 let request = thread.read_with(cx, |thread, cx| {
2308 thread.to_completion_request(RequestKind::Chat, cx)
2309 });
2310
2311 assert_eq!(request.messages.len(), 3);
2312 assert_eq!(
2313 request.messages[1].string_contents(),
2314 "What is the best way to learn Rust?"
2315 );
2316 assert_eq!(
2317 request.messages[2].string_contents(),
2318 "Are there any good books?"
2319 );
2320 }
2321
2322 #[gpui::test]
2323 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2324 init_test_settings(cx);
2325
2326 let project = create_test_project(
2327 cx,
2328 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2329 )
2330 .await;
2331
2332 let (_workspace, _thread_store, thread, context_store) =
2333 setup_test_environment(cx, project.clone()).await;
2334
2335 // Open buffer and add it to context
2336 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2337 .await
2338 .unwrap();
2339
2340 let context =
2341 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2342
2343 // Insert user message with the buffer as context
2344 thread.update(cx, |thread, cx| {
2345 thread.insert_user_message("Explain this code", vec![context], None, cx)
2346 });
2347
2348 // Create a request and check that it doesn't have a stale buffer warning yet
2349 let initial_request = thread.read_with(cx, |thread, cx| {
2350 thread.to_completion_request(RequestKind::Chat, cx)
2351 });
2352
2353 // Make sure we don't have a stale file warning yet
2354 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2355 msg.string_contents()
2356 .contains("These files changed since last read:")
2357 });
2358 assert!(
2359 !has_stale_warning,
2360 "Should not have stale buffer warning before buffer is modified"
2361 );
2362
2363 // Modify the buffer
2364 buffer.update(cx, |buffer, cx| {
2365 // Find a position at the end of line 1
2366 buffer.edit(
2367 [(1..1, "\n println!(\"Added a new line\");\n")],
2368 None,
2369 cx,
2370 );
2371 });
2372
2373 // Insert another user message without context
2374 thread.update(cx, |thread, cx| {
2375 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2376 });
2377
2378 // Create a new request and check for the stale buffer warning
2379 let new_request = thread.read_with(cx, |thread, cx| {
2380 thread.to_completion_request(RequestKind::Chat, cx)
2381 });
2382
2383 // We should have a stale file warning as the last message
2384 let last_message = new_request
2385 .messages
2386 .last()
2387 .expect("Request should have messages");
2388
2389 // The last message should be the stale buffer notification
2390 assert_eq!(last_message.role, Role::User);
2391
2392 // Check the exact content of the message
2393 let expected_content = "These files changed since last read:\n- code.rs\n";
2394 assert_eq!(
2395 last_message.string_contents(),
2396 expected_content,
2397 "Last message should be exactly the stale buffer notification"
2398 );
2399 }
2400
2401 fn init_test_settings(cx: &mut TestAppContext) {
2402 cx.update(|cx| {
2403 let settings_store = SettingsStore::test(cx);
2404 cx.set_global(settings_store);
2405 language::init(cx);
2406 Project::init_settings(cx);
2407 AssistantSettings::register(cx);
2408 thread_store::init(cx);
2409 workspace::init_settings(cx);
2410 ThemeSettings::register(cx);
2411 ContextServerSettings::register(cx);
2412 EditorSettings::register(cx);
2413 });
2414 }
2415
2416 // Helper to create a test project with test files
2417 async fn create_test_project(
2418 cx: &mut TestAppContext,
2419 files: serde_json::Value,
2420 ) -> Entity<Project> {
2421 let fs = FakeFs::new(cx.executor());
2422 fs.insert_tree(path!("/test"), files).await;
2423 Project::test(fs, [path!("/test").as_ref()], cx).await
2424 }
2425
2426 async fn setup_test_environment(
2427 cx: &mut TestAppContext,
2428 project: Entity<Project>,
2429 ) -> (
2430 Entity<Workspace>,
2431 Entity<ThreadStore>,
2432 Entity<Thread>,
2433 Entity<ContextStore>,
2434 ) {
2435 let (workspace, cx) =
2436 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2437
2438 let thread_store = cx
2439 .update(|_, cx| {
2440 ThreadStore::load(
2441 project.clone(),
2442 cx.new(|_| ToolWorkingSet::default()),
2443 Arc::new(PromptBuilder::new(None).unwrap()),
2444 cx,
2445 )
2446 })
2447 .await;
2448
2449 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2450 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2451
2452 (workspace, thread_store, thread, context_store)
2453 }
2454
2455 async fn add_file_to_context(
2456 project: &Entity<Project>,
2457 context_store: &Entity<ContextStore>,
2458 path: &str,
2459 cx: &mut TestAppContext,
2460 ) -> Result<Entity<language::Buffer>> {
2461 let buffer_path = project
2462 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2463 .unwrap();
2464
2465 let buffer = project
2466 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2467 .await
2468 .unwrap();
2469
2470 context_store
2471 .update(cx, |store, cx| {
2472 store.add_file_from_buffer(buffer.clone(), cx)
2473 })
2474 .await?;
2475
2476 Ok(buffer)
2477 }
2478}