1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
18};
19use language_model::{
20 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
21 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
22 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
23 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
24 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
25 TokenUsage,
26};
27use project::Project;
28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
29use prompt_store::PromptBuilder;
30use proto::Plan;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use settings::Settings;
34use thiserror::Error;
35use util::{ResultExt as _, TryFutureExt as _, post_inc};
36use uuid::Uuid;
37
38use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
39use crate::thread_store::{
40 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
41 SerializedToolUse, SharedProjectContext,
42};
43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
44
45#[derive(
46 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
47)]
48pub struct ThreadId(Arc<str>);
49
50impl ThreadId {
51 pub fn new() -> Self {
52 Self(Uuid::new_v4().to_string().into())
53 }
54}
55
56impl std::fmt::Display for ThreadId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62impl From<&str> for ThreadId {
63 fn from(value: &str) -> Self {
64 Self(value.into())
65 }
66}
67
68/// The ID of the user prompt that initiated a request.
69///
70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
72pub struct PromptId(Arc<str>);
73
74impl PromptId {
75 pub fn new() -> Self {
76 Self(Uuid::new_v4().to_string().into())
77 }
78}
79
80impl std::fmt::Display for PromptId {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
87pub struct MessageId(pub(crate) usize);
88
89impl MessageId {
90 fn post_inc(&mut self) -> Self {
91 Self(post_inc(&mut self.0))
92 }
93}
94
95/// A message in a [`Thread`].
96#[derive(Debug, Clone)]
97pub struct Message {
98 pub id: MessageId,
99 pub role: Role,
100 pub segments: Vec<MessageSegment>,
101 pub loaded_context: LoadedContext,
102}
103
104impl Message {
105 /// Returns whether the message contains any meaningful text that should be displayed
106 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
107 pub fn should_display_content(&self) -> bool {
108 self.segments.iter().all(|segment| segment.should_display())
109 }
110
111 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
112 if let Some(MessageSegment::Thinking {
113 text: segment,
114 signature: current_signature,
115 }) = self.segments.last_mut()
116 {
117 if let Some(signature) = signature {
118 *current_signature = Some(signature);
119 }
120 segment.push_str(text);
121 } else {
122 self.segments.push(MessageSegment::Thinking {
123 text: text.to_string(),
124 signature,
125 });
126 }
127 }
128
129 pub fn push_text(&mut self, text: &str) {
130 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
131 segment.push_str(text);
132 } else {
133 self.segments.push(MessageSegment::Text(text.to_string()));
134 }
135 }
136
137 pub fn to_string(&self) -> String {
138 let mut result = String::new();
139
140 if !self.loaded_context.text.is_empty() {
141 result.push_str(&self.loaded_context.text);
142 }
143
144 for segment in &self.segments {
145 match segment {
146 MessageSegment::Text(text) => result.push_str(text),
147 MessageSegment::Thinking { text, .. } => {
148 result.push_str("<think>\n");
149 result.push_str(text);
150 result.push_str("\n</think>");
151 }
152 MessageSegment::RedactedThinking(_) => {}
153 }
154 }
155
156 result
157 }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161pub enum MessageSegment {
162 Text(String),
163 Thinking {
164 text: String,
165 signature: Option<String>,
166 },
167 RedactedThinking(Vec<u8>),
168}
169
170impl MessageSegment {
171 pub fn should_display(&self) -> bool {
172 match self {
173 Self::Text(text) => text.is_empty(),
174 Self::Thinking { text, .. } => text.is_empty(),
175 Self::RedactedThinking(_) => false,
176 }
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct ProjectSnapshot {
182 pub worktree_snapshots: Vec<WorktreeSnapshot>,
183 pub unsaved_buffer_paths: Vec<String>,
184 pub timestamp: DateTime<Utc>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct WorktreeSnapshot {
189 pub worktree_path: String,
190 pub git_state: Option<GitState>,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct GitState {
195 pub remote_url: Option<String>,
196 pub head_sha: Option<String>,
197 pub current_branch: Option<String>,
198 pub diff: Option<String>,
199}
200
201#[derive(Clone)]
202pub struct ThreadCheckpoint {
203 message_id: MessageId,
204 git_checkpoint: GitStoreCheckpoint,
205}
206
207#[derive(Copy, Clone, Debug, PartialEq, Eq)]
208pub enum ThreadFeedback {
209 Positive,
210 Negative,
211}
212
213pub enum LastRestoreCheckpoint {
214 Pending {
215 message_id: MessageId,
216 },
217 Error {
218 message_id: MessageId,
219 error: String,
220 },
221}
222
223impl LastRestoreCheckpoint {
224 pub fn message_id(&self) -> MessageId {
225 match self {
226 LastRestoreCheckpoint::Pending { message_id } => *message_id,
227 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
228 }
229 }
230}
231
232#[derive(Clone, Debug, Default, Serialize, Deserialize)]
233pub enum DetailedSummaryState {
234 #[default]
235 NotGenerated,
236 Generating {
237 message_id: MessageId,
238 },
239 Generated {
240 text: SharedString,
241 message_id: MessageId,
242 },
243}
244
245#[derive(Default)]
246pub struct TotalTokenUsage {
247 pub total: usize,
248 pub max: usize,
249}
250
251impl TotalTokenUsage {
252 pub fn ratio(&self) -> TokenUsageRatio {
253 #[cfg(debug_assertions)]
254 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
255 .unwrap_or("0.8".to_string())
256 .parse()
257 .unwrap();
258 #[cfg(not(debug_assertions))]
259 let warning_threshold: f32 = 0.8;
260
261 if self.total >= self.max {
262 TokenUsageRatio::Exceeded
263 } else if self.total as f32 / self.max as f32 >= warning_threshold {
264 TokenUsageRatio::Warning
265 } else {
266 TokenUsageRatio::Normal
267 }
268 }
269
270 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
271 TotalTokenUsage {
272 total: self.total + tokens,
273 max: self.max,
274 }
275 }
276}
277
278#[derive(Debug, Default, PartialEq, Eq)]
279pub enum TokenUsageRatio {
280 #[default]
281 Normal,
282 Warning,
283 Exceeded,
284}
285
286/// A thread of conversation with the LLM.
287pub struct Thread {
288 id: ThreadId,
289 updated_at: DateTime<Utc>,
290 summary: Option<SharedString>,
291 pending_summary: Task<Option<()>>,
292 detailed_summary_state: DetailedSummaryState,
293 messages: Vec<Message>,
294 next_message_id: MessageId,
295 last_prompt_id: PromptId,
296 project_context: SharedProjectContext,
297 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
298 completion_count: usize,
299 pending_completions: Vec<PendingCompletion>,
300 project: Entity<Project>,
301 prompt_builder: Arc<PromptBuilder>,
302 tools: Entity<ToolWorkingSet>,
303 tool_use: ToolUseState,
304 action_log: Entity<ActionLog>,
305 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
306 pending_checkpoint: Option<ThreadCheckpoint>,
307 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
308 request_token_usage: Vec<TokenUsage>,
309 cumulative_token_usage: TokenUsage,
310 exceeded_window_error: Option<ExceededWindowError>,
311 feedback: Option<ThreadFeedback>,
312 message_feedback: HashMap<MessageId, ThreadFeedback>,
313 last_auto_capture_at: Option<Instant>,
314 request_callback: Option<
315 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
316 >,
317 remaining_turns: u32,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct ExceededWindowError {
322 /// Model used when last message exceeded context window
323 model_id: LanguageModelId,
324 /// Token count including last message
325 token_count: usize,
326}
327
328impl Thread {
329 pub fn new(
330 project: Entity<Project>,
331 tools: Entity<ToolWorkingSet>,
332 prompt_builder: Arc<PromptBuilder>,
333 system_prompt: SharedProjectContext,
334 cx: &mut Context<Self>,
335 ) -> Self {
336 Self {
337 id: ThreadId::new(),
338 updated_at: Utc::now(),
339 summary: None,
340 pending_summary: Task::ready(None),
341 detailed_summary_state: DetailedSummaryState::NotGenerated,
342 messages: Vec::new(),
343 next_message_id: MessageId(0),
344 last_prompt_id: PromptId::new(),
345 project_context: system_prompt,
346 checkpoints_by_message: HashMap::default(),
347 completion_count: 0,
348 pending_completions: Vec::new(),
349 project: project.clone(),
350 prompt_builder,
351 tools: tools.clone(),
352 last_restore_checkpoint: None,
353 pending_checkpoint: None,
354 tool_use: ToolUseState::new(tools.clone()),
355 action_log: cx.new(|_| ActionLog::new(project.clone())),
356 initial_project_snapshot: {
357 let project_snapshot = Self::project_snapshot(project, cx);
358 cx.foreground_executor()
359 .spawn(async move { Some(project_snapshot.await) })
360 .shared()
361 },
362 request_token_usage: Vec::new(),
363 cumulative_token_usage: TokenUsage::default(),
364 exceeded_window_error: None,
365 feedback: None,
366 message_feedback: HashMap::default(),
367 last_auto_capture_at: None,
368 request_callback: None,
369 remaining_turns: u32::MAX,
370 }
371 }
372
373 pub fn deserialize(
374 id: ThreadId,
375 serialized: SerializedThread,
376 project: Entity<Project>,
377 tools: Entity<ToolWorkingSet>,
378 prompt_builder: Arc<PromptBuilder>,
379 project_context: SharedProjectContext,
380 cx: &mut Context<Self>,
381 ) -> Self {
382 let next_message_id = MessageId(
383 serialized
384 .messages
385 .last()
386 .map(|message| message.id.0 + 1)
387 .unwrap_or(0),
388 );
389 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
390
391 Self {
392 id,
393 updated_at: serialized.updated_at,
394 summary: Some(serialized.summary),
395 pending_summary: Task::ready(None),
396 detailed_summary_state: serialized.detailed_summary_state,
397 messages: serialized
398 .messages
399 .into_iter()
400 .map(|message| Message {
401 id: message.id,
402 role: message.role,
403 segments: message
404 .segments
405 .into_iter()
406 .map(|segment| match segment {
407 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
408 SerializedMessageSegment::Thinking { text, signature } => {
409 MessageSegment::Thinking { text, signature }
410 }
411 SerializedMessageSegment::RedactedThinking { data } => {
412 MessageSegment::RedactedThinking(data)
413 }
414 })
415 .collect(),
416 loaded_context: LoadedContext {
417 contexts: Vec::new(),
418 text: message.context,
419 images: Vec::new(),
420 },
421 })
422 .collect(),
423 next_message_id,
424 last_prompt_id: PromptId::new(),
425 project_context,
426 checkpoints_by_message: HashMap::default(),
427 completion_count: 0,
428 pending_completions: Vec::new(),
429 last_restore_checkpoint: None,
430 pending_checkpoint: None,
431 project: project.clone(),
432 prompt_builder,
433 tools,
434 tool_use,
435 action_log: cx.new(|_| ActionLog::new(project)),
436 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
437 request_token_usage: serialized.request_token_usage,
438 cumulative_token_usage: serialized.cumulative_token_usage,
439 exceeded_window_error: None,
440 feedback: None,
441 message_feedback: HashMap::default(),
442 last_auto_capture_at: None,
443 request_callback: None,
444 remaining_turns: u32::MAX,
445 }
446 }
447
448 pub fn set_request_callback(
449 &mut self,
450 callback: impl 'static
451 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
452 ) {
453 self.request_callback = Some(Box::new(callback));
454 }
455
456 pub fn id(&self) -> &ThreadId {
457 &self.id
458 }
459
460 pub fn is_empty(&self) -> bool {
461 self.messages.is_empty()
462 }
463
464 pub fn updated_at(&self) -> DateTime<Utc> {
465 self.updated_at
466 }
467
468 pub fn touch_updated_at(&mut self) {
469 self.updated_at = Utc::now();
470 }
471
472 pub fn advance_prompt_id(&mut self) {
473 self.last_prompt_id = PromptId::new();
474 }
475
476 pub fn summary(&self) -> Option<SharedString> {
477 self.summary.clone()
478 }
479
480 pub fn project_context(&self) -> SharedProjectContext {
481 self.project_context.clone()
482 }
483
484 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
485
486 pub fn summary_or_default(&self) -> SharedString {
487 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
488 }
489
490 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
491 let Some(current_summary) = &self.summary else {
492 // Don't allow setting summary until generated
493 return;
494 };
495
496 let mut new_summary = new_summary.into();
497
498 if new_summary.is_empty() {
499 new_summary = Self::DEFAULT_SUMMARY;
500 }
501
502 if current_summary != &new_summary {
503 self.summary = Some(new_summary);
504 cx.emit(ThreadEvent::SummaryChanged);
505 }
506 }
507
508 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
509 self.latest_detailed_summary()
510 .unwrap_or_else(|| self.text().into())
511 }
512
513 fn latest_detailed_summary(&self) -> Option<SharedString> {
514 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
515 Some(text.clone())
516 } else {
517 None
518 }
519 }
520
521 pub fn message(&self, id: MessageId) -> Option<&Message> {
522 let index = self
523 .messages
524 .binary_search_by(|message| message.id.cmp(&id))
525 .ok()?;
526
527 self.messages.get(index)
528 }
529
530 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
531 self.messages.iter()
532 }
533
534 pub fn is_generating(&self) -> bool {
535 !self.pending_completions.is_empty() || !self.all_tools_finished()
536 }
537
538 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
539 &self.tools
540 }
541
542 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
543 self.tool_use
544 .pending_tool_uses()
545 .into_iter()
546 .find(|tool_use| &tool_use.id == id)
547 }
548
549 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
550 self.tool_use
551 .pending_tool_uses()
552 .into_iter()
553 .filter(|tool_use| tool_use.status.needs_confirmation())
554 }
555
556 pub fn has_pending_tool_uses(&self) -> bool {
557 !self.tool_use.pending_tool_uses().is_empty()
558 }
559
560 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
561 self.checkpoints_by_message.get(&id).cloned()
562 }
563
564 pub fn restore_checkpoint(
565 &mut self,
566 checkpoint: ThreadCheckpoint,
567 cx: &mut Context<Self>,
568 ) -> Task<Result<()>> {
569 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
570 message_id: checkpoint.message_id,
571 });
572 cx.emit(ThreadEvent::CheckpointChanged);
573 cx.notify();
574
575 let git_store = self.project().read(cx).git_store().clone();
576 let restore = git_store.update(cx, |git_store, cx| {
577 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
578 });
579
580 cx.spawn(async move |this, cx| {
581 let result = restore.await;
582 this.update(cx, |this, cx| {
583 if let Err(err) = result.as_ref() {
584 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
585 message_id: checkpoint.message_id,
586 error: err.to_string(),
587 });
588 } else {
589 this.truncate(checkpoint.message_id, cx);
590 this.last_restore_checkpoint = None;
591 }
592 this.pending_checkpoint = None;
593 cx.emit(ThreadEvent::CheckpointChanged);
594 cx.notify();
595 })?;
596 result
597 })
598 }
599
600 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
601 let pending_checkpoint = if self.is_generating() {
602 return;
603 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
604 checkpoint
605 } else {
606 return;
607 };
608
609 let git_store = self.project.read(cx).git_store().clone();
610 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
611 cx.spawn(async move |this, cx| match final_checkpoint.await {
612 Ok(final_checkpoint) => {
613 let equal = git_store
614 .update(cx, |store, cx| {
615 store.compare_checkpoints(
616 pending_checkpoint.git_checkpoint.clone(),
617 final_checkpoint.clone(),
618 cx,
619 )
620 })?
621 .await
622 .unwrap_or(false);
623
624 if !equal {
625 this.update(cx, |this, cx| {
626 this.insert_checkpoint(pending_checkpoint, cx)
627 })?;
628 }
629
630 Ok(())
631 }
632 Err(_) => this.update(cx, |this, cx| {
633 this.insert_checkpoint(pending_checkpoint, cx)
634 }),
635 })
636 .detach();
637 }
638
639 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
640 self.checkpoints_by_message
641 .insert(checkpoint.message_id, checkpoint);
642 cx.emit(ThreadEvent::CheckpointChanged);
643 cx.notify();
644 }
645
646 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
647 self.last_restore_checkpoint.as_ref()
648 }
649
650 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
651 let Some(message_ix) = self
652 .messages
653 .iter()
654 .rposition(|message| message.id == message_id)
655 else {
656 return;
657 };
658 for deleted_message in self.messages.drain(message_ix..) {
659 self.checkpoints_by_message.remove(&deleted_message.id);
660 }
661 cx.notify();
662 }
663
664 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
665 self.messages
666 .iter()
667 .find(|message| message.id == id)
668 .into_iter()
669 .flat_map(|message| message.loaded_context.contexts.iter())
670 }
671
672 pub fn is_turn_end(&self, ix: usize) -> bool {
673 if self.messages.is_empty() {
674 return false;
675 }
676
677 if !self.is_generating() && ix == self.messages.len() - 1 {
678 return true;
679 }
680
681 let Some(message) = self.messages.get(ix) else {
682 return false;
683 };
684
685 if message.role != Role::Assistant {
686 return false;
687 }
688
689 self.messages
690 .get(ix + 1)
691 .and_then(|message| {
692 self.message(message.id)
693 .map(|next_message| next_message.role == Role::User)
694 })
695 .unwrap_or(false)
696 }
697
698 /// Returns whether all of the tool uses have finished running.
699 pub fn all_tools_finished(&self) -> bool {
700 // If the only pending tool uses left are the ones with errors, then
701 // that means that we've finished running all of the pending tools.
702 self.tool_use
703 .pending_tool_uses()
704 .iter()
705 .all(|tool_use| tool_use.status.is_error())
706 }
707
708 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
709 self.tool_use.tool_uses_for_message(id, cx)
710 }
711
712 pub fn tool_results_for_message(
713 &self,
714 assistant_message_id: MessageId,
715 ) -> Vec<&LanguageModelToolResult> {
716 self.tool_use.tool_results_for_message(assistant_message_id)
717 }
718
719 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
720 self.tool_use.tool_result(id)
721 }
722
723 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
724 Some(&self.tool_use.tool_result(id)?.content)
725 }
726
727 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
728 self.tool_use.tool_result_card(id).cloned()
729 }
730
731 pub fn insert_user_message(
732 &mut self,
733 text: impl Into<String>,
734 loaded_context: ContextLoadResult,
735 git_checkpoint: Option<GitStoreCheckpoint>,
736 cx: &mut Context<Self>,
737 ) -> MessageId {
738 if !loaded_context.referenced_buffers.is_empty() {
739 self.action_log.update(cx, |log, cx| {
740 for buffer in loaded_context.referenced_buffers {
741 log.track_buffer(buffer, cx);
742 }
743 });
744 }
745
746 let message_id = self.insert_message(
747 Role::User,
748 vec![MessageSegment::Text(text.into())],
749 loaded_context.loaded_context,
750 cx,
751 );
752
753 if let Some(git_checkpoint) = git_checkpoint {
754 self.pending_checkpoint = Some(ThreadCheckpoint {
755 message_id,
756 git_checkpoint,
757 });
758 }
759
760 self.auto_capture_telemetry(cx);
761
762 message_id
763 }
764
765 pub fn insert_assistant_message(
766 &mut self,
767 segments: Vec<MessageSegment>,
768 cx: &mut Context<Self>,
769 ) -> MessageId {
770 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
771 }
772
773 pub fn insert_message(
774 &mut self,
775 role: Role,
776 segments: Vec<MessageSegment>,
777 loaded_context: LoadedContext,
778 cx: &mut Context<Self>,
779 ) -> MessageId {
780 let id = self.next_message_id.post_inc();
781 self.messages.push(Message {
782 id,
783 role,
784 segments,
785 loaded_context,
786 });
787 self.touch_updated_at();
788 cx.emit(ThreadEvent::MessageAdded(id));
789 id
790 }
791
792 pub fn edit_message(
793 &mut self,
794 id: MessageId,
795 new_role: Role,
796 new_segments: Vec<MessageSegment>,
797 cx: &mut Context<Self>,
798 ) -> bool {
799 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
800 return false;
801 };
802 message.role = new_role;
803 message.segments = new_segments;
804 self.touch_updated_at();
805 cx.emit(ThreadEvent::MessageEdited(id));
806 true
807 }
808
809 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
810 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
811 return false;
812 };
813 self.messages.remove(index);
814 self.touch_updated_at();
815 cx.emit(ThreadEvent::MessageDeleted(id));
816 true
817 }
818
819 /// Returns the representation of this [`Thread`] in a textual form.
820 ///
821 /// This is the representation we use when attaching a thread as context to another thread.
822 pub fn text(&self) -> String {
823 let mut text = String::new();
824
825 for message in &self.messages {
826 text.push_str(match message.role {
827 language_model::Role::User => "User:",
828 language_model::Role::Assistant => "Assistant:",
829 language_model::Role::System => "System:",
830 });
831 text.push('\n');
832
833 for segment in &message.segments {
834 match segment {
835 MessageSegment::Text(content) => text.push_str(content),
836 MessageSegment::Thinking { text: content, .. } => {
837 text.push_str(&format!("<think>{}</think>", content))
838 }
839 MessageSegment::RedactedThinking(_) => {}
840 }
841 }
842 text.push('\n');
843 }
844
845 text
846 }
847
848 /// Serializes this thread into a format for storage or telemetry.
849 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
850 let initial_project_snapshot = self.initial_project_snapshot.clone();
851 cx.spawn(async move |this, cx| {
852 let initial_project_snapshot = initial_project_snapshot.await;
853 this.read_with(cx, |this, cx| SerializedThread {
854 version: SerializedThread::VERSION.to_string(),
855 summary: this.summary_or_default(),
856 updated_at: this.updated_at(),
857 messages: this
858 .messages()
859 .map(|message| SerializedMessage {
860 id: message.id,
861 role: message.role,
862 segments: message
863 .segments
864 .iter()
865 .map(|segment| match segment {
866 MessageSegment::Text(text) => {
867 SerializedMessageSegment::Text { text: text.clone() }
868 }
869 MessageSegment::Thinking { text, signature } => {
870 SerializedMessageSegment::Thinking {
871 text: text.clone(),
872 signature: signature.clone(),
873 }
874 }
875 MessageSegment::RedactedThinking(data) => {
876 SerializedMessageSegment::RedactedThinking {
877 data: data.clone(),
878 }
879 }
880 })
881 .collect(),
882 tool_uses: this
883 .tool_uses_for_message(message.id, cx)
884 .into_iter()
885 .map(|tool_use| SerializedToolUse {
886 id: tool_use.id,
887 name: tool_use.name,
888 input: tool_use.input,
889 })
890 .collect(),
891 tool_results: this
892 .tool_results_for_message(message.id)
893 .into_iter()
894 .map(|tool_result| SerializedToolResult {
895 tool_use_id: tool_result.tool_use_id.clone(),
896 is_error: tool_result.is_error,
897 content: tool_result.content.clone(),
898 })
899 .collect(),
900 context: message.loaded_context.text.clone(),
901 })
902 .collect(),
903 initial_project_snapshot,
904 cumulative_token_usage: this.cumulative_token_usage,
905 request_token_usage: this.request_token_usage.clone(),
906 detailed_summary_state: this.detailed_summary_state.clone(),
907 exceeded_window_error: this.exceeded_window_error.clone(),
908 })
909 })
910 }
911
912 pub fn remaining_turns(&self) -> u32 {
913 self.remaining_turns
914 }
915
916 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
917 self.remaining_turns = remaining_turns;
918 }
919
920 pub fn send_to_model(
921 &mut self,
922 model: Arc<dyn LanguageModel>,
923 window: Option<AnyWindowHandle>,
924 cx: &mut Context<Self>,
925 ) {
926 if self.remaining_turns == 0 {
927 return;
928 }
929
930 self.remaining_turns -= 1;
931
932 let mut request = self.to_completion_request(cx);
933 if model.supports_tools() {
934 request.tools = {
935 let mut tools = Vec::new();
936 tools.extend(
937 self.tools()
938 .read(cx)
939 .enabled_tools(cx)
940 .into_iter()
941 .filter_map(|tool| {
942 // Skip tools that cannot be supported
943 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
944 Some(LanguageModelRequestTool {
945 name: tool.name(),
946 description: tool.description(),
947 input_schema,
948 })
949 }),
950 );
951
952 tools
953 };
954 }
955
956 self.stream_completion(request, model, window, cx);
957 }
958
959 pub fn used_tools_since_last_user_message(&self) -> bool {
960 for message in self.messages.iter().rev() {
961 if self.tool_use.message_has_tool_results(message.id) {
962 return true;
963 } else if message.role == Role::User {
964 return false;
965 }
966 }
967
968 false
969 }
970
971 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
972 let mut request = LanguageModelRequest {
973 thread_id: Some(self.id.to_string()),
974 prompt_id: Some(self.last_prompt_id.to_string()),
975 messages: vec![],
976 tools: Vec::new(),
977 stop: Vec::new(),
978 temperature: None,
979 };
980
981 if let Some(project_context) = self.project_context.borrow().as_ref() {
982 match self
983 .prompt_builder
984 .generate_assistant_system_prompt(project_context)
985 {
986 Err(err) => {
987 let message = format!("{err:?}").into();
988 log::error!("{message}");
989 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
990 header: "Error generating system prompt".into(),
991 message,
992 }));
993 }
994 Ok(system_prompt) => {
995 request.messages.push(LanguageModelRequestMessage {
996 role: Role::System,
997 content: vec![MessageContent::Text(system_prompt)],
998 cache: true,
999 });
1000 }
1001 }
1002 } else {
1003 let message = "Context for system prompt unexpectedly not ready.".into();
1004 log::error!("{message}");
1005 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1006 header: "Error generating system prompt".into(),
1007 message,
1008 }));
1009 }
1010
1011 for message in &self.messages {
1012 let mut request_message = LanguageModelRequestMessage {
1013 role: message.role,
1014 content: Vec::new(),
1015 cache: false,
1016 };
1017
1018 message
1019 .loaded_context
1020 .add_to_request_message(&mut request_message);
1021
1022 for segment in &message.segments {
1023 match segment {
1024 MessageSegment::Text(text) => {
1025 if !text.is_empty() {
1026 request_message
1027 .content
1028 .push(MessageContent::Text(text.into()));
1029 }
1030 }
1031 MessageSegment::Thinking { text, signature } => {
1032 if !text.is_empty() {
1033 request_message.content.push(MessageContent::Thinking {
1034 text: text.into(),
1035 signature: signature.clone(),
1036 });
1037 }
1038 }
1039 MessageSegment::RedactedThinking(data) => {
1040 request_message
1041 .content
1042 .push(MessageContent::RedactedThinking(data.clone()));
1043 }
1044 };
1045 }
1046
1047 self.tool_use
1048 .attach_tool_uses(message.id, &mut request_message);
1049
1050 request.messages.push(request_message);
1051
1052 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1053 request.messages.push(tool_results_message);
1054 }
1055 }
1056
1057 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1058 if let Some(last) = request.messages.last_mut() {
1059 last.cache = true;
1060 }
1061
1062 self.attached_tracked_files_state(&mut request.messages, cx);
1063
1064 request
1065 }
1066
1067 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1068 let mut request = LanguageModelRequest {
1069 thread_id: None,
1070 prompt_id: None,
1071 messages: vec![],
1072 tools: Vec::new(),
1073 stop: Vec::new(),
1074 temperature: None,
1075 };
1076
1077 for message in &self.messages {
1078 let mut request_message = LanguageModelRequestMessage {
1079 role: message.role,
1080 content: Vec::new(),
1081 cache: false,
1082 };
1083
1084 for segment in &message.segments {
1085 match segment {
1086 MessageSegment::Text(text) => request_message
1087 .content
1088 .push(MessageContent::Text(text.clone())),
1089 MessageSegment::Thinking { .. } => {}
1090 MessageSegment::RedactedThinking(_) => {}
1091 }
1092 }
1093
1094 if request_message.content.is_empty() {
1095 continue;
1096 }
1097
1098 request.messages.push(request_message);
1099 }
1100
1101 request.messages.push(LanguageModelRequestMessage {
1102 role: Role::User,
1103 content: vec![MessageContent::Text(added_user_message)],
1104 cache: false,
1105 });
1106
1107 request
1108 }
1109
1110 fn attached_tracked_files_state(
1111 &self,
1112 messages: &mut Vec<LanguageModelRequestMessage>,
1113 cx: &App,
1114 ) {
1115 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1116
1117 let mut stale_message = String::new();
1118
1119 let action_log = self.action_log.read(cx);
1120
1121 for stale_file in action_log.stale_buffers(cx) {
1122 let Some(file) = stale_file.read(cx).file() else {
1123 continue;
1124 };
1125
1126 if stale_message.is_empty() {
1127 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1128 }
1129
1130 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1131 }
1132
1133 let mut content = Vec::with_capacity(2);
1134
1135 if !stale_message.is_empty() {
1136 content.push(stale_message.into());
1137 }
1138
1139 if !content.is_empty() {
1140 let context_message = LanguageModelRequestMessage {
1141 role: Role::User,
1142 content,
1143 cache: false,
1144 };
1145
1146 messages.push(context_message);
1147 }
1148 }
1149
1150 pub fn stream_completion(
1151 &mut self,
1152 request: LanguageModelRequest,
1153 model: Arc<dyn LanguageModel>,
1154 window: Option<AnyWindowHandle>,
1155 cx: &mut Context<Self>,
1156 ) {
1157 let pending_completion_id = post_inc(&mut self.completion_count);
1158 let mut request_callback_parameters = if self.request_callback.is_some() {
1159 Some((request.clone(), Vec::new()))
1160 } else {
1161 None
1162 };
1163 let prompt_id = self.last_prompt_id.clone();
1164 let tool_use_metadata = ToolUseMetadata {
1165 model: model.clone(),
1166 thread_id: self.id.clone(),
1167 prompt_id: prompt_id.clone(),
1168 };
1169
1170 let task = cx.spawn(async move |thread, cx| {
1171 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1172 let initial_token_usage =
1173 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1174 let stream_completion = async {
1175 let (mut events, usage) = stream_completion_future.await?;
1176
1177 let mut stop_reason = StopReason::EndTurn;
1178 let mut current_token_usage = TokenUsage::default();
1179
1180 if let Some(usage) = usage {
1181 thread
1182 .update(cx, |_thread, cx| {
1183 cx.emit(ThreadEvent::UsageUpdated(usage));
1184 })
1185 .ok();
1186 }
1187
1188 let mut request_assistant_message_id = None;
1189
1190 while let Some(event) = events.next().await {
1191 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1192 response_events
1193 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1194 }
1195
1196 thread.update(cx, |thread, cx| {
1197 let event = match event {
1198 Ok(event) => event,
1199 Err(LanguageModelCompletionError::BadInputJson {
1200 id,
1201 tool_name,
1202 raw_input: invalid_input_json,
1203 json_parse_error,
1204 }) => {
1205 thread.receive_invalid_tool_json(
1206 id,
1207 tool_name,
1208 invalid_input_json,
1209 json_parse_error,
1210 window,
1211 cx,
1212 );
1213 return Ok(());
1214 }
1215 Err(LanguageModelCompletionError::Other(error)) => {
1216 return Err(error);
1217 }
1218 };
1219
1220 match event {
1221 LanguageModelCompletionEvent::StartMessage { .. } => {
1222 request_assistant_message_id =
1223 Some(thread.insert_assistant_message(
1224 vec![MessageSegment::Text(String::new())],
1225 cx,
1226 ));
1227 }
1228 LanguageModelCompletionEvent::Stop(reason) => {
1229 stop_reason = reason;
1230 }
1231 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1232 thread.update_token_usage_at_last_message(token_usage);
1233 thread.cumulative_token_usage = thread.cumulative_token_usage
1234 + token_usage
1235 - current_token_usage;
1236 current_token_usage = token_usage;
1237 }
1238 LanguageModelCompletionEvent::Text(chunk) => {
1239 cx.emit(ThreadEvent::ReceivedTextChunk);
1240 if let Some(last_message) = thread.messages.last_mut() {
1241 if last_message.role == Role::Assistant
1242 && !thread.tool_use.has_tool_results(last_message.id)
1243 {
1244 last_message.push_text(&chunk);
1245 cx.emit(ThreadEvent::StreamedAssistantText(
1246 last_message.id,
1247 chunk,
1248 ));
1249 } else {
1250 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1251 // of a new Assistant response.
1252 //
1253 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1254 // will result in duplicating the text of the chunk in the rendered Markdown.
1255 request_assistant_message_id =
1256 Some(thread.insert_assistant_message(
1257 vec![MessageSegment::Text(chunk.to_string())],
1258 cx,
1259 ));
1260 };
1261 }
1262 }
1263 LanguageModelCompletionEvent::Thinking {
1264 text: chunk,
1265 signature,
1266 } => {
1267 if let Some(last_message) = thread.messages.last_mut() {
1268 if last_message.role == Role::Assistant
1269 && !thread.tool_use.has_tool_results(last_message.id)
1270 {
1271 last_message.push_thinking(&chunk, signature);
1272 cx.emit(ThreadEvent::StreamedAssistantThinking(
1273 last_message.id,
1274 chunk,
1275 ));
1276 } else {
1277 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1278 // of a new Assistant response.
1279 //
1280 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1281 // will result in duplicating the text of the chunk in the rendered Markdown.
1282 request_assistant_message_id =
1283 Some(thread.insert_assistant_message(
1284 vec![MessageSegment::Thinking {
1285 text: chunk.to_string(),
1286 signature,
1287 }],
1288 cx,
1289 ));
1290 };
1291 }
1292 }
1293 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1294 let last_assistant_message_id = request_assistant_message_id
1295 .unwrap_or_else(|| {
1296 let new_assistant_message_id =
1297 thread.insert_assistant_message(vec![], cx);
1298 request_assistant_message_id =
1299 Some(new_assistant_message_id);
1300 new_assistant_message_id
1301 });
1302
1303 let tool_use_id = tool_use.id.clone();
1304 let streamed_input = if tool_use.is_input_complete {
1305 None
1306 } else {
1307 Some((&tool_use.input).clone())
1308 };
1309
1310 let ui_text = thread.tool_use.request_tool_use(
1311 last_assistant_message_id,
1312 tool_use,
1313 tool_use_metadata.clone(),
1314 cx,
1315 );
1316
1317 if let Some(input) = streamed_input {
1318 cx.emit(ThreadEvent::StreamedToolUse {
1319 tool_use_id,
1320 ui_text,
1321 input,
1322 });
1323 }
1324 }
1325 }
1326
1327 thread.touch_updated_at();
1328 cx.emit(ThreadEvent::StreamedCompletion);
1329 cx.notify();
1330
1331 thread.auto_capture_telemetry(cx);
1332 Ok(())
1333 })??;
1334
1335 smol::future::yield_now().await;
1336 }
1337
1338 thread.update(cx, |thread, cx| {
1339 thread
1340 .pending_completions
1341 .retain(|completion| completion.id != pending_completion_id);
1342
1343 // If there is a response without tool use, summarize the message. Otherwise,
1344 // allow two tool uses before summarizing.
1345 if thread.summary.is_none()
1346 && thread.messages.len() >= 2
1347 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1348 {
1349 thread.summarize(cx);
1350 }
1351 })?;
1352
1353 anyhow::Ok(stop_reason)
1354 };
1355
1356 let result = stream_completion.await;
1357
1358 thread
1359 .update(cx, |thread, cx| {
1360 thread.finalize_pending_checkpoint(cx);
1361 match result.as_ref() {
1362 Ok(stop_reason) => match stop_reason {
1363 StopReason::ToolUse => {
1364 let tool_uses = thread.use_pending_tools(window, cx);
1365 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1366 }
1367 StopReason::EndTurn => {}
1368 StopReason::MaxTokens => {}
1369 },
1370 Err(error) => {
1371 if error.is::<PaymentRequiredError>() {
1372 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1373 } else if error.is::<MaxMonthlySpendReachedError>() {
1374 cx.emit(ThreadEvent::ShowError(
1375 ThreadError::MaxMonthlySpendReached,
1376 ));
1377 } else if let Some(error) =
1378 error.downcast_ref::<ModelRequestLimitReachedError>()
1379 {
1380 cx.emit(ThreadEvent::ShowError(
1381 ThreadError::ModelRequestLimitReached { plan: error.plan },
1382 ));
1383 } else if let Some(known_error) =
1384 error.downcast_ref::<LanguageModelKnownError>()
1385 {
1386 match known_error {
1387 LanguageModelKnownError::ContextWindowLimitExceeded {
1388 tokens,
1389 } => {
1390 thread.exceeded_window_error = Some(ExceededWindowError {
1391 model_id: model.id(),
1392 token_count: *tokens,
1393 });
1394 cx.notify();
1395 }
1396 }
1397 } else {
1398 let error_message = error
1399 .chain()
1400 .map(|err| err.to_string())
1401 .collect::<Vec<_>>()
1402 .join("\n");
1403 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1404 header: "Error interacting with language model".into(),
1405 message: SharedString::from(error_message.clone()),
1406 }));
1407 }
1408
1409 thread.cancel_last_completion(window, cx);
1410 }
1411 }
1412 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1413
1414 if let Some((request_callback, (request, response_events))) = thread
1415 .request_callback
1416 .as_mut()
1417 .zip(request_callback_parameters.as_ref())
1418 {
1419 request_callback(request, response_events);
1420 }
1421
1422 thread.auto_capture_telemetry(cx);
1423
1424 if let Ok(initial_usage) = initial_token_usage {
1425 let usage = thread.cumulative_token_usage - initial_usage;
1426
1427 telemetry::event!(
1428 "Assistant Thread Completion",
1429 thread_id = thread.id().to_string(),
1430 prompt_id = prompt_id,
1431 model = model.telemetry_id(),
1432 model_provider = model.provider_id().to_string(),
1433 input_tokens = usage.input_tokens,
1434 output_tokens = usage.output_tokens,
1435 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1436 cache_read_input_tokens = usage.cache_read_input_tokens,
1437 );
1438 }
1439 })
1440 .ok();
1441 });
1442
1443 self.pending_completions.push(PendingCompletion {
1444 id: pending_completion_id,
1445 _task: task,
1446 });
1447 }
1448
1449 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1450 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1451 return;
1452 };
1453
1454 if !model.provider.is_authenticated(cx) {
1455 return;
1456 }
1457
1458 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1459 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1460 If the conversation is about a specific subject, include it in the title. \
1461 Be descriptive. DO NOT speak in the first person.";
1462
1463 let request = self.to_summarize_request(added_user_message.into());
1464
1465 self.pending_summary = cx.spawn(async move |this, cx| {
1466 async move {
1467 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1468 let (mut messages, usage) = stream.await?;
1469
1470 if let Some(usage) = usage {
1471 this.update(cx, |_thread, cx| {
1472 cx.emit(ThreadEvent::UsageUpdated(usage));
1473 })
1474 .ok();
1475 }
1476
1477 let mut new_summary = String::new();
1478 while let Some(message) = messages.stream.next().await {
1479 let text = message?;
1480 let mut lines = text.lines();
1481 new_summary.extend(lines.next());
1482
1483 // Stop if the LLM generated multiple lines.
1484 if lines.next().is_some() {
1485 break;
1486 }
1487 }
1488
1489 this.update(cx, |this, cx| {
1490 if !new_summary.is_empty() {
1491 this.summary = Some(new_summary.into());
1492 }
1493
1494 cx.emit(ThreadEvent::SummaryGenerated);
1495 })?;
1496
1497 anyhow::Ok(())
1498 }
1499 .log_err()
1500 .await
1501 });
1502 }
1503
1504 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1505 let last_message_id = self.messages.last().map(|message| message.id)?;
1506
1507 match &self.detailed_summary_state {
1508 DetailedSummaryState::Generating { message_id, .. }
1509 | DetailedSummaryState::Generated { message_id, .. }
1510 if *message_id == last_message_id =>
1511 {
1512 // Already up-to-date
1513 return None;
1514 }
1515 _ => {}
1516 }
1517
1518 let ConfiguredModel { model, provider } =
1519 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1520
1521 if !provider.is_authenticated(cx) {
1522 return None;
1523 }
1524
1525 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1526 1. A brief overview of what was discussed\n\
1527 2. Key facts or information discovered\n\
1528 3. Outcomes or conclusions reached\n\
1529 4. Any action items or next steps if any\n\
1530 Format it in Markdown with headings and bullet points.";
1531
1532 let request = self.to_summarize_request(added_user_message.into());
1533
1534 let task = cx.spawn(async move |thread, cx| {
1535 let stream = model.stream_completion_text(request, &cx);
1536 let Some(mut messages) = stream.await.log_err() else {
1537 thread
1538 .update(cx, |this, _cx| {
1539 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1540 })
1541 .log_err();
1542
1543 return;
1544 };
1545
1546 let mut new_detailed_summary = String::new();
1547
1548 while let Some(chunk) = messages.stream.next().await {
1549 if let Some(chunk) = chunk.log_err() {
1550 new_detailed_summary.push_str(&chunk);
1551 }
1552 }
1553
1554 thread
1555 .update(cx, |this, _cx| {
1556 this.detailed_summary_state = DetailedSummaryState::Generated {
1557 text: new_detailed_summary.into(),
1558 message_id: last_message_id,
1559 };
1560 })
1561 .log_err();
1562 });
1563
1564 self.detailed_summary_state = DetailedSummaryState::Generating {
1565 message_id: last_message_id,
1566 };
1567
1568 Some(task)
1569 }
1570
1571 pub fn is_generating_detailed_summary(&self) -> bool {
1572 matches!(
1573 self.detailed_summary_state,
1574 DetailedSummaryState::Generating { .. }
1575 )
1576 }
1577
1578 pub fn use_pending_tools(
1579 &mut self,
1580 window: Option<AnyWindowHandle>,
1581 cx: &mut Context<Self>,
1582 ) -> Vec<PendingToolUse> {
1583 self.auto_capture_telemetry(cx);
1584 let request = self.to_completion_request(cx);
1585 let messages = Arc::new(request.messages);
1586 let pending_tool_uses = self
1587 .tool_use
1588 .pending_tool_uses()
1589 .into_iter()
1590 .filter(|tool_use| tool_use.status.is_idle())
1591 .cloned()
1592 .collect::<Vec<_>>();
1593
1594 for tool_use in pending_tool_uses.iter() {
1595 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1596 if tool.needs_confirmation(&tool_use.input, cx)
1597 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1598 {
1599 self.tool_use.confirm_tool_use(
1600 tool_use.id.clone(),
1601 tool_use.ui_text.clone(),
1602 tool_use.input.clone(),
1603 messages.clone(),
1604 tool,
1605 );
1606 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1607 } else {
1608 self.run_tool(
1609 tool_use.id.clone(),
1610 tool_use.ui_text.clone(),
1611 tool_use.input.clone(),
1612 &messages,
1613 tool,
1614 window,
1615 cx,
1616 );
1617 }
1618 }
1619 }
1620
1621 pending_tool_uses
1622 }
1623
1624 pub fn receive_invalid_tool_json(
1625 &mut self,
1626 tool_use_id: LanguageModelToolUseId,
1627 tool_name: Arc<str>,
1628 invalid_json: Arc<str>,
1629 error: String,
1630 window: Option<AnyWindowHandle>,
1631 cx: &mut Context<Thread>,
1632 ) {
1633 log::error!("The model returned invalid input JSON: {invalid_json}");
1634
1635 let pending_tool_use = self.tool_use.insert_tool_output(
1636 tool_use_id.clone(),
1637 tool_name,
1638 Err(anyhow!("Error parsing input JSON: {error}")),
1639 cx,
1640 );
1641 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1642 pending_tool_use.ui_text.clone()
1643 } else {
1644 log::error!(
1645 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1646 );
1647 format!("Unknown tool {}", tool_use_id).into()
1648 };
1649
1650 cx.emit(ThreadEvent::InvalidToolInput {
1651 tool_use_id: tool_use_id.clone(),
1652 ui_text,
1653 invalid_input_json: invalid_json,
1654 });
1655
1656 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1657 }
1658
1659 pub fn run_tool(
1660 &mut self,
1661 tool_use_id: LanguageModelToolUseId,
1662 ui_text: impl Into<SharedString>,
1663 input: serde_json::Value,
1664 messages: &[LanguageModelRequestMessage],
1665 tool: Arc<dyn Tool>,
1666 window: Option<AnyWindowHandle>,
1667 cx: &mut Context<Thread>,
1668 ) {
1669 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1670 self.tool_use
1671 .run_pending_tool(tool_use_id, ui_text.into(), task);
1672 }
1673
1674 fn spawn_tool_use(
1675 &mut self,
1676 tool_use_id: LanguageModelToolUseId,
1677 messages: &[LanguageModelRequestMessage],
1678 input: serde_json::Value,
1679 tool: Arc<dyn Tool>,
1680 window: Option<AnyWindowHandle>,
1681 cx: &mut Context<Thread>,
1682 ) -> Task<()> {
1683 let tool_name: Arc<str> = tool.name().into();
1684
1685 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1686 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1687 } else {
1688 tool.run(
1689 input,
1690 messages,
1691 self.project.clone(),
1692 self.action_log.clone(),
1693 window,
1694 cx,
1695 )
1696 };
1697
1698 // Store the card separately if it exists
1699 if let Some(card) = tool_result.card.clone() {
1700 self.tool_use
1701 .insert_tool_result_card(tool_use_id.clone(), card);
1702 }
1703
1704 cx.spawn({
1705 async move |thread: WeakEntity<Thread>, cx| {
1706 let output = tool_result.output.await;
1707
1708 thread
1709 .update(cx, |thread, cx| {
1710 let pending_tool_use = thread.tool_use.insert_tool_output(
1711 tool_use_id.clone(),
1712 tool_name,
1713 output,
1714 cx,
1715 );
1716 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1717 })
1718 .ok();
1719 }
1720 })
1721 }
1722
1723 fn tool_finished(
1724 &mut self,
1725 tool_use_id: LanguageModelToolUseId,
1726 pending_tool_use: Option<PendingToolUse>,
1727 canceled: bool,
1728 window: Option<AnyWindowHandle>,
1729 cx: &mut Context<Self>,
1730 ) {
1731 if self.all_tools_finished() {
1732 let model_registry = LanguageModelRegistry::read_global(cx);
1733 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1734 if !canceled {
1735 self.send_to_model(model, window, cx);
1736 }
1737 self.auto_capture_telemetry(cx);
1738 }
1739 }
1740
1741 cx.emit(ThreadEvent::ToolFinished {
1742 tool_use_id,
1743 pending_tool_use,
1744 });
1745 }
1746
1747 /// Cancels the last pending completion, if there are any pending.
1748 ///
1749 /// Returns whether a completion was canceled.
1750 pub fn cancel_last_completion(
1751 &mut self,
1752 window: Option<AnyWindowHandle>,
1753 cx: &mut Context<Self>,
1754 ) -> bool {
1755 let mut canceled = self.pending_completions.pop().is_some();
1756
1757 for pending_tool_use in self.tool_use.cancel_pending() {
1758 canceled = true;
1759 self.tool_finished(
1760 pending_tool_use.id.clone(),
1761 Some(pending_tool_use),
1762 true,
1763 window,
1764 cx,
1765 );
1766 }
1767
1768 self.finalize_pending_checkpoint(cx);
1769 canceled
1770 }
1771
1772 pub fn feedback(&self) -> Option<ThreadFeedback> {
1773 self.feedback
1774 }
1775
1776 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1777 self.message_feedback.get(&message_id).copied()
1778 }
1779
1780 pub fn report_message_feedback(
1781 &mut self,
1782 message_id: MessageId,
1783 feedback: ThreadFeedback,
1784 cx: &mut Context<Self>,
1785 ) -> Task<Result<()>> {
1786 if self.message_feedback.get(&message_id) == Some(&feedback) {
1787 return Task::ready(Ok(()));
1788 }
1789
1790 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1791 let serialized_thread = self.serialize(cx);
1792 let thread_id = self.id().clone();
1793 let client = self.project.read(cx).client();
1794
1795 let enabled_tool_names: Vec<String> = self
1796 .tools()
1797 .read(cx)
1798 .enabled_tools(cx)
1799 .iter()
1800 .map(|tool| tool.name().to_string())
1801 .collect();
1802
1803 self.message_feedback.insert(message_id, feedback);
1804
1805 cx.notify();
1806
1807 let message_content = self
1808 .message(message_id)
1809 .map(|msg| msg.to_string())
1810 .unwrap_or_default();
1811
1812 cx.background_spawn(async move {
1813 let final_project_snapshot = final_project_snapshot.await;
1814 let serialized_thread = serialized_thread.await?;
1815 let thread_data =
1816 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1817
1818 let rating = match feedback {
1819 ThreadFeedback::Positive => "positive",
1820 ThreadFeedback::Negative => "negative",
1821 };
1822 telemetry::event!(
1823 "Assistant Thread Rated",
1824 rating,
1825 thread_id,
1826 enabled_tool_names,
1827 message_id = message_id.0,
1828 message_content,
1829 thread_data,
1830 final_project_snapshot
1831 );
1832 client.telemetry().flush_events().await;
1833
1834 Ok(())
1835 })
1836 }
1837
1838 pub fn report_feedback(
1839 &mut self,
1840 feedback: ThreadFeedback,
1841 cx: &mut Context<Self>,
1842 ) -> Task<Result<()>> {
1843 let last_assistant_message_id = self
1844 .messages
1845 .iter()
1846 .rev()
1847 .find(|msg| msg.role == Role::Assistant)
1848 .map(|msg| msg.id);
1849
1850 if let Some(message_id) = last_assistant_message_id {
1851 self.report_message_feedback(message_id, feedback, cx)
1852 } else {
1853 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1854 let serialized_thread = self.serialize(cx);
1855 let thread_id = self.id().clone();
1856 let client = self.project.read(cx).client();
1857 self.feedback = Some(feedback);
1858 cx.notify();
1859
1860 cx.background_spawn(async move {
1861 let final_project_snapshot = final_project_snapshot.await;
1862 let serialized_thread = serialized_thread.await?;
1863 let thread_data = serde_json::to_value(serialized_thread)
1864 .unwrap_or_else(|_| serde_json::Value::Null);
1865
1866 let rating = match feedback {
1867 ThreadFeedback::Positive => "positive",
1868 ThreadFeedback::Negative => "negative",
1869 };
1870 telemetry::event!(
1871 "Assistant Thread Rated",
1872 rating,
1873 thread_id,
1874 thread_data,
1875 final_project_snapshot
1876 );
1877 client.telemetry().flush_events().await;
1878
1879 Ok(())
1880 })
1881 }
1882 }
1883
1884 /// Create a snapshot of the current project state including git information and unsaved buffers.
1885 fn project_snapshot(
1886 project: Entity<Project>,
1887 cx: &mut Context<Self>,
1888 ) -> Task<Arc<ProjectSnapshot>> {
1889 let git_store = project.read(cx).git_store().clone();
1890 let worktree_snapshots: Vec<_> = project
1891 .read(cx)
1892 .visible_worktrees(cx)
1893 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1894 .collect();
1895
1896 cx.spawn(async move |_, cx| {
1897 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1898
1899 let mut unsaved_buffers = Vec::new();
1900 cx.update(|app_cx| {
1901 let buffer_store = project.read(app_cx).buffer_store();
1902 for buffer_handle in buffer_store.read(app_cx).buffers() {
1903 let buffer = buffer_handle.read(app_cx);
1904 if buffer.is_dirty() {
1905 if let Some(file) = buffer.file() {
1906 let path = file.path().to_string_lossy().to_string();
1907 unsaved_buffers.push(path);
1908 }
1909 }
1910 }
1911 })
1912 .ok();
1913
1914 Arc::new(ProjectSnapshot {
1915 worktree_snapshots,
1916 unsaved_buffer_paths: unsaved_buffers,
1917 timestamp: Utc::now(),
1918 })
1919 })
1920 }
1921
1922 fn worktree_snapshot(
1923 worktree: Entity<project::Worktree>,
1924 git_store: Entity<GitStore>,
1925 cx: &App,
1926 ) -> Task<WorktreeSnapshot> {
1927 cx.spawn(async move |cx| {
1928 // Get worktree path and snapshot
1929 let worktree_info = cx.update(|app_cx| {
1930 let worktree = worktree.read(app_cx);
1931 let path = worktree.abs_path().to_string_lossy().to_string();
1932 let snapshot = worktree.snapshot();
1933 (path, snapshot)
1934 });
1935
1936 let Ok((worktree_path, _snapshot)) = worktree_info else {
1937 return WorktreeSnapshot {
1938 worktree_path: String::new(),
1939 git_state: None,
1940 };
1941 };
1942
1943 let git_state = git_store
1944 .update(cx, |git_store, cx| {
1945 git_store
1946 .repositories()
1947 .values()
1948 .find(|repo| {
1949 repo.read(cx)
1950 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1951 .is_some()
1952 })
1953 .cloned()
1954 })
1955 .ok()
1956 .flatten()
1957 .map(|repo| {
1958 repo.update(cx, |repo, _| {
1959 let current_branch =
1960 repo.branch.as_ref().map(|branch| branch.name.to_string());
1961 repo.send_job(None, |state, _| async move {
1962 let RepositoryState::Local { backend, .. } = state else {
1963 return GitState {
1964 remote_url: None,
1965 head_sha: None,
1966 current_branch,
1967 diff: None,
1968 };
1969 };
1970
1971 let remote_url = backend.remote_url("origin");
1972 let head_sha = backend.head_sha();
1973 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1974
1975 GitState {
1976 remote_url,
1977 head_sha,
1978 current_branch,
1979 diff,
1980 }
1981 })
1982 })
1983 });
1984
1985 let git_state = match git_state {
1986 Some(git_state) => match git_state.ok() {
1987 Some(git_state) => git_state.await.ok(),
1988 None => None,
1989 },
1990 None => None,
1991 };
1992
1993 WorktreeSnapshot {
1994 worktree_path,
1995 git_state,
1996 }
1997 })
1998 }
1999
2000 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2001 let mut markdown = Vec::new();
2002
2003 if let Some(summary) = self.summary() {
2004 writeln!(markdown, "# {summary}\n")?;
2005 };
2006
2007 for message in self.messages() {
2008 writeln!(
2009 markdown,
2010 "## {role}\n",
2011 role = match message.role {
2012 Role::User => "User",
2013 Role::Assistant => "Assistant",
2014 Role::System => "System",
2015 }
2016 )?;
2017
2018 if !message.loaded_context.text.is_empty() {
2019 writeln!(markdown, "{}", message.loaded_context.text)?;
2020 }
2021
2022 if !message.loaded_context.images.is_empty() {
2023 writeln!(
2024 markdown,
2025 "\n{} images attached as context.\n",
2026 message.loaded_context.images.len()
2027 )?;
2028 }
2029
2030 for segment in &message.segments {
2031 match segment {
2032 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2033 MessageSegment::Thinking { text, .. } => {
2034 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2035 }
2036 MessageSegment::RedactedThinking(_) => {}
2037 }
2038 }
2039
2040 for tool_use in self.tool_uses_for_message(message.id, cx) {
2041 writeln!(
2042 markdown,
2043 "**Use Tool: {} ({})**",
2044 tool_use.name, tool_use.id
2045 )?;
2046 writeln!(markdown, "```json")?;
2047 writeln!(
2048 markdown,
2049 "{}",
2050 serde_json::to_string_pretty(&tool_use.input)?
2051 )?;
2052 writeln!(markdown, "```")?;
2053 }
2054
2055 for tool_result in self.tool_results_for_message(message.id) {
2056 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2057 if tool_result.is_error {
2058 write!(markdown, " (Error)")?;
2059 }
2060
2061 writeln!(markdown, "**\n")?;
2062 writeln!(markdown, "{}", tool_result.content)?;
2063 }
2064 }
2065
2066 Ok(String::from_utf8_lossy(&markdown).to_string())
2067 }
2068
2069 pub fn keep_edits_in_range(
2070 &mut self,
2071 buffer: Entity<language::Buffer>,
2072 buffer_range: Range<language::Anchor>,
2073 cx: &mut Context<Self>,
2074 ) {
2075 self.action_log.update(cx, |action_log, cx| {
2076 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2077 });
2078 }
2079
2080 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2081 self.action_log
2082 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2083 }
2084
2085 pub fn reject_edits_in_ranges(
2086 &mut self,
2087 buffer: Entity<language::Buffer>,
2088 buffer_ranges: Vec<Range<language::Anchor>>,
2089 cx: &mut Context<Self>,
2090 ) -> Task<Result<()>> {
2091 self.action_log.update(cx, |action_log, cx| {
2092 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2093 })
2094 }
2095
2096 pub fn action_log(&self) -> &Entity<ActionLog> {
2097 &self.action_log
2098 }
2099
2100 pub fn project(&self) -> &Entity<Project> {
2101 &self.project
2102 }
2103
2104 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2105 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2106 return;
2107 }
2108
2109 let now = Instant::now();
2110 if let Some(last) = self.last_auto_capture_at {
2111 if now.duration_since(last).as_secs() < 10 {
2112 return;
2113 }
2114 }
2115
2116 self.last_auto_capture_at = Some(now);
2117
2118 let thread_id = self.id().clone();
2119 let github_login = self
2120 .project
2121 .read(cx)
2122 .user_store()
2123 .read(cx)
2124 .current_user()
2125 .map(|user| user.github_login.clone());
2126 let client = self.project.read(cx).client().clone();
2127 let serialize_task = self.serialize(cx);
2128
2129 cx.background_executor()
2130 .spawn(async move {
2131 if let Ok(serialized_thread) = serialize_task.await {
2132 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2133 telemetry::event!(
2134 "Agent Thread Auto-Captured",
2135 thread_id = thread_id.to_string(),
2136 thread_data = thread_data,
2137 auto_capture_reason = "tracked_user",
2138 github_login = github_login
2139 );
2140
2141 client.telemetry().flush_events().await;
2142 }
2143 }
2144 })
2145 .detach();
2146 }
2147
2148 pub fn cumulative_token_usage(&self) -> TokenUsage {
2149 self.cumulative_token_usage
2150 }
2151
2152 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2153 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2154 return TotalTokenUsage::default();
2155 };
2156
2157 let max = model.model.max_token_count();
2158
2159 let index = self
2160 .messages
2161 .iter()
2162 .position(|msg| msg.id == message_id)
2163 .unwrap_or(0);
2164
2165 if index == 0 {
2166 return TotalTokenUsage { total: 0, max };
2167 }
2168
2169 let token_usage = &self
2170 .request_token_usage
2171 .get(index - 1)
2172 .cloned()
2173 .unwrap_or_default();
2174
2175 TotalTokenUsage {
2176 total: token_usage.total_tokens() as usize,
2177 max,
2178 }
2179 }
2180
2181 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2182 let model_registry = LanguageModelRegistry::read_global(cx);
2183 let Some(model) = model_registry.default_model() else {
2184 return TotalTokenUsage::default();
2185 };
2186
2187 let max = model.model.max_token_count();
2188
2189 if let Some(exceeded_error) = &self.exceeded_window_error {
2190 if model.model.id() == exceeded_error.model_id {
2191 return TotalTokenUsage {
2192 total: exceeded_error.token_count,
2193 max,
2194 };
2195 }
2196 }
2197
2198 let total = self
2199 .token_usage_at_last_message()
2200 .unwrap_or_default()
2201 .total_tokens() as usize;
2202
2203 TotalTokenUsage { total, max }
2204 }
2205
2206 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2207 self.request_token_usage
2208 .get(self.messages.len().saturating_sub(1))
2209 .or_else(|| self.request_token_usage.last())
2210 .cloned()
2211 }
2212
2213 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2214 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2215 self.request_token_usage
2216 .resize(self.messages.len(), placeholder);
2217
2218 if let Some(last) = self.request_token_usage.last_mut() {
2219 *last = token_usage;
2220 }
2221 }
2222
2223 pub fn deny_tool_use(
2224 &mut self,
2225 tool_use_id: LanguageModelToolUseId,
2226 tool_name: Arc<str>,
2227 window: Option<AnyWindowHandle>,
2228 cx: &mut Context<Self>,
2229 ) {
2230 let err = Err(anyhow::anyhow!(
2231 "Permission to run tool action denied by user"
2232 ));
2233
2234 self.tool_use
2235 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2236 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2237 }
2238}
2239
2240#[derive(Debug, Clone, Error)]
2241pub enum ThreadError {
2242 #[error("Payment required")]
2243 PaymentRequired,
2244 #[error("Max monthly spend reached")]
2245 MaxMonthlySpendReached,
2246 #[error("Model request limit reached")]
2247 ModelRequestLimitReached { plan: Plan },
2248 #[error("Message {header}: {message}")]
2249 Message {
2250 header: SharedString,
2251 message: SharedString,
2252 },
2253}
2254
2255#[derive(Debug, Clone)]
2256pub enum ThreadEvent {
2257 ShowError(ThreadError),
2258 UsageUpdated(RequestUsage),
2259 StreamedCompletion,
2260 ReceivedTextChunk,
2261 StreamedAssistantText(MessageId, String),
2262 StreamedAssistantThinking(MessageId, String),
2263 StreamedToolUse {
2264 tool_use_id: LanguageModelToolUseId,
2265 ui_text: Arc<str>,
2266 input: serde_json::Value,
2267 },
2268 InvalidToolInput {
2269 tool_use_id: LanguageModelToolUseId,
2270 ui_text: Arc<str>,
2271 invalid_input_json: Arc<str>,
2272 },
2273 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2274 MessageAdded(MessageId),
2275 MessageEdited(MessageId),
2276 MessageDeleted(MessageId),
2277 SummaryGenerated,
2278 SummaryChanged,
2279 UsePendingTools {
2280 tool_uses: Vec<PendingToolUse>,
2281 },
2282 ToolFinished {
2283 #[allow(unused)]
2284 tool_use_id: LanguageModelToolUseId,
2285 /// The pending tool use that corresponds to this tool.
2286 pending_tool_use: Option<PendingToolUse>,
2287 },
2288 CheckpointChanged,
2289 ToolConfirmationNeeded,
2290}
2291
2292impl EventEmitter<ThreadEvent> for Thread {}
2293
2294struct PendingCompletion {
2295 id: usize,
2296 _task: Task<()>,
2297}
2298
2299#[cfg(test)]
2300mod tests {
2301 use super::*;
2302 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2303 use assistant_settings::AssistantSettings;
2304 use context_server::ContextServerSettings;
2305 use editor::EditorSettings;
2306 use gpui::TestAppContext;
2307 use project::{FakeFs, Project};
2308 use prompt_store::PromptBuilder;
2309 use serde_json::json;
2310 use settings::{Settings, SettingsStore};
2311 use std::sync::Arc;
2312 use theme::ThemeSettings;
2313 use util::path;
2314 use workspace::Workspace;
2315
2316 #[gpui::test]
2317 async fn test_message_with_context(cx: &mut TestAppContext) {
2318 init_test_settings(cx);
2319
2320 let project = create_test_project(
2321 cx,
2322 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2323 )
2324 .await;
2325
2326 let (_workspace, _thread_store, thread, context_store) =
2327 setup_test_environment(cx, project.clone()).await;
2328
2329 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2330 .await
2331 .unwrap();
2332
2333 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2334 let loaded_context = cx
2335 .update(|cx| load_context(vec![context], &project, &None, cx))
2336 .await;
2337
2338 // Insert user message with context
2339 let message_id = thread.update(cx, |thread, cx| {
2340 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2341 });
2342
2343 // Check content and context in message object
2344 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2345
2346 // Use different path format strings based on platform for the test
2347 #[cfg(windows)]
2348 let path_part = r"test\code.rs";
2349 #[cfg(not(windows))]
2350 let path_part = "test/code.rs";
2351
2352 let expected_context = format!(
2353 r#"
2354<context>
2355The following items were attached by the user. You don't need to use other tools to read them.
2356
2357<files>
2358```rs {path_part}
2359fn main() {{
2360 println!("Hello, world!");
2361}}
2362```
2363</files>
2364</context>
2365"#
2366 );
2367
2368 assert_eq!(message.role, Role::User);
2369 assert_eq!(message.segments.len(), 1);
2370 assert_eq!(
2371 message.segments[0],
2372 MessageSegment::Text("Please explain this code".to_string())
2373 );
2374 assert_eq!(message.loaded_context.text, expected_context);
2375
2376 // Check message in request
2377 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2378
2379 assert_eq!(request.messages.len(), 2);
2380 let expected_full_message = format!("{}Please explain this code", expected_context);
2381 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2382 }
2383
2384 #[gpui::test]
2385 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2386 init_test_settings(cx);
2387
2388 let project = create_test_project(
2389 cx,
2390 json!({
2391 "file1.rs": "fn function1() {}\n",
2392 "file2.rs": "fn function2() {}\n",
2393 "file3.rs": "fn function3() {}\n",
2394 }),
2395 )
2396 .await;
2397
2398 let (_, _thread_store, thread, context_store) =
2399 setup_test_environment(cx, project.clone()).await;
2400
2401 // First message with context 1
2402 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2403 .await
2404 .unwrap();
2405 let new_contexts = context_store.update(cx, |store, cx| {
2406 store.new_context_for_thread(thread.read(cx))
2407 });
2408 assert_eq!(new_contexts.len(), 1);
2409 let loaded_context = cx
2410 .update(|cx| load_context(new_contexts, &project, &None, cx))
2411 .await;
2412 let message1_id = thread.update(cx, |thread, cx| {
2413 thread.insert_user_message("Message 1", loaded_context, None, cx)
2414 });
2415
2416 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2417 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2418 .await
2419 .unwrap();
2420 let new_contexts = context_store.update(cx, |store, cx| {
2421 store.new_context_for_thread(thread.read(cx))
2422 });
2423 assert_eq!(new_contexts.len(), 1);
2424 let loaded_context = cx
2425 .update(|cx| load_context(new_contexts, &project, &None, cx))
2426 .await;
2427 let message2_id = thread.update(cx, |thread, cx| {
2428 thread.insert_user_message("Message 2", loaded_context, None, cx)
2429 });
2430
2431 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2432 //
2433 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2434 .await
2435 .unwrap();
2436 let new_contexts = context_store.update(cx, |store, cx| {
2437 store.new_context_for_thread(thread.read(cx))
2438 });
2439 assert_eq!(new_contexts.len(), 1);
2440 let loaded_context = cx
2441 .update(|cx| load_context(new_contexts, &project, &None, cx))
2442 .await;
2443 let message3_id = thread.update(cx, |thread, cx| {
2444 thread.insert_user_message("Message 3", loaded_context, None, cx)
2445 });
2446
2447 // Check what contexts are included in each message
2448 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2449 (
2450 thread.message(message1_id).unwrap().clone(),
2451 thread.message(message2_id).unwrap().clone(),
2452 thread.message(message3_id).unwrap().clone(),
2453 )
2454 });
2455
2456 // First message should include context 1
2457 assert!(message1.loaded_context.text.contains("file1.rs"));
2458
2459 // Second message should include only context 2 (not 1)
2460 assert!(!message2.loaded_context.text.contains("file1.rs"));
2461 assert!(message2.loaded_context.text.contains("file2.rs"));
2462
2463 // Third message should include only context 3 (not 1 or 2)
2464 assert!(!message3.loaded_context.text.contains("file1.rs"));
2465 assert!(!message3.loaded_context.text.contains("file2.rs"));
2466 assert!(message3.loaded_context.text.contains("file3.rs"));
2467
2468 // Check entire request to make sure all contexts are properly included
2469 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2470
2471 // The request should contain all 3 messages
2472 assert_eq!(request.messages.len(), 4);
2473
2474 // Check that the contexts are properly formatted in each message
2475 assert!(request.messages[1].string_contents().contains("file1.rs"));
2476 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2477 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2478
2479 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2480 assert!(request.messages[2].string_contents().contains("file2.rs"));
2481 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2482
2483 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2484 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2485 assert!(request.messages[3].string_contents().contains("file3.rs"));
2486 }
2487
2488 #[gpui::test]
2489 async fn test_message_without_files(cx: &mut TestAppContext) {
2490 init_test_settings(cx);
2491
2492 let project = create_test_project(
2493 cx,
2494 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2495 )
2496 .await;
2497
2498 let (_, _thread_store, thread, _context_store) =
2499 setup_test_environment(cx, project.clone()).await;
2500
2501 // Insert user message without any context (empty context vector)
2502 let message_id = thread.update(cx, |thread, cx| {
2503 thread.insert_user_message(
2504 "What is the best way to learn Rust?",
2505 ContextLoadResult::default(),
2506 None,
2507 cx,
2508 )
2509 });
2510
2511 // Check content and context in message object
2512 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2513
2514 // Context should be empty when no files are included
2515 assert_eq!(message.role, Role::User);
2516 assert_eq!(message.segments.len(), 1);
2517 assert_eq!(
2518 message.segments[0],
2519 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2520 );
2521 assert_eq!(message.loaded_context.text, "");
2522
2523 // Check message in request
2524 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2525
2526 assert_eq!(request.messages.len(), 2);
2527 assert_eq!(
2528 request.messages[1].string_contents(),
2529 "What is the best way to learn Rust?"
2530 );
2531
2532 // Add second message, also without context
2533 let message2_id = thread.update(cx, |thread, cx| {
2534 thread.insert_user_message(
2535 "Are there any good books?",
2536 ContextLoadResult::default(),
2537 None,
2538 cx,
2539 )
2540 });
2541
2542 let message2 =
2543 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2544 assert_eq!(message2.loaded_context.text, "");
2545
2546 // Check that both messages appear in the request
2547 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2548
2549 assert_eq!(request.messages.len(), 3);
2550 assert_eq!(
2551 request.messages[1].string_contents(),
2552 "What is the best way to learn Rust?"
2553 );
2554 assert_eq!(
2555 request.messages[2].string_contents(),
2556 "Are there any good books?"
2557 );
2558 }
2559
2560 #[gpui::test]
2561 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2562 init_test_settings(cx);
2563
2564 let project = create_test_project(
2565 cx,
2566 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2567 )
2568 .await;
2569
2570 let (_workspace, _thread_store, thread, context_store) =
2571 setup_test_environment(cx, project.clone()).await;
2572
2573 // Open buffer and add it to context
2574 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2575 .await
2576 .unwrap();
2577
2578 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2579 let loaded_context = cx
2580 .update(|cx| load_context(vec![context], &project, &None, cx))
2581 .await;
2582
2583 // Insert user message with the buffer as context
2584 thread.update(cx, |thread, cx| {
2585 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2586 });
2587
2588 // Create a request and check that it doesn't have a stale buffer warning yet
2589 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2590
2591 // Make sure we don't have a stale file warning yet
2592 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2593 msg.string_contents()
2594 .contains("These files changed since last read:")
2595 });
2596 assert!(
2597 !has_stale_warning,
2598 "Should not have stale buffer warning before buffer is modified"
2599 );
2600
2601 // Modify the buffer
2602 buffer.update(cx, |buffer, cx| {
2603 // Find a position at the end of line 1
2604 buffer.edit(
2605 [(1..1, "\n println!(\"Added a new line\");\n")],
2606 None,
2607 cx,
2608 );
2609 });
2610
2611 // Insert another user message without context
2612 thread.update(cx, |thread, cx| {
2613 thread.insert_user_message(
2614 "What does the code do now?",
2615 ContextLoadResult::default(),
2616 None,
2617 cx,
2618 )
2619 });
2620
2621 // Create a new request and check for the stale buffer warning
2622 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2623
2624 // We should have a stale file warning as the last message
2625 let last_message = new_request
2626 .messages
2627 .last()
2628 .expect("Request should have messages");
2629
2630 // The last message should be the stale buffer notification
2631 assert_eq!(last_message.role, Role::User);
2632
2633 // Check the exact content of the message
2634 let expected_content = "These files changed since last read:\n- code.rs\n";
2635 assert_eq!(
2636 last_message.string_contents(),
2637 expected_content,
2638 "Last message should be exactly the stale buffer notification"
2639 );
2640 }
2641
2642 fn init_test_settings(cx: &mut TestAppContext) {
2643 cx.update(|cx| {
2644 let settings_store = SettingsStore::test(cx);
2645 cx.set_global(settings_store);
2646 language::init(cx);
2647 Project::init_settings(cx);
2648 AssistantSettings::register(cx);
2649 prompt_store::init(cx);
2650 thread_store::init(cx);
2651 workspace::init_settings(cx);
2652 ThemeSettings::register(cx);
2653 ContextServerSettings::register(cx);
2654 EditorSettings::register(cx);
2655 });
2656 }
2657
2658 // Helper to create a test project with test files
2659 async fn create_test_project(
2660 cx: &mut TestAppContext,
2661 files: serde_json::Value,
2662 ) -> Entity<Project> {
2663 let fs = FakeFs::new(cx.executor());
2664 fs.insert_tree(path!("/test"), files).await;
2665 Project::test(fs, [path!("/test").as_ref()], cx).await
2666 }
2667
2668 async fn setup_test_environment(
2669 cx: &mut TestAppContext,
2670 project: Entity<Project>,
2671 ) -> (
2672 Entity<Workspace>,
2673 Entity<ThreadStore>,
2674 Entity<Thread>,
2675 Entity<ContextStore>,
2676 ) {
2677 let (workspace, cx) =
2678 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2679
2680 let thread_store = cx
2681 .update(|_, cx| {
2682 ThreadStore::load(
2683 project.clone(),
2684 cx.new(|_| ToolWorkingSet::default()),
2685 None,
2686 Arc::new(PromptBuilder::new(None).unwrap()),
2687 cx,
2688 )
2689 })
2690 .await
2691 .unwrap();
2692
2693 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2694 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2695
2696 (workspace, thread_store, thread, context_store)
2697 }
2698
2699 async fn add_file_to_context(
2700 project: &Entity<Project>,
2701 context_store: &Entity<ContextStore>,
2702 path: &str,
2703 cx: &mut TestAppContext,
2704 ) -> Result<Entity<language::Buffer>> {
2705 let buffer_path = project
2706 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2707 .unwrap();
2708
2709 let buffer = project
2710 .update(cx, |project, cx| {
2711 project.open_buffer(buffer_path.clone(), cx)
2712 })
2713 .await
2714 .unwrap();
2715
2716 context_store.update(cx, |context_store, cx| {
2717 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2718 });
2719
2720 Ok(buffer)
2721 }
2722}