1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(Debug, Clone, Copy)]
44pub enum RequestKind {
45 Chat,
46 /// Used when summarizing a thread.
47 Summarize,
48}
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
74pub struct MessageId(pub(crate) usize);
75
76impl MessageId {
77 fn post_inc(&mut self) -> Self {
78 Self(post_inc(&mut self.0))
79 }
80}
81
82/// A message in a [`Thread`].
83#[derive(Debug, Clone)]
84pub struct Message {
85 pub id: MessageId,
86 pub role: Role,
87 pub segments: Vec<MessageSegment>,
88 pub context: String,
89}
90
91impl Message {
92 /// Returns whether the message contains any meaningful text that should be displayed
93 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
94 pub fn should_display_content(&self) -> bool {
95 self.segments.iter().all(|segment| segment.should_display())
96 }
97
98 pub fn push_thinking(&mut self, text: &str) {
99 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
100 segment.push_str(text);
101 } else {
102 self.segments
103 .push(MessageSegment::Thinking(text.to_string()));
104 }
105 }
106
107 pub fn push_text(&mut self, text: &str) {
108 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
109 segment.push_str(text);
110 } else {
111 self.segments.push(MessageSegment::Text(text.to_string()));
112 }
113 }
114
115 pub fn to_string(&self) -> String {
116 let mut result = String::new();
117
118 if !self.context.is_empty() {
119 result.push_str(&self.context);
120 }
121
122 for segment in &self.segments {
123 match segment {
124 MessageSegment::Text(text) => result.push_str(text),
125 MessageSegment::Thinking(text) => {
126 result.push_str("<think>");
127 result.push_str(text);
128 result.push_str("</think>");
129 }
130 }
131 }
132
133 result
134 }
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum MessageSegment {
139 Text(String),
140 Thinking(String),
141}
142
143impl MessageSegment {
144 pub fn text_mut(&mut self) -> &mut String {
145 match self {
146 Self::Text(text) => text,
147 Self::Thinking(text) => text,
148 }
149 }
150
151 pub fn should_display(&self) -> bool {
152 // We add USING_TOOL_MARKER when making a request that includes tool uses
153 // without non-whitespace text around them, and this can cause the model
154 // to mimic the pattern, so we consider those segments not displayable.
155 match self {
156 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
157 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
158 }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct ProjectSnapshot {
164 pub worktree_snapshots: Vec<WorktreeSnapshot>,
165 pub unsaved_buffer_paths: Vec<String>,
166 pub timestamp: DateTime<Utc>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct WorktreeSnapshot {
171 pub worktree_path: String,
172 pub git_state: Option<GitState>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct GitState {
177 pub remote_url: Option<String>,
178 pub head_sha: Option<String>,
179 pub current_branch: Option<String>,
180 pub diff: Option<String>,
181}
182
183#[derive(Clone)]
184pub struct ThreadCheckpoint {
185 message_id: MessageId,
186 git_checkpoint: GitStoreCheckpoint,
187}
188
189#[derive(Copy, Clone, Debug, PartialEq, Eq)]
190pub enum ThreadFeedback {
191 Positive,
192 Negative,
193}
194
195pub enum LastRestoreCheckpoint {
196 Pending {
197 message_id: MessageId,
198 },
199 Error {
200 message_id: MessageId,
201 error: String,
202 },
203}
204
205impl LastRestoreCheckpoint {
206 pub fn message_id(&self) -> MessageId {
207 match self {
208 LastRestoreCheckpoint::Pending { message_id } => *message_id,
209 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
210 }
211 }
212}
213
214#[derive(Clone, Debug, Default, Serialize, Deserialize)]
215pub enum DetailedSummaryState {
216 #[default]
217 NotGenerated,
218 Generating {
219 message_id: MessageId,
220 },
221 Generated {
222 text: SharedString,
223 message_id: MessageId,
224 },
225}
226
227#[derive(Default)]
228pub struct TotalTokenUsage {
229 pub total: usize,
230 pub max: usize,
231}
232
233impl TotalTokenUsage {
234 pub fn ratio(&self) -> TokenUsageRatio {
235 #[cfg(debug_assertions)]
236 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
237 .unwrap_or("0.8".to_string())
238 .parse()
239 .unwrap();
240 #[cfg(not(debug_assertions))]
241 let warning_threshold: f32 = 0.8;
242
243 if self.total >= self.max {
244 TokenUsageRatio::Exceeded
245 } else if self.total as f32 / self.max as f32 >= warning_threshold {
246 TokenUsageRatio::Warning
247 } else {
248 TokenUsageRatio::Normal
249 }
250 }
251
252 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
253 TotalTokenUsage {
254 total: self.total + tokens,
255 max: self.max,
256 }
257 }
258}
259
260#[derive(Debug, Default, PartialEq, Eq)]
261pub enum TokenUsageRatio {
262 #[default]
263 Normal,
264 Warning,
265 Exceeded,
266}
267
268/// A thread of conversation with the LLM.
269pub struct Thread {
270 id: ThreadId,
271 updated_at: DateTime<Utc>,
272 summary: Option<SharedString>,
273 pending_summary: Task<Option<()>>,
274 detailed_summary_state: DetailedSummaryState,
275 messages: Vec<Message>,
276 next_message_id: MessageId,
277 context: BTreeMap<ContextId, AssistantContext>,
278 context_by_message: HashMap<MessageId, Vec<ContextId>>,
279 project_context: SharedProjectContext,
280 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
281 completion_count: usize,
282 pending_completions: Vec<PendingCompletion>,
283 project: Entity<Project>,
284 prompt_builder: Arc<PromptBuilder>,
285 tools: Entity<ToolWorkingSet>,
286 tool_use: ToolUseState,
287 action_log: Entity<ActionLog>,
288 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
289 pending_checkpoint: Option<ThreadCheckpoint>,
290 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
291 request_token_usage: Vec<TokenUsage>,
292 cumulative_token_usage: TokenUsage,
293 exceeded_window_error: Option<ExceededWindowError>,
294 feedback: Option<ThreadFeedback>,
295 message_feedback: HashMap<MessageId, ThreadFeedback>,
296 last_auto_capture_at: Option<Instant>,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct ExceededWindowError {
301 /// Model used when last message exceeded context window
302 model_id: LanguageModelId,
303 /// Token count including last message
304 token_count: usize,
305}
306
307impl Thread {
308 pub fn new(
309 project: Entity<Project>,
310 tools: Entity<ToolWorkingSet>,
311 prompt_builder: Arc<PromptBuilder>,
312 system_prompt: SharedProjectContext,
313 cx: &mut Context<Self>,
314 ) -> Self {
315 Self {
316 id: ThreadId::new(),
317 updated_at: Utc::now(),
318 summary: None,
319 pending_summary: Task::ready(None),
320 detailed_summary_state: DetailedSummaryState::NotGenerated,
321 messages: Vec::new(),
322 next_message_id: MessageId(0),
323 context: BTreeMap::default(),
324 context_by_message: HashMap::default(),
325 project_context: system_prompt,
326 checkpoints_by_message: HashMap::default(),
327 completion_count: 0,
328 pending_completions: Vec::new(),
329 project: project.clone(),
330 prompt_builder,
331 tools: tools.clone(),
332 last_restore_checkpoint: None,
333 pending_checkpoint: None,
334 tool_use: ToolUseState::new(tools.clone()),
335 action_log: cx.new(|_| ActionLog::new(project.clone())),
336 initial_project_snapshot: {
337 let project_snapshot = Self::project_snapshot(project, cx);
338 cx.foreground_executor()
339 .spawn(async move { Some(project_snapshot.await) })
340 .shared()
341 },
342 request_token_usage: Vec::new(),
343 cumulative_token_usage: TokenUsage::default(),
344 exceeded_window_error: None,
345 feedback: None,
346 message_feedback: HashMap::default(),
347 last_auto_capture_at: None,
348 }
349 }
350
351 pub fn deserialize(
352 id: ThreadId,
353 serialized: SerializedThread,
354 project: Entity<Project>,
355 tools: Entity<ToolWorkingSet>,
356 prompt_builder: Arc<PromptBuilder>,
357 project_context: SharedProjectContext,
358 cx: &mut Context<Self>,
359 ) -> Self {
360 let next_message_id = MessageId(
361 serialized
362 .messages
363 .last()
364 .map(|message| message.id.0 + 1)
365 .unwrap_or(0),
366 );
367 let tool_use =
368 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
369
370 Self {
371 id,
372 updated_at: serialized.updated_at,
373 summary: Some(serialized.summary),
374 pending_summary: Task::ready(None),
375 detailed_summary_state: serialized.detailed_summary_state,
376 messages: serialized
377 .messages
378 .into_iter()
379 .map(|message| Message {
380 id: message.id,
381 role: message.role,
382 segments: message
383 .segments
384 .into_iter()
385 .map(|segment| match segment {
386 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
387 SerializedMessageSegment::Thinking { text } => {
388 MessageSegment::Thinking(text)
389 }
390 })
391 .collect(),
392 context: message.context,
393 })
394 .collect(),
395 next_message_id,
396 context: BTreeMap::default(),
397 context_by_message: HashMap::default(),
398 project_context,
399 checkpoints_by_message: HashMap::default(),
400 completion_count: 0,
401 pending_completions: Vec::new(),
402 last_restore_checkpoint: None,
403 pending_checkpoint: None,
404 project: project.clone(),
405 prompt_builder,
406 tools,
407 tool_use,
408 action_log: cx.new(|_| ActionLog::new(project)),
409 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
410 request_token_usage: serialized.request_token_usage,
411 cumulative_token_usage: serialized.cumulative_token_usage,
412 exceeded_window_error: None,
413 feedback: None,
414 message_feedback: HashMap::default(),
415 last_auto_capture_at: None,
416 }
417 }
418
419 pub fn id(&self) -> &ThreadId {
420 &self.id
421 }
422
423 pub fn is_empty(&self) -> bool {
424 self.messages.is_empty()
425 }
426
427 pub fn updated_at(&self) -> DateTime<Utc> {
428 self.updated_at
429 }
430
431 pub fn touch_updated_at(&mut self) {
432 self.updated_at = Utc::now();
433 }
434
435 pub fn summary(&self) -> Option<SharedString> {
436 self.summary.clone()
437 }
438
439 pub fn project_context(&self) -> SharedProjectContext {
440 self.project_context.clone()
441 }
442
443 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
444
445 pub fn summary_or_default(&self) -> SharedString {
446 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
447 }
448
449 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
450 let Some(current_summary) = &self.summary else {
451 // Don't allow setting summary until generated
452 return;
453 };
454
455 let mut new_summary = new_summary.into();
456
457 if new_summary.is_empty() {
458 new_summary = Self::DEFAULT_SUMMARY;
459 }
460
461 if current_summary != &new_summary {
462 self.summary = Some(new_summary);
463 cx.emit(ThreadEvent::SummaryChanged);
464 }
465 }
466
467 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
468 self.latest_detailed_summary()
469 .unwrap_or_else(|| self.text().into())
470 }
471
472 fn latest_detailed_summary(&self) -> Option<SharedString> {
473 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
474 Some(text.clone())
475 } else {
476 None
477 }
478 }
479
480 pub fn message(&self, id: MessageId) -> Option<&Message> {
481 self.messages.iter().find(|message| message.id == id)
482 }
483
484 pub fn messages(&self) -> impl Iterator<Item = &Message> {
485 self.messages.iter()
486 }
487
488 pub fn is_generating(&self) -> bool {
489 !self.pending_completions.is_empty() || !self.all_tools_finished()
490 }
491
492 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
493 &self.tools
494 }
495
496 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
497 self.tool_use
498 .pending_tool_uses()
499 .into_iter()
500 .find(|tool_use| &tool_use.id == id)
501 }
502
503 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
504 self.tool_use
505 .pending_tool_uses()
506 .into_iter()
507 .filter(|tool_use| tool_use.status.needs_confirmation())
508 }
509
510 pub fn has_pending_tool_uses(&self) -> bool {
511 !self.tool_use.pending_tool_uses().is_empty()
512 }
513
514 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
515 self.checkpoints_by_message.get(&id).cloned()
516 }
517
518 pub fn restore_checkpoint(
519 &mut self,
520 checkpoint: ThreadCheckpoint,
521 cx: &mut Context<Self>,
522 ) -> Task<Result<()>> {
523 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
524 message_id: checkpoint.message_id,
525 });
526 cx.emit(ThreadEvent::CheckpointChanged);
527 cx.notify();
528
529 let git_store = self.project().read(cx).git_store().clone();
530 let restore = git_store.update(cx, |git_store, cx| {
531 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
532 });
533
534 cx.spawn(async move |this, cx| {
535 let result = restore.await;
536 this.update(cx, |this, cx| {
537 if let Err(err) = result.as_ref() {
538 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
539 message_id: checkpoint.message_id,
540 error: err.to_string(),
541 });
542 } else {
543 this.truncate(checkpoint.message_id, cx);
544 this.last_restore_checkpoint = None;
545 }
546 this.pending_checkpoint = None;
547 cx.emit(ThreadEvent::CheckpointChanged);
548 cx.notify();
549 })?;
550 result
551 })
552 }
553
554 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
555 let pending_checkpoint = if self.is_generating() {
556 return;
557 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
558 checkpoint
559 } else {
560 return;
561 };
562
563 let git_store = self.project.read(cx).git_store().clone();
564 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
565 cx.spawn(async move |this, cx| match final_checkpoint.await {
566 Ok(final_checkpoint) => {
567 let equal = git_store
568 .update(cx, |store, cx| {
569 store.compare_checkpoints(
570 pending_checkpoint.git_checkpoint.clone(),
571 final_checkpoint.clone(),
572 cx,
573 )
574 })?
575 .await
576 .unwrap_or(false);
577
578 if equal {
579 git_store
580 .update(cx, |store, cx| {
581 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
582 })?
583 .detach();
584 } else {
585 this.update(cx, |this, cx| {
586 this.insert_checkpoint(pending_checkpoint, cx)
587 })?;
588 }
589
590 git_store
591 .update(cx, |store, cx| {
592 store.delete_checkpoint(final_checkpoint, cx)
593 })?
594 .detach();
595
596 Ok(())
597 }
598 Err(_) => this.update(cx, |this, cx| {
599 this.insert_checkpoint(pending_checkpoint, cx)
600 }),
601 })
602 .detach();
603 }
604
605 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
606 self.checkpoints_by_message
607 .insert(checkpoint.message_id, checkpoint);
608 cx.emit(ThreadEvent::CheckpointChanged);
609 cx.notify();
610 }
611
612 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
613 self.last_restore_checkpoint.as_ref()
614 }
615
616 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
617 let Some(message_ix) = self
618 .messages
619 .iter()
620 .rposition(|message| message.id == message_id)
621 else {
622 return;
623 };
624 for deleted_message in self.messages.drain(message_ix..) {
625 self.context_by_message.remove(&deleted_message.id);
626 self.checkpoints_by_message.remove(&deleted_message.id);
627 }
628 cx.notify();
629 }
630
631 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
632 self.context_by_message
633 .get(&id)
634 .into_iter()
635 .flat_map(|context| {
636 context
637 .iter()
638 .filter_map(|context_id| self.context.get(&context_id))
639 })
640 }
641
642 /// Returns whether all of the tool uses have finished running.
643 pub fn all_tools_finished(&self) -> bool {
644 // If the only pending tool uses left are the ones with errors, then
645 // that means that we've finished running all of the pending tools.
646 self.tool_use
647 .pending_tool_uses()
648 .iter()
649 .all(|tool_use| tool_use.status.is_error())
650 }
651
652 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
653 self.tool_use.tool_uses_for_message(id, cx)
654 }
655
656 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
657 self.tool_use.tool_results_for_message(id)
658 }
659
660 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
661 self.tool_use.tool_result(id)
662 }
663
664 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
665 Some(&self.tool_use.tool_result(id)?.content)
666 }
667
668 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
669 self.tool_use.tool_result_card(id).cloned()
670 }
671
672 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
673 self.tool_use.message_has_tool_results(message_id)
674 }
675
676 /// Filter out contexts that have already been included in previous messages
677 pub fn filter_new_context<'a>(
678 &self,
679 context: impl Iterator<Item = &'a AssistantContext>,
680 ) -> impl Iterator<Item = &'a AssistantContext> {
681 context.filter(|ctx| self.is_context_new(ctx))
682 }
683
684 fn is_context_new(&self, context: &AssistantContext) -> bool {
685 !self.context.contains_key(&context.id())
686 }
687
688 pub fn insert_user_message(
689 &mut self,
690 text: impl Into<String>,
691 context: Vec<AssistantContext>,
692 git_checkpoint: Option<GitStoreCheckpoint>,
693 cx: &mut Context<Self>,
694 ) -> MessageId {
695 let text = text.into();
696
697 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
698
699 let new_context: Vec<_> = context
700 .into_iter()
701 .filter(|ctx| self.is_context_new(ctx))
702 .collect();
703
704 if !new_context.is_empty() {
705 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
706 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
707 message.context = context_string;
708 }
709 }
710
711 self.action_log.update(cx, |log, cx| {
712 // Track all buffers added as context
713 for ctx in &new_context {
714 match ctx {
715 AssistantContext::File(file_ctx) => {
716 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
717 }
718 AssistantContext::Directory(dir_ctx) => {
719 for context_buffer in &dir_ctx.context_buffers {
720 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
721 }
722 }
723 AssistantContext::Symbol(symbol_ctx) => {
724 log.buffer_added_as_context(
725 symbol_ctx.context_symbol.buffer.clone(),
726 cx,
727 );
728 }
729 AssistantContext::Excerpt(excerpt_context) => {
730 log.buffer_added_as_context(
731 excerpt_context.context_buffer.buffer.clone(),
732 cx,
733 );
734 }
735 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
736 }
737 }
738 });
739 }
740
741 let context_ids = new_context
742 .iter()
743 .map(|context| context.id())
744 .collect::<Vec<_>>();
745 self.context.extend(
746 new_context
747 .into_iter()
748 .map(|context| (context.id(), context)),
749 );
750 self.context_by_message.insert(message_id, context_ids);
751
752 if let Some(git_checkpoint) = git_checkpoint {
753 self.pending_checkpoint = Some(ThreadCheckpoint {
754 message_id,
755 git_checkpoint,
756 });
757 }
758
759 self.auto_capture_telemetry(cx);
760
761 message_id
762 }
763
764 pub fn insert_message(
765 &mut self,
766 role: Role,
767 segments: Vec<MessageSegment>,
768 cx: &mut Context<Self>,
769 ) -> MessageId {
770 let id = self.next_message_id.post_inc();
771 self.messages.push(Message {
772 id,
773 role,
774 segments,
775 context: String::new(),
776 });
777 self.touch_updated_at();
778 cx.emit(ThreadEvent::MessageAdded(id));
779 id
780 }
781
782 pub fn edit_message(
783 &mut self,
784 id: MessageId,
785 new_role: Role,
786 new_segments: Vec<MessageSegment>,
787 cx: &mut Context<Self>,
788 ) -> bool {
789 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
790 return false;
791 };
792 message.role = new_role;
793 message.segments = new_segments;
794 self.touch_updated_at();
795 cx.emit(ThreadEvent::MessageEdited(id));
796 true
797 }
798
799 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
800 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
801 return false;
802 };
803 self.messages.remove(index);
804 self.context_by_message.remove(&id);
805 self.touch_updated_at();
806 cx.emit(ThreadEvent::MessageDeleted(id));
807 true
808 }
809
810 /// Returns the representation of this [`Thread`] in a textual form.
811 ///
812 /// This is the representation we use when attaching a thread as context to another thread.
813 pub fn text(&self) -> String {
814 let mut text = String::new();
815
816 for message in &self.messages {
817 text.push_str(match message.role {
818 language_model::Role::User => "User:",
819 language_model::Role::Assistant => "Assistant:",
820 language_model::Role::System => "System:",
821 });
822 text.push('\n');
823
824 for segment in &message.segments {
825 match segment {
826 MessageSegment::Text(content) => text.push_str(content),
827 MessageSegment::Thinking(content) => {
828 text.push_str(&format!("<think>{}</think>", content))
829 }
830 }
831 }
832 text.push('\n');
833 }
834
835 text
836 }
837
838 /// Serializes this thread into a format for storage or telemetry.
839 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
840 let initial_project_snapshot = self.initial_project_snapshot.clone();
841 cx.spawn(async move |this, cx| {
842 let initial_project_snapshot = initial_project_snapshot.await;
843 this.read_with(cx, |this, cx| SerializedThread {
844 version: SerializedThread::VERSION.to_string(),
845 summary: this.summary_or_default(),
846 updated_at: this.updated_at(),
847 messages: this
848 .messages()
849 .map(|message| SerializedMessage {
850 id: message.id,
851 role: message.role,
852 segments: message
853 .segments
854 .iter()
855 .map(|segment| match segment {
856 MessageSegment::Text(text) => {
857 SerializedMessageSegment::Text { text: text.clone() }
858 }
859 MessageSegment::Thinking(text) => {
860 SerializedMessageSegment::Thinking { text: text.clone() }
861 }
862 })
863 .collect(),
864 tool_uses: this
865 .tool_uses_for_message(message.id, cx)
866 .into_iter()
867 .map(|tool_use| SerializedToolUse {
868 id: tool_use.id,
869 name: tool_use.name,
870 input: tool_use.input,
871 })
872 .collect(),
873 tool_results: this
874 .tool_results_for_message(message.id)
875 .into_iter()
876 .map(|tool_result| SerializedToolResult {
877 tool_use_id: tool_result.tool_use_id.clone(),
878 is_error: tool_result.is_error,
879 content: tool_result.content.clone(),
880 })
881 .collect(),
882 context: message.context.clone(),
883 })
884 .collect(),
885 initial_project_snapshot,
886 cumulative_token_usage: this.cumulative_token_usage,
887 request_token_usage: this.request_token_usage.clone(),
888 detailed_summary_state: this.detailed_summary_state.clone(),
889 exceeded_window_error: this.exceeded_window_error.clone(),
890 })
891 })
892 }
893
894 pub fn send_to_model(
895 &mut self,
896 model: Arc<dyn LanguageModel>,
897 request_kind: RequestKind,
898 cx: &mut Context<Self>,
899 ) {
900 let mut request = self.to_completion_request(request_kind, cx);
901 if model.supports_tools() {
902 request.tools = {
903 let mut tools = Vec::new();
904 tools.extend(
905 self.tools()
906 .read(cx)
907 .enabled_tools(cx)
908 .into_iter()
909 .filter_map(|tool| {
910 // Skip tools that cannot be supported
911 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
912 Some(LanguageModelRequestTool {
913 name: tool.name(),
914 description: tool.description(),
915 input_schema,
916 })
917 }),
918 );
919
920 tools
921 };
922 }
923
924 self.stream_completion(request, model, cx);
925 }
926
927 pub fn used_tools_since_last_user_message(&self) -> bool {
928 for message in self.messages.iter().rev() {
929 if self.tool_use.message_has_tool_results(message.id) {
930 return true;
931 } else if message.role == Role::User {
932 return false;
933 }
934 }
935
936 false
937 }
938
939 pub fn to_completion_request(
940 &self,
941 request_kind: RequestKind,
942 cx: &App,
943 ) -> LanguageModelRequest {
944 let mut request = LanguageModelRequest {
945 messages: vec![],
946 tools: Vec::new(),
947 stop: Vec::new(),
948 temperature: None,
949 };
950
951 if let Some(project_context) = self.project_context.borrow().as_ref() {
952 if let Some(system_prompt) = self
953 .prompt_builder
954 .generate_assistant_system_prompt(project_context)
955 .context("failed to generate assistant system prompt")
956 .log_err()
957 {
958 request.messages.push(LanguageModelRequestMessage {
959 role: Role::System,
960 content: vec![MessageContent::Text(system_prompt)],
961 cache: true,
962 });
963 }
964 } else {
965 log::error!("project_context not set.")
966 }
967
968 for message in &self.messages {
969 let mut request_message = LanguageModelRequestMessage {
970 role: message.role,
971 content: Vec::new(),
972 cache: false,
973 };
974
975 match request_kind {
976 RequestKind::Chat => {
977 self.tool_use
978 .attach_tool_results(message.id, &mut request_message);
979 }
980 RequestKind::Summarize => {
981 // We don't care about tool use during summarization.
982 if self.tool_use.message_has_tool_results(message.id) {
983 continue;
984 }
985 }
986 }
987
988 if !message.segments.is_empty() {
989 request_message
990 .content
991 .push(MessageContent::Text(message.to_string()));
992 }
993
994 match request_kind {
995 RequestKind::Chat => {
996 self.tool_use
997 .attach_tool_uses(message.id, &mut request_message);
998 }
999 RequestKind::Summarize => {
1000 // We don't care about tool use during summarization.
1001 }
1002 };
1003
1004 request.messages.push(request_message);
1005 }
1006
1007 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1008 if let Some(last) = request.messages.last_mut() {
1009 last.cache = true;
1010 }
1011
1012 self.attached_tracked_files_state(&mut request.messages, cx);
1013
1014 request
1015 }
1016
1017 fn attached_tracked_files_state(
1018 &self,
1019 messages: &mut Vec<LanguageModelRequestMessage>,
1020 cx: &App,
1021 ) {
1022 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1023
1024 let mut stale_message = String::new();
1025
1026 let action_log = self.action_log.read(cx);
1027
1028 for stale_file in action_log.stale_buffers(cx) {
1029 let Some(file) = stale_file.read(cx).file() else {
1030 continue;
1031 };
1032
1033 if stale_message.is_empty() {
1034 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1035 }
1036
1037 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1038 }
1039
1040 let mut content = Vec::with_capacity(2);
1041
1042 if !stale_message.is_empty() {
1043 content.push(stale_message.into());
1044 }
1045
1046 if action_log.has_edited_files_since_project_diagnostics_check() {
1047 content.push(
1048 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1049 and fix all errors AND warnings you introduced! \
1050 DO NOT mention you're going to do this until you're done."
1051 .into(),
1052 );
1053 }
1054
1055 if !content.is_empty() {
1056 let context_message = LanguageModelRequestMessage {
1057 role: Role::User,
1058 content,
1059 cache: false,
1060 };
1061
1062 messages.push(context_message);
1063 }
1064 }
1065
1066 pub fn stream_completion(
1067 &mut self,
1068 request: LanguageModelRequest,
1069 model: Arc<dyn LanguageModel>,
1070 cx: &mut Context<Self>,
1071 ) {
1072 let pending_completion_id = post_inc(&mut self.completion_count);
1073 let task = cx.spawn(async move |thread, cx| {
1074 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1075 let initial_token_usage =
1076 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1077 let stream_completion = async {
1078 let (mut events, usage) = stream_completion_future.await?;
1079 let mut stop_reason = StopReason::EndTurn;
1080 let mut current_token_usage = TokenUsage::default();
1081
1082 if let Some(usage) = usage {
1083 thread
1084 .update(cx, |_thread, cx| {
1085 cx.emit(ThreadEvent::UsageUpdated(usage));
1086 })
1087 .ok();
1088 }
1089
1090 while let Some(event) = events.next().await {
1091 let event = event?;
1092
1093 thread.update(cx, |thread, cx| {
1094 match event {
1095 LanguageModelCompletionEvent::StartMessage { .. } => {
1096 thread.insert_message(
1097 Role::Assistant,
1098 vec![MessageSegment::Text(String::new())],
1099 cx,
1100 );
1101 }
1102 LanguageModelCompletionEvent::Stop(reason) => {
1103 stop_reason = reason;
1104 }
1105 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1106 thread.update_token_usage_at_last_message(token_usage);
1107 thread.cumulative_token_usage = thread.cumulative_token_usage
1108 + token_usage
1109 - current_token_usage;
1110 current_token_usage = token_usage;
1111 }
1112 LanguageModelCompletionEvent::Text(chunk) => {
1113 if let Some(last_message) = thread.messages.last_mut() {
1114 if last_message.role == Role::Assistant {
1115 last_message.push_text(&chunk);
1116 cx.emit(ThreadEvent::StreamedAssistantText(
1117 last_message.id,
1118 chunk,
1119 ));
1120 } else {
1121 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1122 // of a new Assistant response.
1123 //
1124 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1125 // will result in duplicating the text of the chunk in the rendered Markdown.
1126 thread.insert_message(
1127 Role::Assistant,
1128 vec![MessageSegment::Text(chunk.to_string())],
1129 cx,
1130 );
1131 };
1132 }
1133 }
1134 LanguageModelCompletionEvent::Thinking(chunk) => {
1135 if let Some(last_message) = thread.messages.last_mut() {
1136 if last_message.role == Role::Assistant {
1137 last_message.push_thinking(&chunk);
1138 cx.emit(ThreadEvent::StreamedAssistantThinking(
1139 last_message.id,
1140 chunk,
1141 ));
1142 } else {
1143 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1144 // of a new Assistant response.
1145 //
1146 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1147 // will result in duplicating the text of the chunk in the rendered Markdown.
1148 thread.insert_message(
1149 Role::Assistant,
1150 vec![MessageSegment::Thinking(chunk.to_string())],
1151 cx,
1152 );
1153 };
1154 }
1155 }
1156 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1157 let last_assistant_message_id = thread
1158 .messages
1159 .iter_mut()
1160 .rfind(|message| message.role == Role::Assistant)
1161 .map(|message| message.id)
1162 .unwrap_or_else(|| {
1163 thread.insert_message(Role::Assistant, vec![], cx)
1164 });
1165
1166 thread.tool_use.request_tool_use(
1167 last_assistant_message_id,
1168 tool_use,
1169 cx,
1170 );
1171 }
1172 }
1173
1174 thread.touch_updated_at();
1175 cx.emit(ThreadEvent::StreamedCompletion);
1176 cx.notify();
1177
1178 thread.auto_capture_telemetry(cx);
1179 })?;
1180
1181 smol::future::yield_now().await;
1182 }
1183
1184 thread.update(cx, |thread, cx| {
1185 thread
1186 .pending_completions
1187 .retain(|completion| completion.id != pending_completion_id);
1188
1189 if thread.summary.is_none() && thread.messages.len() >= 2 {
1190 thread.summarize(cx);
1191 }
1192 })?;
1193
1194 anyhow::Ok(stop_reason)
1195 };
1196
1197 let result = stream_completion.await;
1198
1199 thread
1200 .update(cx, |thread, cx| {
1201 thread.finalize_pending_checkpoint(cx);
1202 match result.as_ref() {
1203 Ok(stop_reason) => match stop_reason {
1204 StopReason::ToolUse => {
1205 let tool_uses = thread.use_pending_tools(cx);
1206 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1207 }
1208 StopReason::EndTurn => {}
1209 StopReason::MaxTokens => {}
1210 },
1211 Err(error) => {
1212 if error.is::<PaymentRequiredError>() {
1213 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1214 } else if error.is::<MaxMonthlySpendReachedError>() {
1215 cx.emit(ThreadEvent::ShowError(
1216 ThreadError::MaxMonthlySpendReached,
1217 ));
1218 } else if let Some(error) =
1219 error.downcast_ref::<ModelRequestLimitReachedError>()
1220 {
1221 cx.emit(ThreadEvent::ShowError(
1222 ThreadError::ModelRequestLimitReached { plan: error.plan },
1223 ));
1224 } else if let Some(known_error) =
1225 error.downcast_ref::<LanguageModelKnownError>()
1226 {
1227 match known_error {
1228 LanguageModelKnownError::ContextWindowLimitExceeded {
1229 tokens,
1230 } => {
1231 thread.exceeded_window_error = Some(ExceededWindowError {
1232 model_id: model.id(),
1233 token_count: *tokens,
1234 });
1235 cx.notify();
1236 }
1237 }
1238 } else {
1239 let error_message = error
1240 .chain()
1241 .map(|err| err.to_string())
1242 .collect::<Vec<_>>()
1243 .join("\n");
1244 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1245 header: "Error interacting with language model".into(),
1246 message: SharedString::from(error_message.clone()),
1247 }));
1248 }
1249
1250 thread.cancel_last_completion(cx);
1251 }
1252 }
1253 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1254
1255 thread.auto_capture_telemetry(cx);
1256
1257 if let Ok(initial_usage) = initial_token_usage {
1258 let usage = thread.cumulative_token_usage - initial_usage;
1259
1260 telemetry::event!(
1261 "Assistant Thread Completion",
1262 thread_id = thread.id().to_string(),
1263 model = model.telemetry_id(),
1264 model_provider = model.provider_id().to_string(),
1265 input_tokens = usage.input_tokens,
1266 output_tokens = usage.output_tokens,
1267 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1268 cache_read_input_tokens = usage.cache_read_input_tokens,
1269 );
1270 }
1271 })
1272 .ok();
1273 });
1274
1275 self.pending_completions.push(PendingCompletion {
1276 id: pending_completion_id,
1277 _task: task,
1278 });
1279 }
1280
1281 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1282 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1283 return;
1284 };
1285
1286 if !model.provider.is_authenticated(cx) {
1287 return;
1288 }
1289
1290 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1291 request.messages.push(LanguageModelRequestMessage {
1292 role: Role::User,
1293 content: vec![
1294 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1295 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1296 If the conversation is about a specific subject, include it in the title. \
1297 Be descriptive. DO NOT speak in the first person."
1298 .into(),
1299 ],
1300 cache: false,
1301 });
1302
1303 self.pending_summary = cx.spawn(async move |this, cx| {
1304 async move {
1305 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1306 let (mut messages, usage) = stream.await?;
1307
1308 if let Some(usage) = usage {
1309 this.update(cx, |_thread, cx| {
1310 cx.emit(ThreadEvent::UsageUpdated(usage));
1311 })
1312 .ok();
1313 }
1314
1315 let mut new_summary = String::new();
1316 while let Some(message) = messages.stream.next().await {
1317 let text = message?;
1318 let mut lines = text.lines();
1319 new_summary.extend(lines.next());
1320
1321 // Stop if the LLM generated multiple lines.
1322 if lines.next().is_some() {
1323 break;
1324 }
1325 }
1326
1327 this.update(cx, |this, cx| {
1328 if !new_summary.is_empty() {
1329 this.summary = Some(new_summary.into());
1330 }
1331
1332 cx.emit(ThreadEvent::SummaryGenerated);
1333 })?;
1334
1335 anyhow::Ok(())
1336 }
1337 .log_err()
1338 .await
1339 });
1340 }
1341
1342 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1343 let last_message_id = self.messages.last().map(|message| message.id)?;
1344
1345 match &self.detailed_summary_state {
1346 DetailedSummaryState::Generating { message_id, .. }
1347 | DetailedSummaryState::Generated { message_id, .. }
1348 if *message_id == last_message_id =>
1349 {
1350 // Already up-to-date
1351 return None;
1352 }
1353 _ => {}
1354 }
1355
1356 let ConfiguredModel { model, provider } =
1357 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1358
1359 if !provider.is_authenticated(cx) {
1360 return None;
1361 }
1362
1363 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1364
1365 request.messages.push(LanguageModelRequestMessage {
1366 role: Role::User,
1367 content: vec![
1368 "Generate a detailed summary of this conversation. Include:\n\
1369 1. A brief overview of what was discussed\n\
1370 2. Key facts or information discovered\n\
1371 3. Outcomes or conclusions reached\n\
1372 4. Any action items or next steps if any\n\
1373 Format it in Markdown with headings and bullet points."
1374 .into(),
1375 ],
1376 cache: false,
1377 });
1378
1379 let task = cx.spawn(async move |thread, cx| {
1380 let stream = model.stream_completion_text(request, &cx);
1381 let Some(mut messages) = stream.await.log_err() else {
1382 thread
1383 .update(cx, |this, _cx| {
1384 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1385 })
1386 .log_err();
1387
1388 return;
1389 };
1390
1391 let mut new_detailed_summary = String::new();
1392
1393 while let Some(chunk) = messages.stream.next().await {
1394 if let Some(chunk) = chunk.log_err() {
1395 new_detailed_summary.push_str(&chunk);
1396 }
1397 }
1398
1399 thread
1400 .update(cx, |this, _cx| {
1401 this.detailed_summary_state = DetailedSummaryState::Generated {
1402 text: new_detailed_summary.into(),
1403 message_id: last_message_id,
1404 };
1405 })
1406 .log_err();
1407 });
1408
1409 self.detailed_summary_state = DetailedSummaryState::Generating {
1410 message_id: last_message_id,
1411 };
1412
1413 Some(task)
1414 }
1415
1416 pub fn is_generating_detailed_summary(&self) -> bool {
1417 matches!(
1418 self.detailed_summary_state,
1419 DetailedSummaryState::Generating { .. }
1420 )
1421 }
1422
1423 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1424 self.auto_capture_telemetry(cx);
1425 let request = self.to_completion_request(RequestKind::Chat, cx);
1426 let messages = Arc::new(request.messages);
1427 let pending_tool_uses = self
1428 .tool_use
1429 .pending_tool_uses()
1430 .into_iter()
1431 .filter(|tool_use| tool_use.status.is_idle())
1432 .cloned()
1433 .collect::<Vec<_>>();
1434
1435 for tool_use in pending_tool_uses.iter() {
1436 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1437 if tool.needs_confirmation(&tool_use.input, cx)
1438 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1439 {
1440 self.tool_use.confirm_tool_use(
1441 tool_use.id.clone(),
1442 tool_use.ui_text.clone(),
1443 tool_use.input.clone(),
1444 messages.clone(),
1445 tool,
1446 );
1447 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1448 } else {
1449 self.run_tool(
1450 tool_use.id.clone(),
1451 tool_use.ui_text.clone(),
1452 tool_use.input.clone(),
1453 &messages,
1454 tool,
1455 cx,
1456 );
1457 }
1458 }
1459 }
1460
1461 pending_tool_uses
1462 }
1463
1464 pub fn run_tool(
1465 &mut self,
1466 tool_use_id: LanguageModelToolUseId,
1467 ui_text: impl Into<SharedString>,
1468 input: serde_json::Value,
1469 messages: &[LanguageModelRequestMessage],
1470 tool: Arc<dyn Tool>,
1471 cx: &mut Context<Thread>,
1472 ) {
1473 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1474 self.tool_use
1475 .run_pending_tool(tool_use_id, ui_text.into(), task);
1476 }
1477
1478 fn spawn_tool_use(
1479 &mut self,
1480 tool_use_id: LanguageModelToolUseId,
1481 messages: &[LanguageModelRequestMessage],
1482 input: serde_json::Value,
1483 tool: Arc<dyn Tool>,
1484 cx: &mut Context<Thread>,
1485 ) -> Task<()> {
1486 let tool_name: Arc<str> = tool.name().into();
1487
1488 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1489 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1490 } else {
1491 tool.run(
1492 input,
1493 messages,
1494 self.project.clone(),
1495 self.action_log.clone(),
1496 cx,
1497 )
1498 };
1499
1500 // Store the card separately if it exists
1501 if let Some(card) = tool_result.card.clone() {
1502 self.tool_use
1503 .insert_tool_result_card(tool_use_id.clone(), card);
1504 }
1505
1506 cx.spawn({
1507 async move |thread: WeakEntity<Thread>, cx| {
1508 let output = tool_result.output.await;
1509
1510 thread
1511 .update(cx, |thread, cx| {
1512 let pending_tool_use = thread.tool_use.insert_tool_output(
1513 tool_use_id.clone(),
1514 tool_name,
1515 output,
1516 cx,
1517 );
1518 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1519 })
1520 .ok();
1521 }
1522 })
1523 }
1524
1525 fn tool_finished(
1526 &mut self,
1527 tool_use_id: LanguageModelToolUseId,
1528 pending_tool_use: Option<PendingToolUse>,
1529 canceled: bool,
1530 cx: &mut Context<Self>,
1531 ) {
1532 if self.all_tools_finished() {
1533 let model_registry = LanguageModelRegistry::read_global(cx);
1534 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1535 self.attach_tool_results(cx);
1536 if !canceled {
1537 self.send_to_model(model, RequestKind::Chat, cx);
1538 }
1539 }
1540 }
1541
1542 cx.emit(ThreadEvent::ToolFinished {
1543 tool_use_id,
1544 pending_tool_use,
1545 });
1546 }
1547
1548 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1549 // Insert a user message to contain the tool results.
1550 self.insert_user_message(
1551 // TODO: Sending up a user message without any content results in the model sending back
1552 // responses that also don't have any content. We currently don't handle this case well,
1553 // so for now we provide some text to keep the model on track.
1554 "Here are the tool results.",
1555 Vec::new(),
1556 None,
1557 cx,
1558 );
1559 }
1560
1561 /// Cancels the last pending completion, if there are any pending.
1562 ///
1563 /// Returns whether a completion was canceled.
1564 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1565 let canceled = if self.pending_completions.pop().is_some() {
1566 true
1567 } else {
1568 let mut canceled = false;
1569 for pending_tool_use in self.tool_use.cancel_pending() {
1570 canceled = true;
1571 self.tool_finished(
1572 pending_tool_use.id.clone(),
1573 Some(pending_tool_use),
1574 true,
1575 cx,
1576 );
1577 }
1578 canceled
1579 };
1580 self.finalize_pending_checkpoint(cx);
1581 canceled
1582 }
1583
1584 pub fn feedback(&self) -> Option<ThreadFeedback> {
1585 self.feedback
1586 }
1587
1588 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1589 self.message_feedback.get(&message_id).copied()
1590 }
1591
1592 pub fn report_message_feedback(
1593 &mut self,
1594 message_id: MessageId,
1595 feedback: ThreadFeedback,
1596 cx: &mut Context<Self>,
1597 ) -> Task<Result<()>> {
1598 if self.message_feedback.get(&message_id) == Some(&feedback) {
1599 return Task::ready(Ok(()));
1600 }
1601
1602 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1603 let serialized_thread = self.serialize(cx);
1604 let thread_id = self.id().clone();
1605 let client = self.project.read(cx).client();
1606
1607 let enabled_tool_names: Vec<String> = self
1608 .tools()
1609 .read(cx)
1610 .enabled_tools(cx)
1611 .iter()
1612 .map(|tool| tool.name().to_string())
1613 .collect();
1614
1615 self.message_feedback.insert(message_id, feedback);
1616
1617 cx.notify();
1618
1619 let message_content = self
1620 .message(message_id)
1621 .map(|msg| msg.to_string())
1622 .unwrap_or_default();
1623
1624 cx.background_spawn(async move {
1625 let final_project_snapshot = final_project_snapshot.await;
1626 let serialized_thread = serialized_thread.await?;
1627 let thread_data =
1628 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1629
1630 let rating = match feedback {
1631 ThreadFeedback::Positive => "positive",
1632 ThreadFeedback::Negative => "negative",
1633 };
1634 telemetry::event!(
1635 "Assistant Thread Rated",
1636 rating,
1637 thread_id,
1638 enabled_tool_names,
1639 message_id = message_id.0,
1640 message_content,
1641 thread_data,
1642 final_project_snapshot
1643 );
1644 client.telemetry().flush_events();
1645
1646 Ok(())
1647 })
1648 }
1649
1650 pub fn report_feedback(
1651 &mut self,
1652 feedback: ThreadFeedback,
1653 cx: &mut Context<Self>,
1654 ) -> Task<Result<()>> {
1655 let last_assistant_message_id = self
1656 .messages
1657 .iter()
1658 .rev()
1659 .find(|msg| msg.role == Role::Assistant)
1660 .map(|msg| msg.id);
1661
1662 if let Some(message_id) = last_assistant_message_id {
1663 self.report_message_feedback(message_id, feedback, cx)
1664 } else {
1665 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1666 let serialized_thread = self.serialize(cx);
1667 let thread_id = self.id().clone();
1668 let client = self.project.read(cx).client();
1669 self.feedback = Some(feedback);
1670 cx.notify();
1671
1672 cx.background_spawn(async move {
1673 let final_project_snapshot = final_project_snapshot.await;
1674 let serialized_thread = serialized_thread.await?;
1675 let thread_data = serde_json::to_value(serialized_thread)
1676 .unwrap_or_else(|_| serde_json::Value::Null);
1677
1678 let rating = match feedback {
1679 ThreadFeedback::Positive => "positive",
1680 ThreadFeedback::Negative => "negative",
1681 };
1682 telemetry::event!(
1683 "Assistant Thread Rated",
1684 rating,
1685 thread_id,
1686 thread_data,
1687 final_project_snapshot
1688 );
1689 client.telemetry().flush_events();
1690
1691 Ok(())
1692 })
1693 }
1694 }
1695
1696 /// Create a snapshot of the current project state including git information and unsaved buffers.
1697 fn project_snapshot(
1698 project: Entity<Project>,
1699 cx: &mut Context<Self>,
1700 ) -> Task<Arc<ProjectSnapshot>> {
1701 let git_store = project.read(cx).git_store().clone();
1702 let worktree_snapshots: Vec<_> = project
1703 .read(cx)
1704 .visible_worktrees(cx)
1705 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1706 .collect();
1707
1708 cx.spawn(async move |_, cx| {
1709 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1710
1711 let mut unsaved_buffers = Vec::new();
1712 cx.update(|app_cx| {
1713 let buffer_store = project.read(app_cx).buffer_store();
1714 for buffer_handle in buffer_store.read(app_cx).buffers() {
1715 let buffer = buffer_handle.read(app_cx);
1716 if buffer.is_dirty() {
1717 if let Some(file) = buffer.file() {
1718 let path = file.path().to_string_lossy().to_string();
1719 unsaved_buffers.push(path);
1720 }
1721 }
1722 }
1723 })
1724 .ok();
1725
1726 Arc::new(ProjectSnapshot {
1727 worktree_snapshots,
1728 unsaved_buffer_paths: unsaved_buffers,
1729 timestamp: Utc::now(),
1730 })
1731 })
1732 }
1733
1734 fn worktree_snapshot(
1735 worktree: Entity<project::Worktree>,
1736 git_store: Entity<GitStore>,
1737 cx: &App,
1738 ) -> Task<WorktreeSnapshot> {
1739 cx.spawn(async move |cx| {
1740 // Get worktree path and snapshot
1741 let worktree_info = cx.update(|app_cx| {
1742 let worktree = worktree.read(app_cx);
1743 let path = worktree.abs_path().to_string_lossy().to_string();
1744 let snapshot = worktree.snapshot();
1745 (path, snapshot)
1746 });
1747
1748 let Ok((worktree_path, _snapshot)) = worktree_info else {
1749 return WorktreeSnapshot {
1750 worktree_path: String::new(),
1751 git_state: None,
1752 };
1753 };
1754
1755 let git_state = git_store
1756 .update(cx, |git_store, cx| {
1757 git_store
1758 .repositories()
1759 .values()
1760 .find(|repo| {
1761 repo.read(cx)
1762 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1763 .is_some()
1764 })
1765 .cloned()
1766 })
1767 .ok()
1768 .flatten()
1769 .map(|repo| {
1770 repo.update(cx, |repo, _| {
1771 let current_branch =
1772 repo.branch.as_ref().map(|branch| branch.name.to_string());
1773 repo.send_job(None, |state, _| async move {
1774 let RepositoryState::Local { backend, .. } = state else {
1775 return GitState {
1776 remote_url: None,
1777 head_sha: None,
1778 current_branch,
1779 diff: None,
1780 };
1781 };
1782
1783 let remote_url = backend.remote_url("origin");
1784 let head_sha = backend.head_sha();
1785 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1786
1787 GitState {
1788 remote_url,
1789 head_sha,
1790 current_branch,
1791 diff,
1792 }
1793 })
1794 })
1795 });
1796
1797 let git_state = match git_state {
1798 Some(git_state) => match git_state.ok() {
1799 Some(git_state) => git_state.await.ok(),
1800 None => None,
1801 },
1802 None => None,
1803 };
1804
1805 WorktreeSnapshot {
1806 worktree_path,
1807 git_state,
1808 }
1809 })
1810 }
1811
1812 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1813 let mut markdown = Vec::new();
1814
1815 if let Some(summary) = self.summary() {
1816 writeln!(markdown, "# {summary}\n")?;
1817 };
1818
1819 for message in self.messages() {
1820 writeln!(
1821 markdown,
1822 "## {role}\n",
1823 role = match message.role {
1824 Role::User => "User",
1825 Role::Assistant => "Assistant",
1826 Role::System => "System",
1827 }
1828 )?;
1829
1830 if !message.context.is_empty() {
1831 writeln!(markdown, "{}", message.context)?;
1832 }
1833
1834 for segment in &message.segments {
1835 match segment {
1836 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1837 MessageSegment::Thinking(text) => {
1838 writeln!(markdown, "<think>{}</think>\n", text)?
1839 }
1840 }
1841 }
1842
1843 for tool_use in self.tool_uses_for_message(message.id, cx) {
1844 writeln!(
1845 markdown,
1846 "**Use Tool: {} ({})**",
1847 tool_use.name, tool_use.id
1848 )?;
1849 writeln!(markdown, "```json")?;
1850 writeln!(
1851 markdown,
1852 "{}",
1853 serde_json::to_string_pretty(&tool_use.input)?
1854 )?;
1855 writeln!(markdown, "```")?;
1856 }
1857
1858 for tool_result in self.tool_results_for_message(message.id) {
1859 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1860 if tool_result.is_error {
1861 write!(markdown, " (Error)")?;
1862 }
1863
1864 writeln!(markdown, "**\n")?;
1865 writeln!(markdown, "{}", tool_result.content)?;
1866 }
1867 }
1868
1869 Ok(String::from_utf8_lossy(&markdown).to_string())
1870 }
1871
1872 pub fn keep_edits_in_range(
1873 &mut self,
1874 buffer: Entity<language::Buffer>,
1875 buffer_range: Range<language::Anchor>,
1876 cx: &mut Context<Self>,
1877 ) {
1878 self.action_log.update(cx, |action_log, cx| {
1879 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1880 });
1881 }
1882
1883 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1884 self.action_log
1885 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1886 }
1887
1888 pub fn reject_edits_in_ranges(
1889 &mut self,
1890 buffer: Entity<language::Buffer>,
1891 buffer_ranges: Vec<Range<language::Anchor>>,
1892 cx: &mut Context<Self>,
1893 ) -> Task<Result<()>> {
1894 self.action_log.update(cx, |action_log, cx| {
1895 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1896 })
1897 }
1898
1899 pub fn action_log(&self) -> &Entity<ActionLog> {
1900 &self.action_log
1901 }
1902
1903 pub fn project(&self) -> &Entity<Project> {
1904 &self.project
1905 }
1906
1907 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1908 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1909 return;
1910 }
1911
1912 let now = Instant::now();
1913 if let Some(last) = self.last_auto_capture_at {
1914 if now.duration_since(last).as_secs() < 10 {
1915 return;
1916 }
1917 }
1918
1919 self.last_auto_capture_at = Some(now);
1920
1921 let thread_id = self.id().clone();
1922 let github_login = self
1923 .project
1924 .read(cx)
1925 .user_store()
1926 .read(cx)
1927 .current_user()
1928 .map(|user| user.github_login.clone());
1929 let client = self.project.read(cx).client().clone();
1930 let serialize_task = self.serialize(cx);
1931
1932 cx.background_executor()
1933 .spawn(async move {
1934 if let Ok(serialized_thread) = serialize_task.await {
1935 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1936 telemetry::event!(
1937 "Agent Thread Auto-Captured",
1938 thread_id = thread_id.to_string(),
1939 thread_data = thread_data,
1940 auto_capture_reason = "tracked_user",
1941 github_login = github_login
1942 );
1943
1944 client.telemetry().flush_events();
1945 }
1946 }
1947 })
1948 .detach();
1949 }
1950
1951 pub fn cumulative_token_usage(&self) -> TokenUsage {
1952 self.cumulative_token_usage
1953 }
1954
1955 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1956 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1957 return TotalTokenUsage::default();
1958 };
1959
1960 let max = model.model.max_token_count();
1961
1962 let index = self
1963 .messages
1964 .iter()
1965 .position(|msg| msg.id == message_id)
1966 .unwrap_or(0);
1967
1968 if index == 0 {
1969 return TotalTokenUsage { total: 0, max };
1970 }
1971
1972 let token_usage = &self
1973 .request_token_usage
1974 .get(index - 1)
1975 .cloned()
1976 .unwrap_or_default();
1977
1978 TotalTokenUsage {
1979 total: token_usage.total_tokens() as usize,
1980 max,
1981 }
1982 }
1983
1984 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1985 let model_registry = LanguageModelRegistry::read_global(cx);
1986 let Some(model) = model_registry.default_model() else {
1987 return TotalTokenUsage::default();
1988 };
1989
1990 let max = model.model.max_token_count();
1991
1992 if let Some(exceeded_error) = &self.exceeded_window_error {
1993 if model.model.id() == exceeded_error.model_id {
1994 return TotalTokenUsage {
1995 total: exceeded_error.token_count,
1996 max,
1997 };
1998 }
1999 }
2000
2001 let total = self
2002 .token_usage_at_last_message()
2003 .unwrap_or_default()
2004 .total_tokens() as usize;
2005
2006 TotalTokenUsage { total, max }
2007 }
2008
2009 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2010 self.request_token_usage
2011 .get(self.messages.len().saturating_sub(1))
2012 .or_else(|| self.request_token_usage.last())
2013 .cloned()
2014 }
2015
2016 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2017 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2018 self.request_token_usage
2019 .resize(self.messages.len(), placeholder);
2020
2021 if let Some(last) = self.request_token_usage.last_mut() {
2022 *last = token_usage;
2023 }
2024 }
2025
2026 pub fn deny_tool_use(
2027 &mut self,
2028 tool_use_id: LanguageModelToolUseId,
2029 tool_name: Arc<str>,
2030 cx: &mut Context<Self>,
2031 ) {
2032 let err = Err(anyhow::anyhow!(
2033 "Permission to run tool action denied by user"
2034 ));
2035
2036 self.tool_use
2037 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2038 self.tool_finished(tool_use_id.clone(), None, true, cx);
2039 }
2040}
2041
2042#[derive(Debug, Clone, Error)]
2043pub enum ThreadError {
2044 #[error("Payment required")]
2045 PaymentRequired,
2046 #[error("Max monthly spend reached")]
2047 MaxMonthlySpendReached,
2048 #[error("Model request limit reached")]
2049 ModelRequestLimitReached { plan: Plan },
2050 #[error("Message {header}: {message}")]
2051 Message {
2052 header: SharedString,
2053 message: SharedString,
2054 },
2055}
2056
2057#[derive(Debug, Clone)]
2058pub enum ThreadEvent {
2059 ShowError(ThreadError),
2060 UsageUpdated(RequestUsage),
2061 StreamedCompletion,
2062 StreamedAssistantText(MessageId, String),
2063 StreamedAssistantThinking(MessageId, String),
2064 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2065 MessageAdded(MessageId),
2066 MessageEdited(MessageId),
2067 MessageDeleted(MessageId),
2068 SummaryGenerated,
2069 SummaryChanged,
2070 UsePendingTools {
2071 tool_uses: Vec<PendingToolUse>,
2072 },
2073 ToolFinished {
2074 #[allow(unused)]
2075 tool_use_id: LanguageModelToolUseId,
2076 /// The pending tool use that corresponds to this tool.
2077 pending_tool_use: Option<PendingToolUse>,
2078 },
2079 CheckpointChanged,
2080 ToolConfirmationNeeded,
2081}
2082
2083impl EventEmitter<ThreadEvent> for Thread {}
2084
2085struct PendingCompletion {
2086 id: usize,
2087 _task: Task<()>,
2088}
2089
2090#[cfg(test)]
2091mod tests {
2092 use super::*;
2093 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2094 use assistant_settings::AssistantSettings;
2095 use context_server::ContextServerSettings;
2096 use editor::EditorSettings;
2097 use gpui::TestAppContext;
2098 use project::{FakeFs, Project};
2099 use prompt_store::PromptBuilder;
2100 use serde_json::json;
2101 use settings::{Settings, SettingsStore};
2102 use std::sync::Arc;
2103 use theme::ThemeSettings;
2104 use util::path;
2105 use workspace::Workspace;
2106
2107 #[gpui::test]
2108 async fn test_message_with_context(cx: &mut TestAppContext) {
2109 init_test_settings(cx);
2110
2111 let project = create_test_project(
2112 cx,
2113 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2114 )
2115 .await;
2116
2117 let (_workspace, _thread_store, thread, context_store) =
2118 setup_test_environment(cx, project.clone()).await;
2119
2120 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2121 .await
2122 .unwrap();
2123
2124 let context =
2125 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2126
2127 // Insert user message with context
2128 let message_id = thread.update(cx, |thread, cx| {
2129 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2130 });
2131
2132 // Check content and context in message object
2133 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2134
2135 // Use different path format strings based on platform for the test
2136 #[cfg(windows)]
2137 let path_part = r"test\code.rs";
2138 #[cfg(not(windows))]
2139 let path_part = "test/code.rs";
2140
2141 let expected_context = format!(
2142 r#"
2143<context>
2144The following items were attached by the user. You don't need to use other tools to read them.
2145
2146<files>
2147```rs {path_part}
2148fn main() {{
2149 println!("Hello, world!");
2150}}
2151```
2152</files>
2153</context>
2154"#
2155 );
2156
2157 assert_eq!(message.role, Role::User);
2158 assert_eq!(message.segments.len(), 1);
2159 assert_eq!(
2160 message.segments[0],
2161 MessageSegment::Text("Please explain this code".to_string())
2162 );
2163 assert_eq!(message.context, expected_context);
2164
2165 // Check message in request
2166 let request = thread.read_with(cx, |thread, cx| {
2167 thread.to_completion_request(RequestKind::Chat, cx)
2168 });
2169
2170 assert_eq!(request.messages.len(), 2);
2171 let expected_full_message = format!("{}Please explain this code", expected_context);
2172 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2173 }
2174
2175 #[gpui::test]
2176 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2177 init_test_settings(cx);
2178
2179 let project = create_test_project(
2180 cx,
2181 json!({
2182 "file1.rs": "fn function1() {}\n",
2183 "file2.rs": "fn function2() {}\n",
2184 "file3.rs": "fn function3() {}\n",
2185 }),
2186 )
2187 .await;
2188
2189 let (_, _thread_store, thread, context_store) =
2190 setup_test_environment(cx, project.clone()).await;
2191
2192 // Open files individually
2193 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2194 .await
2195 .unwrap();
2196 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2197 .await
2198 .unwrap();
2199 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2200 .await
2201 .unwrap();
2202
2203 // Get the context objects
2204 let contexts = context_store.update(cx, |store, _| store.context().clone());
2205 assert_eq!(contexts.len(), 3);
2206
2207 // First message with context 1
2208 let message1_id = thread.update(cx, |thread, cx| {
2209 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2210 });
2211
2212 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2213 let message2_id = thread.update(cx, |thread, cx| {
2214 thread.insert_user_message(
2215 "Message 2",
2216 vec![contexts[0].clone(), contexts[1].clone()],
2217 None,
2218 cx,
2219 )
2220 });
2221
2222 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2223 let message3_id = thread.update(cx, |thread, cx| {
2224 thread.insert_user_message(
2225 "Message 3",
2226 vec![
2227 contexts[0].clone(),
2228 contexts[1].clone(),
2229 contexts[2].clone(),
2230 ],
2231 None,
2232 cx,
2233 )
2234 });
2235
2236 // Check what contexts are included in each message
2237 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2238 (
2239 thread.message(message1_id).unwrap().clone(),
2240 thread.message(message2_id).unwrap().clone(),
2241 thread.message(message3_id).unwrap().clone(),
2242 )
2243 });
2244
2245 // First message should include context 1
2246 assert!(message1.context.contains("file1.rs"));
2247
2248 // Second message should include only context 2 (not 1)
2249 assert!(!message2.context.contains("file1.rs"));
2250 assert!(message2.context.contains("file2.rs"));
2251
2252 // Third message should include only context 3 (not 1 or 2)
2253 assert!(!message3.context.contains("file1.rs"));
2254 assert!(!message3.context.contains("file2.rs"));
2255 assert!(message3.context.contains("file3.rs"));
2256
2257 // Check entire request to make sure all contexts are properly included
2258 let request = thread.read_with(cx, |thread, cx| {
2259 thread.to_completion_request(RequestKind::Chat, cx)
2260 });
2261
2262 // The request should contain all 3 messages
2263 assert_eq!(request.messages.len(), 4);
2264
2265 // Check that the contexts are properly formatted in each message
2266 assert!(request.messages[1].string_contents().contains("file1.rs"));
2267 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2268 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2269
2270 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2271 assert!(request.messages[2].string_contents().contains("file2.rs"));
2272 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2273
2274 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2275 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2276 assert!(request.messages[3].string_contents().contains("file3.rs"));
2277 }
2278
2279 #[gpui::test]
2280 async fn test_message_without_files(cx: &mut TestAppContext) {
2281 init_test_settings(cx);
2282
2283 let project = create_test_project(
2284 cx,
2285 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2286 )
2287 .await;
2288
2289 let (_, _thread_store, thread, _context_store) =
2290 setup_test_environment(cx, project.clone()).await;
2291
2292 // Insert user message without any context (empty context vector)
2293 let message_id = thread.update(cx, |thread, cx| {
2294 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2295 });
2296
2297 // Check content and context in message object
2298 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2299
2300 // Context should be empty when no files are included
2301 assert_eq!(message.role, Role::User);
2302 assert_eq!(message.segments.len(), 1);
2303 assert_eq!(
2304 message.segments[0],
2305 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2306 );
2307 assert_eq!(message.context, "");
2308
2309 // Check message in request
2310 let request = thread.read_with(cx, |thread, cx| {
2311 thread.to_completion_request(RequestKind::Chat, cx)
2312 });
2313
2314 assert_eq!(request.messages.len(), 2);
2315 assert_eq!(
2316 request.messages[1].string_contents(),
2317 "What is the best way to learn Rust?"
2318 );
2319
2320 // Add second message, also without context
2321 let message2_id = thread.update(cx, |thread, cx| {
2322 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2323 });
2324
2325 let message2 =
2326 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2327 assert_eq!(message2.context, "");
2328
2329 // Check that both messages appear in the request
2330 let request = thread.read_with(cx, |thread, cx| {
2331 thread.to_completion_request(RequestKind::Chat, cx)
2332 });
2333
2334 assert_eq!(request.messages.len(), 3);
2335 assert_eq!(
2336 request.messages[1].string_contents(),
2337 "What is the best way to learn Rust?"
2338 );
2339 assert_eq!(
2340 request.messages[2].string_contents(),
2341 "Are there any good books?"
2342 );
2343 }
2344
2345 #[gpui::test]
2346 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2347 init_test_settings(cx);
2348
2349 let project = create_test_project(
2350 cx,
2351 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2352 )
2353 .await;
2354
2355 let (_workspace, _thread_store, thread, context_store) =
2356 setup_test_environment(cx, project.clone()).await;
2357
2358 // Open buffer and add it to context
2359 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2360 .await
2361 .unwrap();
2362
2363 let context =
2364 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2365
2366 // Insert user message with the buffer as context
2367 thread.update(cx, |thread, cx| {
2368 thread.insert_user_message("Explain this code", vec![context], None, cx)
2369 });
2370
2371 // Create a request and check that it doesn't have a stale buffer warning yet
2372 let initial_request = thread.read_with(cx, |thread, cx| {
2373 thread.to_completion_request(RequestKind::Chat, cx)
2374 });
2375
2376 // Make sure we don't have a stale file warning yet
2377 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2378 msg.string_contents()
2379 .contains("These files changed since last read:")
2380 });
2381 assert!(
2382 !has_stale_warning,
2383 "Should not have stale buffer warning before buffer is modified"
2384 );
2385
2386 // Modify the buffer
2387 buffer.update(cx, |buffer, cx| {
2388 // Find a position at the end of line 1
2389 buffer.edit(
2390 [(1..1, "\n println!(\"Added a new line\");\n")],
2391 None,
2392 cx,
2393 );
2394 });
2395
2396 // Insert another user message without context
2397 thread.update(cx, |thread, cx| {
2398 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2399 });
2400
2401 // Create a new request and check for the stale buffer warning
2402 let new_request = thread.read_with(cx, |thread, cx| {
2403 thread.to_completion_request(RequestKind::Chat, cx)
2404 });
2405
2406 // We should have a stale file warning as the last message
2407 let last_message = new_request
2408 .messages
2409 .last()
2410 .expect("Request should have messages");
2411
2412 // The last message should be the stale buffer notification
2413 assert_eq!(last_message.role, Role::User);
2414
2415 // Check the exact content of the message
2416 let expected_content = "These files changed since last read:\n- code.rs\n";
2417 assert_eq!(
2418 last_message.string_contents(),
2419 expected_content,
2420 "Last message should be exactly the stale buffer notification"
2421 );
2422 }
2423
2424 fn init_test_settings(cx: &mut TestAppContext) {
2425 cx.update(|cx| {
2426 let settings_store = SettingsStore::test(cx);
2427 cx.set_global(settings_store);
2428 language::init(cx);
2429 Project::init_settings(cx);
2430 AssistantSettings::register(cx);
2431 thread_store::init(cx);
2432 workspace::init_settings(cx);
2433 ThemeSettings::register(cx);
2434 ContextServerSettings::register(cx);
2435 EditorSettings::register(cx);
2436 });
2437 }
2438
2439 // Helper to create a test project with test files
2440 async fn create_test_project(
2441 cx: &mut TestAppContext,
2442 files: serde_json::Value,
2443 ) -> Entity<Project> {
2444 let fs = FakeFs::new(cx.executor());
2445 fs.insert_tree(path!("/test"), files).await;
2446 Project::test(fs, [path!("/test").as_ref()], cx).await
2447 }
2448
2449 async fn setup_test_environment(
2450 cx: &mut TestAppContext,
2451 project: Entity<Project>,
2452 ) -> (
2453 Entity<Workspace>,
2454 Entity<ThreadStore>,
2455 Entity<Thread>,
2456 Entity<ContextStore>,
2457 ) {
2458 let (workspace, cx) =
2459 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2460
2461 let thread_store = cx
2462 .update(|_, cx| {
2463 ThreadStore::load(
2464 project.clone(),
2465 cx.new(|_| ToolWorkingSet::default()),
2466 Arc::new(PromptBuilder::new(None).unwrap()),
2467 cx,
2468 )
2469 })
2470 .await;
2471
2472 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2473 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2474
2475 (workspace, thread_store, thread, context_store)
2476 }
2477
2478 async fn add_file_to_context(
2479 project: &Entity<Project>,
2480 context_store: &Entity<ContextStore>,
2481 path: &str,
2482 cx: &mut TestAppContext,
2483 ) -> Result<Entity<language::Buffer>> {
2484 let buffer_path = project
2485 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2486 .unwrap();
2487
2488 let buffer = project
2489 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2490 .await
2491 .unwrap();
2492
2493 context_store
2494 .update(cx, |store, cx| {
2495 store.add_file_from_buffer(buffer.clone(), cx)
2496 })
2497 .await?;
2498
2499 Ok(buffer)
2500 }
2501}