1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
18};
19use language_model::{
20 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
21 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
22 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
23 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
24 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
25 TokenUsage,
26};
27use project::Project;
28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
29use prompt_store::PromptBuilder;
30use proto::Plan;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use settings::Settings;
34use thiserror::Error;
35use util::{ResultExt as _, TryFutureExt as _, post_inc};
36use uuid::Uuid;
37
38use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
39use crate::thread_store::{
40 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
41 SerializedToolUse, SharedProjectContext,
42};
43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
44
45#[derive(
46 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
47)]
48pub struct ThreadId(Arc<str>);
49
50impl ThreadId {
51 pub fn new() -> Self {
52 Self(Uuid::new_v4().to_string().into())
53 }
54}
55
56impl std::fmt::Display for ThreadId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62impl From<&str> for ThreadId {
63 fn from(value: &str) -> Self {
64 Self(value.into())
65 }
66}
67
68/// The ID of the user prompt that initiated a request.
69///
70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
72pub struct PromptId(Arc<str>);
73
74impl PromptId {
75 pub fn new() -> Self {
76 Self(Uuid::new_v4().to_string().into())
77 }
78}
79
80impl std::fmt::Display for PromptId {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
87pub struct MessageId(pub(crate) usize);
88
89impl MessageId {
90 fn post_inc(&mut self) -> Self {
91 Self(post_inc(&mut self.0))
92 }
93}
94
95/// A message in a [`Thread`].
96#[derive(Debug, Clone)]
97pub struct Message {
98 pub id: MessageId,
99 pub role: Role,
100 pub segments: Vec<MessageSegment>,
101 pub loaded_context: LoadedContext,
102}
103
104impl Message {
105 /// Returns whether the message contains any meaningful text that should be displayed
106 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
107 pub fn should_display_content(&self) -> bool {
108 self.segments.iter().all(|segment| segment.should_display())
109 }
110
111 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
112 if let Some(MessageSegment::Thinking {
113 text: segment,
114 signature: current_signature,
115 }) = self.segments.last_mut()
116 {
117 if let Some(signature) = signature {
118 *current_signature = Some(signature);
119 }
120 segment.push_str(text);
121 } else {
122 self.segments.push(MessageSegment::Thinking {
123 text: text.to_string(),
124 signature,
125 });
126 }
127 }
128
129 pub fn push_text(&mut self, text: &str) {
130 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
131 segment.push_str(text);
132 } else {
133 self.segments.push(MessageSegment::Text(text.to_string()));
134 }
135 }
136
137 pub fn to_string(&self) -> String {
138 let mut result = String::new();
139
140 if !self.loaded_context.text.is_empty() {
141 result.push_str(&self.loaded_context.text);
142 }
143
144 for segment in &self.segments {
145 match segment {
146 MessageSegment::Text(text) => result.push_str(text),
147 MessageSegment::Thinking { text, .. } => {
148 result.push_str("<think>\n");
149 result.push_str(text);
150 result.push_str("\n</think>");
151 }
152 MessageSegment::RedactedThinking(_) => {}
153 }
154 }
155
156 result
157 }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161pub enum MessageSegment {
162 Text(String),
163 Thinking {
164 text: String,
165 signature: Option<String>,
166 },
167 RedactedThinking(Vec<u8>),
168}
169
170impl MessageSegment {
171 pub fn should_display(&self) -> bool {
172 match self {
173 Self::Text(text) => text.is_empty(),
174 Self::Thinking { text, .. } => text.is_empty(),
175 Self::RedactedThinking(_) => false,
176 }
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ProjectSnapshot {
182 pub worktree_snapshots: Vec<WorktreeSnapshot>,
183 pub unsaved_buffer_paths: Vec<String>,
184 pub timestamp: DateTime<Utc>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct WorktreeSnapshot {
189 pub worktree_path: String,
190 pub git_state: Option<GitState>,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct GitState {
195 pub remote_url: Option<String>,
196 pub head_sha: Option<String>,
197 pub current_branch: Option<String>,
198 pub diff: Option<String>,
199}
200
201#[derive(Clone)]
202pub struct ThreadCheckpoint {
203 message_id: MessageId,
204 git_checkpoint: GitStoreCheckpoint,
205}
206
207#[derive(Copy, Clone, Debug, PartialEq, Eq)]
208pub enum ThreadFeedback {
209 Positive,
210 Negative,
211}
212
213pub enum LastRestoreCheckpoint {
214 Pending {
215 message_id: MessageId,
216 },
217 Error {
218 message_id: MessageId,
219 error: String,
220 },
221}
222
223impl LastRestoreCheckpoint {
224 pub fn message_id(&self) -> MessageId {
225 match self {
226 LastRestoreCheckpoint::Pending { message_id } => *message_id,
227 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
228 }
229 }
230}
231
232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
233pub enum DetailedSummaryState {
234 #[default]
235 NotGenerated,
236 Generating {
237 message_id: MessageId,
238 },
239 Generated {
240 text: SharedString,
241 message_id: MessageId,
242 },
243}
244
245#[derive(Default)]
246pub struct TotalTokenUsage {
247 pub total: usize,
248 pub max: usize,
249}
250
251impl TotalTokenUsage {
252 pub fn ratio(&self) -> TokenUsageRatio {
253 #[cfg(debug_assertions)]
254 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
255 .unwrap_or("0.8".to_string())
256 .parse()
257 .unwrap();
258 #[cfg(not(debug_assertions))]
259 let warning_threshold: f32 = 0.8;
260
261 if self.total >= self.max {
262 TokenUsageRatio::Exceeded
263 } else if self.total as f32 / self.max as f32 >= warning_threshold {
264 TokenUsageRatio::Warning
265 } else {
266 TokenUsageRatio::Normal
267 }
268 }
269
270 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
271 TotalTokenUsage {
272 total: self.total + tokens,
273 max: self.max,
274 }
275 }
276}
277
278#[derive(Debug, Default, PartialEq, Eq)]
279pub enum TokenUsageRatio {
280 #[default]
281 Normal,
282 Warning,
283 Exceeded,
284}
285
286/// A thread of conversation with the LLM.
287pub struct Thread {
288 id: ThreadId,
289 updated_at: DateTime<Utc>,
290 summary: Option<SharedString>,
291 pending_summary: Task<Option<()>>,
292 detailed_summary_state: DetailedSummaryState,
293 messages: Vec<Message>,
294 next_message_id: MessageId,
295 last_prompt_id: PromptId,
296 project_context: SharedProjectContext,
297 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
298 completion_count: usize,
299 pending_completions: Vec<PendingCompletion>,
300 project: Entity<Project>,
301 prompt_builder: Arc<PromptBuilder>,
302 tools: Entity<ToolWorkingSet>,
303 tool_use: ToolUseState,
304 action_log: Entity<ActionLog>,
305 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
306 pending_checkpoint: Option<ThreadCheckpoint>,
307 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
308 request_token_usage: Vec<TokenUsage>,
309 cumulative_token_usage: TokenUsage,
310 exceeded_window_error: Option<ExceededWindowError>,
311 feedback: Option<ThreadFeedback>,
312 message_feedback: HashMap<MessageId, ThreadFeedback>,
313 last_auto_capture_at: Option<Instant>,
314 request_callback: Option<
315 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
316 >,
317 remaining_turns: u32,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct ExceededWindowError {
322 /// Model used when last message exceeded context window
323 model_id: LanguageModelId,
324 /// Token count including last message
325 token_count: usize,
326}
327
328impl Thread {
329 pub fn new(
330 project: Entity<Project>,
331 tools: Entity<ToolWorkingSet>,
332 prompt_builder: Arc<PromptBuilder>,
333 system_prompt: SharedProjectContext,
334 cx: &mut Context<Self>,
335 ) -> Self {
336 Self {
337 id: ThreadId::new(),
338 updated_at: Utc::now(),
339 summary: None,
340 pending_summary: Task::ready(None),
341 detailed_summary_state: DetailedSummaryState::NotGenerated,
342 messages: Vec::new(),
343 next_message_id: MessageId(0),
344 last_prompt_id: PromptId::new(),
345 project_context: system_prompt,
346 checkpoints_by_message: HashMap::default(),
347 completion_count: 0,
348 pending_completions: Vec::new(),
349 project: project.clone(),
350 prompt_builder,
351 tools: tools.clone(),
352 last_restore_checkpoint: None,
353 pending_checkpoint: None,
354 tool_use: ToolUseState::new(tools.clone()),
355 action_log: cx.new(|_| ActionLog::new(project.clone())),
356 initial_project_snapshot: {
357 let project_snapshot = Self::project_snapshot(project, cx);
358 cx.foreground_executor()
359 .spawn(async move { Some(project_snapshot.await) })
360 .shared()
361 },
362 request_token_usage: Vec::new(),
363 cumulative_token_usage: TokenUsage::default(),
364 exceeded_window_error: None,
365 feedback: None,
366 message_feedback: HashMap::default(),
367 last_auto_capture_at: None,
368 request_callback: None,
369 remaining_turns: u32::MAX,
370 }
371 }
372
373 pub fn deserialize(
374 id: ThreadId,
375 serialized: SerializedThread,
376 project: Entity<Project>,
377 tools: Entity<ToolWorkingSet>,
378 prompt_builder: Arc<PromptBuilder>,
379 project_context: SharedProjectContext,
380 cx: &mut Context<Self>,
381 ) -> Self {
382 let next_message_id = MessageId(
383 serialized
384 .messages
385 .last()
386 .map(|message| message.id.0 + 1)
387 .unwrap_or(0),
388 );
389 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
390
391 Self {
392 id,
393 updated_at: serialized.updated_at,
394 summary: Some(serialized.summary),
395 pending_summary: Task::ready(None),
396 detailed_summary_state: serialized.detailed_summary_state,
397 messages: serialized
398 .messages
399 .into_iter()
400 .map(|message| Message {
401 id: message.id,
402 role: message.role,
403 segments: message
404 .segments
405 .into_iter()
406 .map(|segment| match segment {
407 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
408 SerializedMessageSegment::Thinking { text, signature } => {
409 MessageSegment::Thinking { text, signature }
410 }
411 SerializedMessageSegment::RedactedThinking { data } => {
412 MessageSegment::RedactedThinking(data)
413 }
414 })
415 .collect(),
416 loaded_context: LoadedContext {
417 contexts: Vec::new(),
418 text: message.context,
419 images: Vec::new(),
420 },
421 })
422 .collect(),
423 next_message_id,
424 last_prompt_id: PromptId::new(),
425 project_context,
426 checkpoints_by_message: HashMap::default(),
427 completion_count: 0,
428 pending_completions: Vec::new(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 project: project.clone(),
432 prompt_builder,
433 tools,
434 tool_use,
435 action_log: cx.new(|_| ActionLog::new(project)),
436 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
437 request_token_usage: serialized.request_token_usage,
438 cumulative_token_usage: serialized.cumulative_token_usage,
439 exceeded_window_error: None,
440 feedback: None,
441 message_feedback: HashMap::default(),
442 last_auto_capture_at: None,
443 request_callback: None,
444 remaining_turns: u32::MAX,
445 }
446 }
447
448 pub fn set_request_callback(
449 &mut self,
450 callback: impl 'static
451 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
452 ) {
453 self.request_callback = Some(Box::new(callback));
454 }
455
456 pub fn id(&self) -> &ThreadId {
457 &self.id
458 }
459
460 pub fn is_empty(&self) -> bool {
461 self.messages.is_empty()
462 }
463
464 pub fn updated_at(&self) -> DateTime<Utc> {
465 self.updated_at
466 }
467
468 pub fn touch_updated_at(&mut self) {
469 self.updated_at = Utc::now();
470 }
471
472 pub fn advance_prompt_id(&mut self) {
473 self.last_prompt_id = PromptId::new();
474 }
475
476 pub fn summary(&self) -> Option<SharedString> {
477 self.summary.clone()
478 }
479
480 pub fn project_context(&self) -> SharedProjectContext {
481 self.project_context.clone()
482 }
483
484 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
485
486 pub fn summary_or_default(&self) -> SharedString {
487 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
488 }
489
490 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
491 let Some(current_summary) = &self.summary else {
492 // Don't allow setting summary until generated
493 return;
494 };
495
496 let mut new_summary = new_summary.into();
497
498 if new_summary.is_empty() {
499 new_summary = Self::DEFAULT_SUMMARY;
500 }
501
502 if current_summary != &new_summary {
503 self.summary = Some(new_summary);
504 cx.emit(ThreadEvent::SummaryChanged);
505 }
506 }
507
508 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
509 self.latest_detailed_summary()
510 .unwrap_or_else(|| self.text().into())
511 }
512
513 fn latest_detailed_summary(&self) -> Option<SharedString> {
514 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
515 Some(text.clone())
516 } else {
517 None
518 }
519 }
520
521 pub fn message(&self, id: MessageId) -> Option<&Message> {
522 let index = self
523 .messages
524 .binary_search_by(|message| message.id.cmp(&id))
525 .ok()?;
526
527 self.messages.get(index)
528 }
529
530 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
531 self.messages.iter()
532 }
533
534 pub fn is_generating(&self) -> bool {
535 !self.pending_completions.is_empty() || !self.all_tools_finished()
536 }
537
538 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
539 &self.tools
540 }
541
542 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
543 self.tool_use
544 .pending_tool_uses()
545 .into_iter()
546 .find(|tool_use| &tool_use.id == id)
547 }
548
549 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
550 self.tool_use
551 .pending_tool_uses()
552 .into_iter()
553 .filter(|tool_use| tool_use.status.needs_confirmation())
554 }
555
556 pub fn has_pending_tool_uses(&self) -> bool {
557 !self.tool_use.pending_tool_uses().is_empty()
558 }
559
560 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
561 self.checkpoints_by_message.get(&id).cloned()
562 }
563
564 pub fn restore_checkpoint(
565 &mut self,
566 checkpoint: ThreadCheckpoint,
567 cx: &mut Context<Self>,
568 ) -> Task<Result<()>> {
569 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
570 message_id: checkpoint.message_id,
571 });
572 cx.emit(ThreadEvent::CheckpointChanged);
573 cx.notify();
574
575 let git_store = self.project().read(cx).git_store().clone();
576 let restore = git_store.update(cx, |git_store, cx| {
577 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
578 });
579
580 cx.spawn(async move |this, cx| {
581 let result = restore.await;
582 this.update(cx, |this, cx| {
583 if let Err(err) = result.as_ref() {
584 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
585 message_id: checkpoint.message_id,
586 error: err.to_string(),
587 });
588 } else {
589 this.truncate(checkpoint.message_id, cx);
590 this.last_restore_checkpoint = None;
591 }
592 this.pending_checkpoint = None;
593 cx.emit(ThreadEvent::CheckpointChanged);
594 cx.notify();
595 })?;
596 result
597 })
598 }
599
600 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
601 let pending_checkpoint = if self.is_generating() {
602 return;
603 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
604 checkpoint
605 } else {
606 return;
607 };
608
609 let git_store = self.project.read(cx).git_store().clone();
610 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
611 cx.spawn(async move |this, cx| match final_checkpoint.await {
612 Ok(final_checkpoint) => {
613 let equal = git_store
614 .update(cx, |store, cx| {
615 store.compare_checkpoints(
616 pending_checkpoint.git_checkpoint.clone(),
617 final_checkpoint.clone(),
618 cx,
619 )
620 })?
621 .await
622 .unwrap_or(false);
623
624 if !equal {
625 this.update(cx, |this, cx| {
626 this.insert_checkpoint(pending_checkpoint, cx)
627 })?;
628 }
629
630 Ok(())
631 }
632 Err(_) => this.update(cx, |this, cx| {
633 this.insert_checkpoint(pending_checkpoint, cx)
634 }),
635 })
636 .detach();
637 }
638
639 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
640 self.checkpoints_by_message
641 .insert(checkpoint.message_id, checkpoint);
642 cx.emit(ThreadEvent::CheckpointChanged);
643 cx.notify();
644 }
645
646 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
647 self.last_restore_checkpoint.as_ref()
648 }
649
650 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
651 let Some(message_ix) = self
652 .messages
653 .iter()
654 .rposition(|message| message.id == message_id)
655 else {
656 return;
657 };
658 for deleted_message in self.messages.drain(message_ix..) {
659 self.checkpoints_by_message.remove(&deleted_message.id);
660 }
661 cx.notify();
662 }
663
664 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
665 self.messages
666 .iter()
667 .find(|message| message.id == id)
668 .into_iter()
669 .flat_map(|message| message.loaded_context.contexts.iter())
670 }
671
672 pub fn is_turn_end(&self, ix: usize) -> bool {
673 if self.messages.is_empty() {
674 return false;
675 }
676
677 if !self.is_generating() && ix == self.messages.len() - 1 {
678 return true;
679 }
680
681 let Some(message) = self.messages.get(ix) else {
682 return false;
683 };
684
685 if message.role != Role::Assistant {
686 return false;
687 }
688
689 self.messages
690 .get(ix + 1)
691 .and_then(|message| {
692 self.message(message.id)
693 .map(|next_message| next_message.role == Role::User)
694 })
695 .unwrap_or(false)
696 }
697
698 /// Returns whether all of the tool uses have finished running.
699 pub fn all_tools_finished(&self) -> bool {
700 // If the only pending tool uses left are the ones with errors, then
701 // that means that we've finished running all of the pending tools.
702 self.tool_use
703 .pending_tool_uses()
704 .iter()
705 .all(|tool_use| tool_use.status.is_error())
706 }
707
708 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
709 self.tool_use.tool_uses_for_message(id, cx)
710 }
711
712 pub fn tool_results_for_message(
713 &self,
714 assistant_message_id: MessageId,
715 ) -> Vec<&LanguageModelToolResult> {
716 self.tool_use.tool_results_for_message(assistant_message_id)
717 }
718
719 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
720 self.tool_use.tool_result(id)
721 }
722
723 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
724 Some(&self.tool_use.tool_result(id)?.content)
725 }
726
727 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
728 self.tool_use.tool_result_card(id).cloned()
729 }
730
731 pub fn insert_user_message(
732 &mut self,
733 text: impl Into<String>,
734 loaded_context: ContextLoadResult,
735 git_checkpoint: Option<GitStoreCheckpoint>,
736 cx: &mut Context<Self>,
737 ) -> MessageId {
738 if !loaded_context.referenced_buffers.is_empty() {
739 self.action_log.update(cx, |log, cx| {
740 for buffer in loaded_context.referenced_buffers {
741 log.track_buffer(buffer, cx);
742 }
743 });
744 }
745
746 let message_id = self.insert_message(
747 Role::User,
748 vec![MessageSegment::Text(text.into())],
749 loaded_context.loaded_context,
750 cx,
751 );
752
753 if let Some(git_checkpoint) = git_checkpoint {
754 self.pending_checkpoint = Some(ThreadCheckpoint {
755 message_id,
756 git_checkpoint,
757 });
758 }
759
760 self.auto_capture_telemetry(cx);
761
762 message_id
763 }
764
765 pub fn insert_assistant_message(
766 &mut self,
767 segments: Vec<MessageSegment>,
768 cx: &mut Context<Self>,
769 ) -> MessageId {
770 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
771 }
772
773 pub fn insert_message(
774 &mut self,
775 role: Role,
776 segments: Vec<MessageSegment>,
777 loaded_context: LoadedContext,
778 cx: &mut Context<Self>,
779 ) -> MessageId {
780 let id = self.next_message_id.post_inc();
781 self.messages.push(Message {
782 id,
783 role,
784 segments,
785 loaded_context,
786 });
787 self.touch_updated_at();
788 cx.emit(ThreadEvent::MessageAdded(id));
789 id
790 }
791
792 pub fn edit_message(
793 &mut self,
794 id: MessageId,
795 new_role: Role,
796 new_segments: Vec<MessageSegment>,
797 cx: &mut Context<Self>,
798 ) -> bool {
799 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
800 return false;
801 };
802 message.role = new_role;
803 message.segments = new_segments;
804 self.touch_updated_at();
805 cx.emit(ThreadEvent::MessageEdited(id));
806 true
807 }
808
809 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
810 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
811 return false;
812 };
813 self.messages.remove(index);
814 self.touch_updated_at();
815 cx.emit(ThreadEvent::MessageDeleted(id));
816 true
817 }
818
819 /// Returns the representation of this [`Thread`] in a textual form.
820 ///
821 /// This is the representation we use when attaching a thread as context to another thread.
822 pub fn text(&self) -> String {
823 let mut text = String::new();
824
825 for message in &self.messages {
826 text.push_str(match message.role {
827 language_model::Role::User => "User:",
828 language_model::Role::Assistant => "Assistant:",
829 language_model::Role::System => "System:",
830 });
831 text.push('\n');
832
833 for segment in &message.segments {
834 match segment {
835 MessageSegment::Text(content) => text.push_str(content),
836 MessageSegment::Thinking { text: content, .. } => {
837 text.push_str(&format!("<think>{}</think>", content))
838 }
839 MessageSegment::RedactedThinking(_) => {}
840 }
841 }
842 text.push('\n');
843 }
844
845 text
846 }
847
848 /// Serializes this thread into a format for storage or telemetry.
849 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
850 let initial_project_snapshot = self.initial_project_snapshot.clone();
851 cx.spawn(async move |this, cx| {
852 let initial_project_snapshot = initial_project_snapshot.await;
853 this.read_with(cx, |this, cx| SerializedThread {
854 version: SerializedThread::VERSION.to_string(),
855 summary: this.summary_or_default(),
856 updated_at: this.updated_at(),
857 messages: this
858 .messages()
859 .map(|message| SerializedMessage {
860 id: message.id,
861 role: message.role,
862 segments: message
863 .segments
864 .iter()
865 .map(|segment| match segment {
866 MessageSegment::Text(text) => {
867 SerializedMessageSegment::Text { text: text.clone() }
868 }
869 MessageSegment::Thinking { text, signature } => {
870 SerializedMessageSegment::Thinking {
871 text: text.clone(),
872 signature: signature.clone(),
873 }
874 }
875 MessageSegment::RedactedThinking(data) => {
876 SerializedMessageSegment::RedactedThinking {
877 data: data.clone(),
878 }
879 }
880 })
881 .collect(),
882 tool_uses: this
883 .tool_uses_for_message(message.id, cx)
884 .into_iter()
885 .map(|tool_use| SerializedToolUse {
886 id: tool_use.id,
887 name: tool_use.name,
888 input: tool_use.input,
889 })
890 .collect(),
891 tool_results: this
892 .tool_results_for_message(message.id)
893 .into_iter()
894 .map(|tool_result| SerializedToolResult {
895 tool_use_id: tool_result.tool_use_id.clone(),
896 is_error: tool_result.is_error,
897 content: tool_result.content.clone(),
898 })
899 .collect(),
900 context: message.loaded_context.text.clone(),
901 })
902 .collect(),
903 initial_project_snapshot,
904 cumulative_token_usage: this.cumulative_token_usage,
905 request_token_usage: this.request_token_usage.clone(),
906 detailed_summary_state: this.detailed_summary_state.clone(),
907 exceeded_window_error: this.exceeded_window_error.clone(),
908 })
909 })
910 }
911
912 pub fn remaining_turns(&self) -> u32 {
913 self.remaining_turns
914 }
915
916 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
917 self.remaining_turns = remaining_turns;
918 }
919
920 pub fn send_to_model(
921 &mut self,
922 model: Arc<dyn LanguageModel>,
923 window: Option<AnyWindowHandle>,
924 cx: &mut Context<Self>,
925 ) {
926 if self.remaining_turns == 0 {
927 return;
928 }
929
930 self.remaining_turns -= 1;
931
932 let mut request = self.to_completion_request(cx);
933 if model.supports_tools() {
934 request.tools = self
935 .tools()
936 .read(cx)
937 .enabled_tools(cx)
938 .into_iter()
939 .filter_map(|tool| {
940 // Skip tools that cannot be supported
941 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
942 Some(LanguageModelRequestTool {
943 name: tool.name(),
944 description: tool.description(),
945 input_schema,
946 })
947 })
948 .collect();
949 }
950
951 self.stream_completion(request, model, window, cx);
952 }
953
954 pub fn used_tools_since_last_user_message(&self) -> bool {
955 for message in self.messages.iter().rev() {
956 if self.tool_use.message_has_tool_results(message.id) {
957 return true;
958 } else if message.role == Role::User {
959 return false;
960 }
961 }
962
963 false
964 }
965
966 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
967 let mut request = LanguageModelRequest {
968 thread_id: Some(self.id.to_string()),
969 prompt_id: Some(self.last_prompt_id.to_string()),
970 messages: vec![],
971 tools: Vec::new(),
972 stop: Vec::new(),
973 temperature: None,
974 };
975
976 if let Some(project_context) = self.project_context.borrow().as_ref() {
977 match self
978 .prompt_builder
979 .generate_assistant_system_prompt(project_context)
980 {
981 Err(err) => {
982 let message = format!("{err:?}").into();
983 log::error!("{message}");
984 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
985 header: "Error generating system prompt".into(),
986 message,
987 }));
988 }
989 Ok(system_prompt) => {
990 request.messages.push(LanguageModelRequestMessage {
991 role: Role::System,
992 content: vec![MessageContent::Text(system_prompt)],
993 cache: true,
994 });
995 }
996 }
997 } else {
998 let message = "Context for system prompt unexpectedly not ready.".into();
999 log::error!("{message}");
1000 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1001 header: "Error generating system prompt".into(),
1002 message,
1003 }));
1004 }
1005
1006 for message in &self.messages {
1007 let mut request_message = LanguageModelRequestMessage {
1008 role: message.role,
1009 content: Vec::new(),
1010 cache: false,
1011 };
1012
1013 message
1014 .loaded_context
1015 .add_to_request_message(&mut request_message);
1016
1017 for segment in &message.segments {
1018 match segment {
1019 MessageSegment::Text(text) => {
1020 if !text.is_empty() {
1021 request_message
1022 .content
1023 .push(MessageContent::Text(text.into()));
1024 }
1025 }
1026 MessageSegment::Thinking { text, signature } => {
1027 if !text.is_empty() {
1028 request_message.content.push(MessageContent::Thinking {
1029 text: text.into(),
1030 signature: signature.clone(),
1031 });
1032 }
1033 }
1034 MessageSegment::RedactedThinking(data) => {
1035 request_message
1036 .content
1037 .push(MessageContent::RedactedThinking(data.clone()));
1038 }
1039 };
1040 }
1041
1042 self.tool_use
1043 .attach_tool_uses(message.id, &mut request_message);
1044
1045 request.messages.push(request_message);
1046
1047 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1048 request.messages.push(tool_results_message);
1049 }
1050 }
1051
1052 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1053 if let Some(last) = request.messages.last_mut() {
1054 last.cache = true;
1055 }
1056
1057 self.attached_tracked_files_state(&mut request.messages, cx);
1058
1059 request
1060 }
1061
1062 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1063 let mut request = LanguageModelRequest {
1064 thread_id: None,
1065 prompt_id: None,
1066 messages: vec![],
1067 tools: Vec::new(),
1068 stop: Vec::new(),
1069 temperature: None,
1070 };
1071
1072 for message in &self.messages {
1073 let mut request_message = LanguageModelRequestMessage {
1074 role: message.role,
1075 content: Vec::new(),
1076 cache: false,
1077 };
1078
1079 for segment in &message.segments {
1080 match segment {
1081 MessageSegment::Text(text) => request_message
1082 .content
1083 .push(MessageContent::Text(text.clone())),
1084 MessageSegment::Thinking { .. } => {}
1085 MessageSegment::RedactedThinking(_) => {}
1086 }
1087 }
1088
1089 if request_message.content.is_empty() {
1090 continue;
1091 }
1092
1093 request.messages.push(request_message);
1094 }
1095
1096 request.messages.push(LanguageModelRequestMessage {
1097 role: Role::User,
1098 content: vec![MessageContent::Text(added_user_message)],
1099 cache: false,
1100 });
1101
1102 request
1103 }
1104
1105 fn attached_tracked_files_state(
1106 &self,
1107 messages: &mut Vec<LanguageModelRequestMessage>,
1108 cx: &App,
1109 ) {
1110 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1111
1112 let mut stale_message = String::new();
1113
1114 let action_log = self.action_log.read(cx);
1115
1116 for stale_file in action_log.stale_buffers(cx) {
1117 let Some(file) = stale_file.read(cx).file() else {
1118 continue;
1119 };
1120
1121 if stale_message.is_empty() {
1122 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1123 }
1124
1125 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1126 }
1127
1128 let mut content = Vec::with_capacity(2);
1129
1130 if !stale_message.is_empty() {
1131 content.push(stale_message.into());
1132 }
1133
1134 if !content.is_empty() {
1135 let context_message = LanguageModelRequestMessage {
1136 role: Role::User,
1137 content,
1138 cache: false,
1139 };
1140
1141 messages.push(context_message);
1142 }
1143 }
1144
1145 pub fn stream_completion(
1146 &mut self,
1147 request: LanguageModelRequest,
1148 model: Arc<dyn LanguageModel>,
1149 window: Option<AnyWindowHandle>,
1150 cx: &mut Context<Self>,
1151 ) {
1152 let pending_completion_id = post_inc(&mut self.completion_count);
1153 let mut request_callback_parameters = if self.request_callback.is_some() {
1154 Some((request.clone(), Vec::new()))
1155 } else {
1156 None
1157 };
1158 let prompt_id = self.last_prompt_id.clone();
1159 let tool_use_metadata = ToolUseMetadata {
1160 model: model.clone(),
1161 thread_id: self.id.clone(),
1162 prompt_id: prompt_id.clone(),
1163 };
1164
1165 let task = cx.spawn(async move |thread, cx| {
1166 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1167 let initial_token_usage =
1168 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1169 let stream_completion = async {
1170 let (mut events, usage) = stream_completion_future.await?;
1171
1172 let mut stop_reason = StopReason::EndTurn;
1173 let mut current_token_usage = TokenUsage::default();
1174
1175 if let Some(usage) = usage {
1176 thread
1177 .update(cx, |_thread, cx| {
1178 cx.emit(ThreadEvent::UsageUpdated(usage));
1179 })
1180 .ok();
1181 }
1182
1183 let mut request_assistant_message_id = None;
1184
1185 while let Some(event) = events.next().await {
1186 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1187 response_events
1188 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1189 }
1190
1191 thread.update(cx, |thread, cx| {
1192 let event = match event {
1193 Ok(event) => event,
1194 Err(LanguageModelCompletionError::BadInputJson {
1195 id,
1196 tool_name,
1197 raw_input: invalid_input_json,
1198 json_parse_error,
1199 }) => {
1200 thread.receive_invalid_tool_json(
1201 id,
1202 tool_name,
1203 invalid_input_json,
1204 json_parse_error,
1205 window,
1206 cx,
1207 );
1208 return Ok(());
1209 }
1210 Err(LanguageModelCompletionError::Other(error)) => {
1211 return Err(error);
1212 }
1213 };
1214
1215 match event {
1216 LanguageModelCompletionEvent::StartMessage { .. } => {
1217 request_assistant_message_id =
1218 Some(thread.insert_assistant_message(
1219 vec![MessageSegment::Text(String::new())],
1220 cx,
1221 ));
1222 }
1223 LanguageModelCompletionEvent::Stop(reason) => {
1224 stop_reason = reason;
1225 }
1226 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1227 thread.update_token_usage_at_last_message(token_usage);
1228 thread.cumulative_token_usage = thread.cumulative_token_usage
1229 + token_usage
1230 - current_token_usage;
1231 current_token_usage = token_usage;
1232 }
1233 LanguageModelCompletionEvent::Text(chunk) => {
1234 cx.emit(ThreadEvent::ReceivedTextChunk);
1235 if let Some(last_message) = thread.messages.last_mut() {
1236 if last_message.role == Role::Assistant
1237 && !thread.tool_use.has_tool_results(last_message.id)
1238 {
1239 last_message.push_text(&chunk);
1240 cx.emit(ThreadEvent::StreamedAssistantText(
1241 last_message.id,
1242 chunk,
1243 ));
1244 } else {
1245 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1246 // of a new Assistant response.
1247 //
1248 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1249 // will result in duplicating the text of the chunk in the rendered Markdown.
1250 request_assistant_message_id =
1251 Some(thread.insert_assistant_message(
1252 vec![MessageSegment::Text(chunk.to_string())],
1253 cx,
1254 ));
1255 };
1256 }
1257 }
1258 LanguageModelCompletionEvent::Thinking {
1259 text: chunk,
1260 signature,
1261 } => {
1262 if let Some(last_message) = thread.messages.last_mut() {
1263 if last_message.role == Role::Assistant
1264 && !thread.tool_use.has_tool_results(last_message.id)
1265 {
1266 last_message.push_thinking(&chunk, signature);
1267 cx.emit(ThreadEvent::StreamedAssistantThinking(
1268 last_message.id,
1269 chunk,
1270 ));
1271 } else {
1272 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1273 // of a new Assistant response.
1274 //
1275 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1276 // will result in duplicating the text of the chunk in the rendered Markdown.
1277 request_assistant_message_id =
1278 Some(thread.insert_assistant_message(
1279 vec![MessageSegment::Thinking {
1280 text: chunk.to_string(),
1281 signature,
1282 }],
1283 cx,
1284 ));
1285 };
1286 }
1287 }
1288 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1289 let last_assistant_message_id = request_assistant_message_id
1290 .unwrap_or_else(|| {
1291 let new_assistant_message_id =
1292 thread.insert_assistant_message(vec![], cx);
1293 request_assistant_message_id =
1294 Some(new_assistant_message_id);
1295 new_assistant_message_id
1296 });
1297
1298 let tool_use_id = tool_use.id.clone();
1299 let streamed_input = if tool_use.is_input_complete {
1300 None
1301 } else {
1302 Some((&tool_use.input).clone())
1303 };
1304
1305 let ui_text = thread.tool_use.request_tool_use(
1306 last_assistant_message_id,
1307 tool_use,
1308 tool_use_metadata.clone(),
1309 cx,
1310 );
1311
1312 if let Some(input) = streamed_input {
1313 cx.emit(ThreadEvent::StreamedToolUse {
1314 tool_use_id,
1315 ui_text,
1316 input,
1317 });
1318 }
1319 }
1320 }
1321
1322 thread.touch_updated_at();
1323 cx.emit(ThreadEvent::StreamedCompletion);
1324 cx.notify();
1325
1326 thread.auto_capture_telemetry(cx);
1327 Ok(())
1328 })??;
1329
1330 smol::future::yield_now().await;
1331 }
1332
1333 thread.update(cx, |thread, cx| {
1334 thread
1335 .pending_completions
1336 .retain(|completion| completion.id != pending_completion_id);
1337
1338 // If there is a response without tool use, summarize the message. Otherwise,
1339 // allow two tool uses before summarizing.
1340 if thread.summary.is_none()
1341 && thread.messages.len() >= 2
1342 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1343 {
1344 thread.summarize(cx);
1345 }
1346 })?;
1347
1348 anyhow::Ok(stop_reason)
1349 };
1350
1351 let result = stream_completion.await;
1352
1353 thread
1354 .update(cx, |thread, cx| {
1355 thread.finalize_pending_checkpoint(cx);
1356 match result.as_ref() {
1357 Ok(stop_reason) => match stop_reason {
1358 StopReason::ToolUse => {
1359 let tool_uses = thread.use_pending_tools(window, cx);
1360 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1361 }
1362 StopReason::EndTurn => {}
1363 StopReason::MaxTokens => {}
1364 },
1365 Err(error) => {
1366 if error.is::<PaymentRequiredError>() {
1367 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1368 } else if error.is::<MaxMonthlySpendReachedError>() {
1369 cx.emit(ThreadEvent::ShowError(
1370 ThreadError::MaxMonthlySpendReached,
1371 ));
1372 } else if let Some(error) =
1373 error.downcast_ref::<ModelRequestLimitReachedError>()
1374 {
1375 cx.emit(ThreadEvent::ShowError(
1376 ThreadError::ModelRequestLimitReached { plan: error.plan },
1377 ));
1378 } else if let Some(known_error) =
1379 error.downcast_ref::<LanguageModelKnownError>()
1380 {
1381 match known_error {
1382 LanguageModelKnownError::ContextWindowLimitExceeded {
1383 tokens,
1384 } => {
1385 thread.exceeded_window_error = Some(ExceededWindowError {
1386 model_id: model.id(),
1387 token_count: *tokens,
1388 });
1389 cx.notify();
1390 }
1391 }
1392 } else {
1393 let error_message = error
1394 .chain()
1395 .map(|err| err.to_string())
1396 .collect::<Vec<_>>()
1397 .join("\n");
1398 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1399 header: "Error interacting with language model".into(),
1400 message: SharedString::from(error_message.clone()),
1401 }));
1402 }
1403
1404 thread.cancel_last_completion(window, cx);
1405 }
1406 }
1407 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1408
1409 if let Some((request_callback, (request, response_events))) = thread
1410 .request_callback
1411 .as_mut()
1412 .zip(request_callback_parameters.as_ref())
1413 {
1414 request_callback(request, response_events);
1415 }
1416
1417 thread.auto_capture_telemetry(cx);
1418
1419 if let Ok(initial_usage) = initial_token_usage {
1420 let usage = thread.cumulative_token_usage - initial_usage;
1421
1422 telemetry::event!(
1423 "Assistant Thread Completion",
1424 thread_id = thread.id().to_string(),
1425 prompt_id = prompt_id,
1426 model = model.telemetry_id(),
1427 model_provider = model.provider_id().to_string(),
1428 input_tokens = usage.input_tokens,
1429 output_tokens = usage.output_tokens,
1430 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1431 cache_read_input_tokens = usage.cache_read_input_tokens,
1432 );
1433 }
1434 })
1435 .ok();
1436 });
1437
1438 self.pending_completions.push(PendingCompletion {
1439 id: pending_completion_id,
1440 _task: task,
1441 });
1442 }
1443
1444 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1445 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1446 return;
1447 };
1448
1449 if !model.provider.is_authenticated(cx) {
1450 return;
1451 }
1452
1453 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1454 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1455 If the conversation is about a specific subject, include it in the title. \
1456 Be descriptive. DO NOT speak in the first person.";
1457
1458 let request = self.to_summarize_request(added_user_message.into());
1459
1460 self.pending_summary = cx.spawn(async move |this, cx| {
1461 async move {
1462 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1463 let (mut messages, usage) = stream.await?;
1464
1465 if let Some(usage) = usage {
1466 this.update(cx, |_thread, cx| {
1467 cx.emit(ThreadEvent::UsageUpdated(usage));
1468 })
1469 .ok();
1470 }
1471
1472 let mut new_summary = String::new();
1473 while let Some(message) = messages.stream.next().await {
1474 let text = message?;
1475 let mut lines = text.lines();
1476 new_summary.extend(lines.next());
1477
1478 // Stop if the LLM generated multiple lines.
1479 if lines.next().is_some() {
1480 break;
1481 }
1482 }
1483
1484 this.update(cx, |this, cx| {
1485 if !new_summary.is_empty() {
1486 this.summary = Some(new_summary.into());
1487 }
1488
1489 cx.emit(ThreadEvent::SummaryGenerated);
1490 })?;
1491
1492 anyhow::Ok(())
1493 }
1494 .log_err()
1495 .await
1496 });
1497 }
1498
1499 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1500 let last_message_id = self.messages.last().map(|message| message.id)?;
1501
1502 match &self.detailed_summary_state {
1503 DetailedSummaryState::Generating { message_id, .. }
1504 | DetailedSummaryState::Generated { message_id, .. }
1505 if *message_id == last_message_id =>
1506 {
1507 // Already up-to-date
1508 return None;
1509 }
1510 _ => {}
1511 }
1512
1513 let ConfiguredModel { model, provider } =
1514 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1515
1516 if !provider.is_authenticated(cx) {
1517 return None;
1518 }
1519
1520 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1521 1. A brief overview of what was discussed\n\
1522 2. Key facts or information discovered\n\
1523 3. Outcomes or conclusions reached\n\
1524 4. Any action items or next steps if any\n\
1525 Format it in Markdown with headings and bullet points.";
1526
1527 let request = self.to_summarize_request(added_user_message.into());
1528
1529 let task = cx.spawn(async move |thread, cx| {
1530 let stream = model.stream_completion_text(request, &cx);
1531 let Some(mut messages) = stream.await.log_err() else {
1532 thread
1533 .update(cx, |this, _cx| {
1534 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1535 })
1536 .log_err();
1537
1538 return;
1539 };
1540
1541 let mut new_detailed_summary = String::new();
1542
1543 while let Some(chunk) = messages.stream.next().await {
1544 if let Some(chunk) = chunk.log_err() {
1545 new_detailed_summary.push_str(&chunk);
1546 }
1547 }
1548
1549 thread
1550 .update(cx, |this, _cx| {
1551 this.detailed_summary_state = DetailedSummaryState::Generated {
1552 text: new_detailed_summary.into(),
1553 message_id: last_message_id,
1554 };
1555 })
1556 .log_err();
1557 });
1558
1559 self.detailed_summary_state = DetailedSummaryState::Generating {
1560 message_id: last_message_id,
1561 };
1562
1563 Some(task)
1564 }
1565
1566 pub fn is_generating_detailed_summary(&self) -> bool {
1567 matches!(
1568 self.detailed_summary_state,
1569 DetailedSummaryState::Generating { .. }
1570 )
1571 }
1572
1573 pub fn use_pending_tools(
1574 &mut self,
1575 window: Option<AnyWindowHandle>,
1576 cx: &mut Context<Self>,
1577 ) -> Vec<PendingToolUse> {
1578 self.auto_capture_telemetry(cx);
1579 let request = self.to_completion_request(cx);
1580 let messages = Arc::new(request.messages);
1581 let pending_tool_uses = self
1582 .tool_use
1583 .pending_tool_uses()
1584 .into_iter()
1585 .filter(|tool_use| tool_use.status.is_idle())
1586 .cloned()
1587 .collect::<Vec<_>>();
1588
1589 for tool_use in pending_tool_uses.iter() {
1590 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1591 if tool.needs_confirmation(&tool_use.input, cx)
1592 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1593 {
1594 self.tool_use.confirm_tool_use(
1595 tool_use.id.clone(),
1596 tool_use.ui_text.clone(),
1597 tool_use.input.clone(),
1598 messages.clone(),
1599 tool,
1600 );
1601 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1602 } else {
1603 self.run_tool(
1604 tool_use.id.clone(),
1605 tool_use.ui_text.clone(),
1606 tool_use.input.clone(),
1607 &messages,
1608 tool,
1609 window,
1610 cx,
1611 );
1612 }
1613 }
1614 }
1615
1616 pending_tool_uses
1617 }
1618
1619 pub fn receive_invalid_tool_json(
1620 &mut self,
1621 tool_use_id: LanguageModelToolUseId,
1622 tool_name: Arc<str>,
1623 invalid_json: Arc<str>,
1624 error: String,
1625 window: Option<AnyWindowHandle>,
1626 cx: &mut Context<Thread>,
1627 ) {
1628 log::error!("The model returned invalid input JSON: {invalid_json}");
1629
1630 let pending_tool_use = self.tool_use.insert_tool_output(
1631 tool_use_id.clone(),
1632 tool_name,
1633 Err(anyhow!("Error parsing input JSON: {error}")),
1634 cx,
1635 );
1636 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1637 pending_tool_use.ui_text.clone()
1638 } else {
1639 log::error!(
1640 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1641 );
1642 format!("Unknown tool {}", tool_use_id).into()
1643 };
1644
1645 cx.emit(ThreadEvent::InvalidToolInput {
1646 tool_use_id: tool_use_id.clone(),
1647 ui_text,
1648 invalid_input_json: invalid_json,
1649 });
1650
1651 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1652 }
1653
1654 pub fn run_tool(
1655 &mut self,
1656 tool_use_id: LanguageModelToolUseId,
1657 ui_text: impl Into<SharedString>,
1658 input: serde_json::Value,
1659 messages: &[LanguageModelRequestMessage],
1660 tool: Arc<dyn Tool>,
1661 window: Option<AnyWindowHandle>,
1662 cx: &mut Context<Thread>,
1663 ) {
1664 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1665 self.tool_use
1666 .run_pending_tool(tool_use_id, ui_text.into(), task);
1667 }
1668
1669 fn spawn_tool_use(
1670 &mut self,
1671 tool_use_id: LanguageModelToolUseId,
1672 messages: &[LanguageModelRequestMessage],
1673 input: serde_json::Value,
1674 tool: Arc<dyn Tool>,
1675 window: Option<AnyWindowHandle>,
1676 cx: &mut Context<Thread>,
1677 ) -> Task<()> {
1678 let tool_name: Arc<str> = tool.name().into();
1679
1680 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1681 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1682 } else {
1683 tool.run(
1684 input,
1685 messages,
1686 self.project.clone(),
1687 self.action_log.clone(),
1688 window,
1689 cx,
1690 )
1691 };
1692
1693 // Store the card separately if it exists
1694 if let Some(card) = tool_result.card.clone() {
1695 self.tool_use
1696 .insert_tool_result_card(tool_use_id.clone(), card);
1697 }
1698
1699 cx.spawn({
1700 async move |thread: WeakEntity<Thread>, cx| {
1701 let output = tool_result.output.await;
1702
1703 thread
1704 .update(cx, |thread, cx| {
1705 let pending_tool_use = thread.tool_use.insert_tool_output(
1706 tool_use_id.clone(),
1707 tool_name,
1708 output,
1709 cx,
1710 );
1711 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1712 })
1713 .ok();
1714 }
1715 })
1716 }
1717
1718 fn tool_finished(
1719 &mut self,
1720 tool_use_id: LanguageModelToolUseId,
1721 pending_tool_use: Option<PendingToolUse>,
1722 canceled: bool,
1723 window: Option<AnyWindowHandle>,
1724 cx: &mut Context<Self>,
1725 ) {
1726 if self.all_tools_finished() {
1727 let model_registry = LanguageModelRegistry::read_global(cx);
1728 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1729 if !canceled {
1730 self.send_to_model(model, window, cx);
1731 }
1732 self.auto_capture_telemetry(cx);
1733 }
1734 }
1735
1736 cx.emit(ThreadEvent::ToolFinished {
1737 tool_use_id,
1738 pending_tool_use,
1739 });
1740 }
1741
1742 /// Cancels the last pending completion, if there are any pending.
1743 ///
1744 /// Returns whether a completion was canceled.
1745 pub fn cancel_last_completion(
1746 &mut self,
1747 window: Option<AnyWindowHandle>,
1748 cx: &mut Context<Self>,
1749 ) -> bool {
1750 let mut canceled = self.pending_completions.pop().is_some();
1751
1752 for pending_tool_use in self.tool_use.cancel_pending() {
1753 canceled = true;
1754 self.tool_finished(
1755 pending_tool_use.id.clone(),
1756 Some(pending_tool_use),
1757 true,
1758 window,
1759 cx,
1760 );
1761 }
1762
1763 self.finalize_pending_checkpoint(cx);
1764 canceled
1765 }
1766
1767 pub fn feedback(&self) -> Option<ThreadFeedback> {
1768 self.feedback
1769 }
1770
1771 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1772 self.message_feedback.get(&message_id).copied()
1773 }
1774
1775 pub fn report_message_feedback(
1776 &mut self,
1777 message_id: MessageId,
1778 feedback: ThreadFeedback,
1779 cx: &mut Context<Self>,
1780 ) -> Task<Result<()>> {
1781 if self.message_feedback.get(&message_id) == Some(&feedback) {
1782 return Task::ready(Ok(()));
1783 }
1784
1785 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1786 let serialized_thread = self.serialize(cx);
1787 let thread_id = self.id().clone();
1788 let client = self.project.read(cx).client();
1789
1790 let enabled_tool_names: Vec<String> = self
1791 .tools()
1792 .read(cx)
1793 .enabled_tools(cx)
1794 .iter()
1795 .map(|tool| tool.name().to_string())
1796 .collect();
1797
1798 self.message_feedback.insert(message_id, feedback);
1799
1800 cx.notify();
1801
1802 let message_content = self
1803 .message(message_id)
1804 .map(|msg| msg.to_string())
1805 .unwrap_or_default();
1806
1807 cx.background_spawn(async move {
1808 let final_project_snapshot = final_project_snapshot.await;
1809 let serialized_thread = serialized_thread.await?;
1810 let thread_data =
1811 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1812
1813 let rating = match feedback {
1814 ThreadFeedback::Positive => "positive",
1815 ThreadFeedback::Negative => "negative",
1816 };
1817 telemetry::event!(
1818 "Assistant Thread Rated",
1819 rating,
1820 thread_id,
1821 enabled_tool_names,
1822 message_id = message_id.0,
1823 message_content,
1824 thread_data,
1825 final_project_snapshot
1826 );
1827 client.telemetry().flush_events().await;
1828
1829 Ok(())
1830 })
1831 }
1832
1833 pub fn report_feedback(
1834 &mut self,
1835 feedback: ThreadFeedback,
1836 cx: &mut Context<Self>,
1837 ) -> Task<Result<()>> {
1838 let last_assistant_message_id = self
1839 .messages
1840 .iter()
1841 .rev()
1842 .find(|msg| msg.role == Role::Assistant)
1843 .map(|msg| msg.id);
1844
1845 if let Some(message_id) = last_assistant_message_id {
1846 self.report_message_feedback(message_id, feedback, cx)
1847 } else {
1848 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1849 let serialized_thread = self.serialize(cx);
1850 let thread_id = self.id().clone();
1851 let client = self.project.read(cx).client();
1852 self.feedback = Some(feedback);
1853 cx.notify();
1854
1855 cx.background_spawn(async move {
1856 let final_project_snapshot = final_project_snapshot.await;
1857 let serialized_thread = serialized_thread.await?;
1858 let thread_data = serde_json::to_value(serialized_thread)
1859 .unwrap_or_else(|_| serde_json::Value::Null);
1860
1861 let rating = match feedback {
1862 ThreadFeedback::Positive => "positive",
1863 ThreadFeedback::Negative => "negative",
1864 };
1865 telemetry::event!(
1866 "Assistant Thread Rated",
1867 rating,
1868 thread_id,
1869 thread_data,
1870 final_project_snapshot
1871 );
1872 client.telemetry().flush_events().await;
1873
1874 Ok(())
1875 })
1876 }
1877 }
1878
1879 /// Create a snapshot of the current project state including git information and unsaved buffers.
1880 fn project_snapshot(
1881 project: Entity<Project>,
1882 cx: &mut Context<Self>,
1883 ) -> Task<Arc<ProjectSnapshot>> {
1884 let git_store = project.read(cx).git_store().clone();
1885 let worktree_snapshots: Vec<_> = project
1886 .read(cx)
1887 .visible_worktrees(cx)
1888 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1889 .collect();
1890
1891 cx.spawn(async move |_, cx| {
1892 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1893
1894 let mut unsaved_buffers = Vec::new();
1895 cx.update(|app_cx| {
1896 let buffer_store = project.read(app_cx).buffer_store();
1897 for buffer_handle in buffer_store.read(app_cx).buffers() {
1898 let buffer = buffer_handle.read(app_cx);
1899 if buffer.is_dirty() {
1900 if let Some(file) = buffer.file() {
1901 let path = file.path().to_string_lossy().to_string();
1902 unsaved_buffers.push(path);
1903 }
1904 }
1905 }
1906 })
1907 .ok();
1908
1909 Arc::new(ProjectSnapshot {
1910 worktree_snapshots,
1911 unsaved_buffer_paths: unsaved_buffers,
1912 timestamp: Utc::now(),
1913 })
1914 })
1915 }
1916
1917 fn worktree_snapshot(
1918 worktree: Entity<project::Worktree>,
1919 git_store: Entity<GitStore>,
1920 cx: &App,
1921 ) -> Task<WorktreeSnapshot> {
1922 cx.spawn(async move |cx| {
1923 // Get worktree path and snapshot
1924 let worktree_info = cx.update(|app_cx| {
1925 let worktree = worktree.read(app_cx);
1926 let path = worktree.abs_path().to_string_lossy().to_string();
1927 let snapshot = worktree.snapshot();
1928 (path, snapshot)
1929 });
1930
1931 let Ok((worktree_path, _snapshot)) = worktree_info else {
1932 return WorktreeSnapshot {
1933 worktree_path: String::new(),
1934 git_state: None,
1935 };
1936 };
1937
1938 let git_state = git_store
1939 .update(cx, |git_store, cx| {
1940 git_store
1941 .repositories()
1942 .values()
1943 .find(|repo| {
1944 repo.read(cx)
1945 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1946 .is_some()
1947 })
1948 .cloned()
1949 })
1950 .ok()
1951 .flatten()
1952 .map(|repo| {
1953 repo.update(cx, |repo, _| {
1954 let current_branch =
1955 repo.branch.as_ref().map(|branch| branch.name.to_string());
1956 repo.send_job(None, |state, _| async move {
1957 let RepositoryState::Local { backend, .. } = state else {
1958 return GitState {
1959 remote_url: None,
1960 head_sha: None,
1961 current_branch,
1962 diff: None,
1963 };
1964 };
1965
1966 let remote_url = backend.remote_url("origin");
1967 let head_sha = backend.head_sha().await;
1968 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1969
1970 GitState {
1971 remote_url,
1972 head_sha,
1973 current_branch,
1974 diff,
1975 }
1976 })
1977 })
1978 });
1979
1980 let git_state = match git_state {
1981 Some(git_state) => match git_state.ok() {
1982 Some(git_state) => git_state.await.ok(),
1983 None => None,
1984 },
1985 None => None,
1986 };
1987
1988 WorktreeSnapshot {
1989 worktree_path,
1990 git_state,
1991 }
1992 })
1993 }
1994
1995 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1996 let mut markdown = Vec::new();
1997
1998 if let Some(summary) = self.summary() {
1999 writeln!(markdown, "# {summary}\n")?;
2000 };
2001
2002 for message in self.messages() {
2003 writeln!(
2004 markdown,
2005 "## {role}\n",
2006 role = match message.role {
2007 Role::User => "User",
2008 Role::Assistant => "Assistant",
2009 Role::System => "System",
2010 }
2011 )?;
2012
2013 if !message.loaded_context.text.is_empty() {
2014 writeln!(markdown, "{}", message.loaded_context.text)?;
2015 }
2016
2017 if !message.loaded_context.images.is_empty() {
2018 writeln!(
2019 markdown,
2020 "\n{} images attached as context.\n",
2021 message.loaded_context.images.len()
2022 )?;
2023 }
2024
2025 for segment in &message.segments {
2026 match segment {
2027 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2028 MessageSegment::Thinking { text, .. } => {
2029 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2030 }
2031 MessageSegment::RedactedThinking(_) => {}
2032 }
2033 }
2034
2035 for tool_use in self.tool_uses_for_message(message.id, cx) {
2036 writeln!(
2037 markdown,
2038 "**Use Tool: {} ({})**",
2039 tool_use.name, tool_use.id
2040 )?;
2041 writeln!(markdown, "```json")?;
2042 writeln!(
2043 markdown,
2044 "{}",
2045 serde_json::to_string_pretty(&tool_use.input)?
2046 )?;
2047 writeln!(markdown, "```")?;
2048 }
2049
2050 for tool_result in self.tool_results_for_message(message.id) {
2051 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2052 if tool_result.is_error {
2053 write!(markdown, " (Error)")?;
2054 }
2055
2056 writeln!(markdown, "**\n")?;
2057 writeln!(markdown, "{}", tool_result.content)?;
2058 }
2059 }
2060
2061 Ok(String::from_utf8_lossy(&markdown).to_string())
2062 }
2063
2064 pub fn keep_edits_in_range(
2065 &mut self,
2066 buffer: Entity<language::Buffer>,
2067 buffer_range: Range<language::Anchor>,
2068 cx: &mut Context<Self>,
2069 ) {
2070 self.action_log.update(cx, |action_log, cx| {
2071 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2072 });
2073 }
2074
2075 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2076 self.action_log
2077 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2078 }
2079
2080 pub fn reject_edits_in_ranges(
2081 &mut self,
2082 buffer: Entity<language::Buffer>,
2083 buffer_ranges: Vec<Range<language::Anchor>>,
2084 cx: &mut Context<Self>,
2085 ) -> Task<Result<()>> {
2086 self.action_log.update(cx, |action_log, cx| {
2087 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2088 })
2089 }
2090
2091 pub fn action_log(&self) -> &Entity<ActionLog> {
2092 &self.action_log
2093 }
2094
2095 pub fn project(&self) -> &Entity<Project> {
2096 &self.project
2097 }
2098
2099 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2100 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2101 return;
2102 }
2103
2104 let now = Instant::now();
2105 if let Some(last) = self.last_auto_capture_at {
2106 if now.duration_since(last).as_secs() < 10 {
2107 return;
2108 }
2109 }
2110
2111 self.last_auto_capture_at = Some(now);
2112
2113 let thread_id = self.id().clone();
2114 let github_login = self
2115 .project
2116 .read(cx)
2117 .user_store()
2118 .read(cx)
2119 .current_user()
2120 .map(|user| user.github_login.clone());
2121 let client = self.project.read(cx).client().clone();
2122 let serialize_task = self.serialize(cx);
2123
2124 cx.background_executor()
2125 .spawn(async move {
2126 if let Ok(serialized_thread) = serialize_task.await {
2127 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2128 telemetry::event!(
2129 "Agent Thread Auto-Captured",
2130 thread_id = thread_id.to_string(),
2131 thread_data = thread_data,
2132 auto_capture_reason = "tracked_user",
2133 github_login = github_login
2134 );
2135
2136 client.telemetry().flush_events().await;
2137 }
2138 }
2139 })
2140 .detach();
2141 }
2142
2143 pub fn cumulative_token_usage(&self) -> TokenUsage {
2144 self.cumulative_token_usage
2145 }
2146
2147 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2148 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2149 return TotalTokenUsage::default();
2150 };
2151
2152 let max = model.model.max_token_count();
2153
2154 let index = self
2155 .messages
2156 .iter()
2157 .position(|msg| msg.id == message_id)
2158 .unwrap_or(0);
2159
2160 if index == 0 {
2161 return TotalTokenUsage { total: 0, max };
2162 }
2163
2164 let token_usage = &self
2165 .request_token_usage
2166 .get(index - 1)
2167 .cloned()
2168 .unwrap_or_default();
2169
2170 TotalTokenUsage {
2171 total: token_usage.total_tokens() as usize,
2172 max,
2173 }
2174 }
2175
2176 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2177 let model_registry = LanguageModelRegistry::read_global(cx);
2178 let Some(model) = model_registry.default_model() else {
2179 return TotalTokenUsage::default();
2180 };
2181
2182 let max = model.model.max_token_count();
2183
2184 if let Some(exceeded_error) = &self.exceeded_window_error {
2185 if model.model.id() == exceeded_error.model_id {
2186 return TotalTokenUsage {
2187 total: exceeded_error.token_count,
2188 max,
2189 };
2190 }
2191 }
2192
2193 let total = self
2194 .token_usage_at_last_message()
2195 .unwrap_or_default()
2196 .total_tokens() as usize;
2197
2198 TotalTokenUsage { total, max }
2199 }
2200
2201 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2202 self.request_token_usage
2203 .get(self.messages.len().saturating_sub(1))
2204 .or_else(|| self.request_token_usage.last())
2205 .cloned()
2206 }
2207
2208 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2209 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2210 self.request_token_usage
2211 .resize(self.messages.len(), placeholder);
2212
2213 if let Some(last) = self.request_token_usage.last_mut() {
2214 *last = token_usage;
2215 }
2216 }
2217
2218 pub fn deny_tool_use(
2219 &mut self,
2220 tool_use_id: LanguageModelToolUseId,
2221 tool_name: Arc<str>,
2222 window: Option<AnyWindowHandle>,
2223 cx: &mut Context<Self>,
2224 ) {
2225 let err = Err(anyhow::anyhow!(
2226 "Permission to run tool action denied by user"
2227 ));
2228
2229 self.tool_use
2230 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2231 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2232 }
2233}
2234
2235#[derive(Debug, Clone, Error)]
2236pub enum ThreadError {
2237 #[error("Payment required")]
2238 PaymentRequired,
2239 #[error("Max monthly spend reached")]
2240 MaxMonthlySpendReached,
2241 #[error("Model request limit reached")]
2242 ModelRequestLimitReached { plan: Plan },
2243 #[error("Message {header}: {message}")]
2244 Message {
2245 header: SharedString,
2246 message: SharedString,
2247 },
2248}
2249
2250#[derive(Debug, Clone)]
2251pub enum ThreadEvent {
2252 ShowError(ThreadError),
2253 UsageUpdated(RequestUsage),
2254 StreamedCompletion,
2255 ReceivedTextChunk,
2256 StreamedAssistantText(MessageId, String),
2257 StreamedAssistantThinking(MessageId, String),
2258 StreamedToolUse {
2259 tool_use_id: LanguageModelToolUseId,
2260 ui_text: Arc<str>,
2261 input: serde_json::Value,
2262 },
2263 InvalidToolInput {
2264 tool_use_id: LanguageModelToolUseId,
2265 ui_text: Arc<str>,
2266 invalid_input_json: Arc<str>,
2267 },
2268 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2269 MessageAdded(MessageId),
2270 MessageEdited(MessageId),
2271 MessageDeleted(MessageId),
2272 SummaryGenerated,
2273 SummaryChanged,
2274 UsePendingTools {
2275 tool_uses: Vec<PendingToolUse>,
2276 },
2277 ToolFinished {
2278 #[allow(unused)]
2279 tool_use_id: LanguageModelToolUseId,
2280 /// The pending tool use that corresponds to this tool.
2281 pending_tool_use: Option<PendingToolUse>,
2282 },
2283 CheckpointChanged,
2284 ToolConfirmationNeeded,
2285}
2286
2287impl EventEmitter<ThreadEvent> for Thread {}
2288
2289struct PendingCompletion {
2290 id: usize,
2291 _task: Task<()>,
2292}
2293
2294#[cfg(test)]
2295mod tests {
2296 use super::*;
2297 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2298 use assistant_settings::AssistantSettings;
2299 use context_server::ContextServerSettings;
2300 use editor::EditorSettings;
2301 use gpui::TestAppContext;
2302 use project::{FakeFs, Project};
2303 use prompt_store::PromptBuilder;
2304 use serde_json::json;
2305 use settings::{Settings, SettingsStore};
2306 use std::sync::Arc;
2307 use theme::ThemeSettings;
2308 use util::path;
2309 use workspace::Workspace;
2310
2311 #[gpui::test]
2312 async fn test_message_with_context(cx: &mut TestAppContext) {
2313 init_test_settings(cx);
2314
2315 let project = create_test_project(
2316 cx,
2317 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2318 )
2319 .await;
2320
2321 let (_workspace, _thread_store, thread, context_store) =
2322 setup_test_environment(cx, project.clone()).await;
2323
2324 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2325 .await
2326 .unwrap();
2327
2328 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2329 let loaded_context = cx
2330 .update(|cx| load_context(vec![context], &project, &None, cx))
2331 .await;
2332
2333 // Insert user message with context
2334 let message_id = thread.update(cx, |thread, cx| {
2335 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2336 });
2337
2338 // Check content and context in message object
2339 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2340
2341 // Use different path format strings based on platform for the test
2342 #[cfg(windows)]
2343 let path_part = r"test\code.rs";
2344 #[cfg(not(windows))]
2345 let path_part = "test/code.rs";
2346
2347 let expected_context = format!(
2348 r#"
2349<context>
2350The following items were attached by the user. You don't need to use other tools to read them.
2351
2352<files>
2353```rs {path_part}
2354fn main() {{
2355 println!("Hello, world!");
2356}}
2357```
2358</files>
2359</context>
2360"#
2361 );
2362
2363 assert_eq!(message.role, Role::User);
2364 assert_eq!(message.segments.len(), 1);
2365 assert_eq!(
2366 message.segments[0],
2367 MessageSegment::Text("Please explain this code".to_string())
2368 );
2369 assert_eq!(message.loaded_context.text, expected_context);
2370
2371 // Check message in request
2372 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2373
2374 assert_eq!(request.messages.len(), 2);
2375 let expected_full_message = format!("{}Please explain this code", expected_context);
2376 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2377 }
2378
2379 #[gpui::test]
2380 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2381 init_test_settings(cx);
2382
2383 let project = create_test_project(
2384 cx,
2385 json!({
2386 "file1.rs": "fn function1() {}\n",
2387 "file2.rs": "fn function2() {}\n",
2388 "file3.rs": "fn function3() {}\n",
2389 }),
2390 )
2391 .await;
2392
2393 let (_, _thread_store, thread, context_store) =
2394 setup_test_environment(cx, project.clone()).await;
2395
2396 // First message with context 1
2397 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2398 .await
2399 .unwrap();
2400 let new_contexts = context_store.update(cx, |store, cx| {
2401 store.new_context_for_thread(thread.read(cx))
2402 });
2403 assert_eq!(new_contexts.len(), 1);
2404 let loaded_context = cx
2405 .update(|cx| load_context(new_contexts, &project, &None, cx))
2406 .await;
2407 let message1_id = thread.update(cx, |thread, cx| {
2408 thread.insert_user_message("Message 1", loaded_context, None, cx)
2409 });
2410
2411 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2412 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2413 .await
2414 .unwrap();
2415 let new_contexts = context_store.update(cx, |store, cx| {
2416 store.new_context_for_thread(thread.read(cx))
2417 });
2418 assert_eq!(new_contexts.len(), 1);
2419 let loaded_context = cx
2420 .update(|cx| load_context(new_contexts, &project, &None, cx))
2421 .await;
2422 let message2_id = thread.update(cx, |thread, cx| {
2423 thread.insert_user_message("Message 2", loaded_context, None, cx)
2424 });
2425
2426 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2427 //
2428 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2429 .await
2430 .unwrap();
2431 let new_contexts = context_store.update(cx, |store, cx| {
2432 store.new_context_for_thread(thread.read(cx))
2433 });
2434 assert_eq!(new_contexts.len(), 1);
2435 let loaded_context = cx
2436 .update(|cx| load_context(new_contexts, &project, &None, cx))
2437 .await;
2438 let message3_id = thread.update(cx, |thread, cx| {
2439 thread.insert_user_message("Message 3", loaded_context, None, cx)
2440 });
2441
2442 // Check what contexts are included in each message
2443 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2444 (
2445 thread.message(message1_id).unwrap().clone(),
2446 thread.message(message2_id).unwrap().clone(),
2447 thread.message(message3_id).unwrap().clone(),
2448 )
2449 });
2450
2451 // First message should include context 1
2452 assert!(message1.loaded_context.text.contains("file1.rs"));
2453
2454 // Second message should include only context 2 (not 1)
2455 assert!(!message2.loaded_context.text.contains("file1.rs"));
2456 assert!(message2.loaded_context.text.contains("file2.rs"));
2457
2458 // Third message should include only context 3 (not 1 or 2)
2459 assert!(!message3.loaded_context.text.contains("file1.rs"));
2460 assert!(!message3.loaded_context.text.contains("file2.rs"));
2461 assert!(message3.loaded_context.text.contains("file3.rs"));
2462
2463 // Check entire request to make sure all contexts are properly included
2464 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2465
2466 // The request should contain all 3 messages
2467 assert_eq!(request.messages.len(), 4);
2468
2469 // Check that the contexts are properly formatted in each message
2470 assert!(request.messages[1].string_contents().contains("file1.rs"));
2471 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2472 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2473
2474 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2475 assert!(request.messages[2].string_contents().contains("file2.rs"));
2476 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2477
2478 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2479 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2480 assert!(request.messages[3].string_contents().contains("file3.rs"));
2481 }
2482
2483 #[gpui::test]
2484 async fn test_message_without_files(cx: &mut TestAppContext) {
2485 init_test_settings(cx);
2486
2487 let project = create_test_project(
2488 cx,
2489 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2490 )
2491 .await;
2492
2493 let (_, _thread_store, thread, _context_store) =
2494 setup_test_environment(cx, project.clone()).await;
2495
2496 // Insert user message without any context (empty context vector)
2497 let message_id = thread.update(cx, |thread, cx| {
2498 thread.insert_user_message(
2499 "What is the best way to learn Rust?",
2500 ContextLoadResult::default(),
2501 None,
2502 cx,
2503 )
2504 });
2505
2506 // Check content and context in message object
2507 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2508
2509 // Context should be empty when no files are included
2510 assert_eq!(message.role, Role::User);
2511 assert_eq!(message.segments.len(), 1);
2512 assert_eq!(
2513 message.segments[0],
2514 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2515 );
2516 assert_eq!(message.loaded_context.text, "");
2517
2518 // Check message in request
2519 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2520
2521 assert_eq!(request.messages.len(), 2);
2522 assert_eq!(
2523 request.messages[1].string_contents(),
2524 "What is the best way to learn Rust?"
2525 );
2526
2527 // Add second message, also without context
2528 let message2_id = thread.update(cx, |thread, cx| {
2529 thread.insert_user_message(
2530 "Are there any good books?",
2531 ContextLoadResult::default(),
2532 None,
2533 cx,
2534 )
2535 });
2536
2537 let message2 =
2538 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2539 assert_eq!(message2.loaded_context.text, "");
2540
2541 // Check that both messages appear in the request
2542 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2543
2544 assert_eq!(request.messages.len(), 3);
2545 assert_eq!(
2546 request.messages[1].string_contents(),
2547 "What is the best way to learn Rust?"
2548 );
2549 assert_eq!(
2550 request.messages[2].string_contents(),
2551 "Are there any good books?"
2552 );
2553 }
2554
2555 #[gpui::test]
2556 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2557 init_test_settings(cx);
2558
2559 let project = create_test_project(
2560 cx,
2561 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2562 )
2563 .await;
2564
2565 let (_workspace, _thread_store, thread, context_store) =
2566 setup_test_environment(cx, project.clone()).await;
2567
2568 // Open buffer and add it to context
2569 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2570 .await
2571 .unwrap();
2572
2573 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2574 let loaded_context = cx
2575 .update(|cx| load_context(vec![context], &project, &None, cx))
2576 .await;
2577
2578 // Insert user message with the buffer as context
2579 thread.update(cx, |thread, cx| {
2580 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2581 });
2582
2583 // Create a request and check that it doesn't have a stale buffer warning yet
2584 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2585
2586 // Make sure we don't have a stale file warning yet
2587 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2588 msg.string_contents()
2589 .contains("These files changed since last read:")
2590 });
2591 assert!(
2592 !has_stale_warning,
2593 "Should not have stale buffer warning before buffer is modified"
2594 );
2595
2596 // Modify the buffer
2597 buffer.update(cx, |buffer, cx| {
2598 // Find a position at the end of line 1
2599 buffer.edit(
2600 [(1..1, "\n println!(\"Added a new line\");\n")],
2601 None,
2602 cx,
2603 );
2604 });
2605
2606 // Insert another user message without context
2607 thread.update(cx, |thread, cx| {
2608 thread.insert_user_message(
2609 "What does the code do now?",
2610 ContextLoadResult::default(),
2611 None,
2612 cx,
2613 )
2614 });
2615
2616 // Create a new request and check for the stale buffer warning
2617 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2618
2619 // We should have a stale file warning as the last message
2620 let last_message = new_request
2621 .messages
2622 .last()
2623 .expect("Request should have messages");
2624
2625 // The last message should be the stale buffer notification
2626 assert_eq!(last_message.role, Role::User);
2627
2628 // Check the exact content of the message
2629 let expected_content = "These files changed since last read:\n- code.rs\n";
2630 assert_eq!(
2631 last_message.string_contents(),
2632 expected_content,
2633 "Last message should be exactly the stale buffer notification"
2634 );
2635 }
2636
2637 fn init_test_settings(cx: &mut TestAppContext) {
2638 cx.update(|cx| {
2639 let settings_store = SettingsStore::test(cx);
2640 cx.set_global(settings_store);
2641 language::init(cx);
2642 Project::init_settings(cx);
2643 AssistantSettings::register(cx);
2644 prompt_store::init(cx);
2645 thread_store::init(cx);
2646 workspace::init_settings(cx);
2647 ThemeSettings::register(cx);
2648 ContextServerSettings::register(cx);
2649 EditorSettings::register(cx);
2650 });
2651 }
2652
2653 // Helper to create a test project with test files
2654 async fn create_test_project(
2655 cx: &mut TestAppContext,
2656 files: serde_json::Value,
2657 ) -> Entity<Project> {
2658 let fs = FakeFs::new(cx.executor());
2659 fs.insert_tree(path!("/test"), files).await;
2660 Project::test(fs, [path!("/test").as_ref()], cx).await
2661 }
2662
2663 async fn setup_test_environment(
2664 cx: &mut TestAppContext,
2665 project: Entity<Project>,
2666 ) -> (
2667 Entity<Workspace>,
2668 Entity<ThreadStore>,
2669 Entity<Thread>,
2670 Entity<ContextStore>,
2671 ) {
2672 let (workspace, cx) =
2673 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2674
2675 let thread_store = cx
2676 .update(|_, cx| {
2677 ThreadStore::load(
2678 project.clone(),
2679 cx.new(|_| ToolWorkingSet::default()),
2680 None,
2681 Arc::new(PromptBuilder::new(None).unwrap()),
2682 cx,
2683 )
2684 })
2685 .await
2686 .unwrap();
2687
2688 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2689 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2690
2691 (workspace, thread_store, thread, context_store)
2692 }
2693
2694 async fn add_file_to_context(
2695 project: &Entity<Project>,
2696 context_store: &Entity<ContextStore>,
2697 path: &str,
2698 cx: &mut TestAppContext,
2699 ) -> Result<Entity<language::Buffer>> {
2700 let buffer_path = project
2701 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2702 .unwrap();
2703
2704 let buffer = project
2705 .update(cx, |project, cx| {
2706 project.open_buffer(buffer_path.clone(), cx)
2707 })
2708 .await
2709 .unwrap();
2710
2711 context_store.update(cx, |context_store, cx| {
2712 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2713 });
2714
2715 Ok(buffer)
2716 }
2717}