1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100 pub images: Vec<LanguageModelImage>,
101}
102
103impl Message {
104 /// Returns whether the message contains any meaningful text that should be displayed
105 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
106 pub fn should_display_content(&self) -> bool {
107 self.segments.iter().all(|segment| segment.should_display())
108 }
109
110 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
111 if let Some(MessageSegment::Thinking {
112 text: segment,
113 signature: current_signature,
114 }) = self.segments.last_mut()
115 {
116 if let Some(signature) = signature {
117 *current_signature = Some(signature);
118 }
119 segment.push_str(text);
120 } else {
121 self.segments.push(MessageSegment::Thinking {
122 text: text.to_string(),
123 signature,
124 });
125 }
126 }
127
128 pub fn push_text(&mut self, text: &str) {
129 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
130 segment.push_str(text);
131 } else {
132 self.segments.push(MessageSegment::Text(text.to_string()));
133 }
134 }
135
136 pub fn to_string(&self) -> String {
137 let mut result = String::new();
138
139 if !self.context.is_empty() {
140 result.push_str(&self.context);
141 }
142
143 for segment in &self.segments {
144 match segment {
145 MessageSegment::Text(text) => result.push_str(text),
146 MessageSegment::Thinking { text, .. } => {
147 result.push_str("<think>\n");
148 result.push_str(text);
149 result.push_str("\n</think>");
150 }
151 MessageSegment::RedactedThinking(_) => {}
152 }
153 }
154
155 result
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub enum MessageSegment {
161 Text(String),
162 Thinking {
163 text: String,
164 signature: Option<String>,
165 },
166 RedactedThinking(Vec<u8>),
167}
168
169impl MessageSegment {
170 pub fn should_display(&self) -> bool {
171 match self {
172 Self::Text(text) => text.is_empty(),
173 Self::Thinking { text, .. } => text.is_empty(),
174 Self::RedactedThinking(_) => false,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ProjectSnapshot {
181 pub worktree_snapshots: Vec<WorktreeSnapshot>,
182 pub unsaved_buffer_paths: Vec<String>,
183 pub timestamp: DateTime<Utc>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct WorktreeSnapshot {
188 pub worktree_path: String,
189 pub git_state: Option<GitState>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GitState {
194 pub remote_url: Option<String>,
195 pub head_sha: Option<String>,
196 pub current_branch: Option<String>,
197 pub diff: Option<String>,
198}
199
200#[derive(Clone)]
201pub struct ThreadCheckpoint {
202 message_id: MessageId,
203 git_checkpoint: GitStoreCheckpoint,
204}
205
206#[derive(Copy, Clone, Debug, PartialEq, Eq)]
207pub enum ThreadFeedback {
208 Positive,
209 Negative,
210}
211
212pub enum LastRestoreCheckpoint {
213 Pending {
214 message_id: MessageId,
215 },
216 Error {
217 message_id: MessageId,
218 error: String,
219 },
220}
221
222impl LastRestoreCheckpoint {
223 pub fn message_id(&self) -> MessageId {
224 match self {
225 LastRestoreCheckpoint::Pending { message_id } => *message_id,
226 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
227 }
228 }
229}
230
231#[derive(Clone, Debug, Default, Serialize, Deserialize)]
232pub enum DetailedSummaryState {
233 #[default]
234 NotGenerated,
235 Generating {
236 message_id: MessageId,
237 },
238 Generated {
239 text: SharedString,
240 message_id: MessageId,
241 },
242}
243
244#[derive(Default)]
245pub struct TotalTokenUsage {
246 pub total: usize,
247 pub max: usize,
248}
249
250impl TotalTokenUsage {
251 pub fn ratio(&self) -> TokenUsageRatio {
252 #[cfg(debug_assertions)]
253 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
254 .unwrap_or("0.8".to_string())
255 .parse()
256 .unwrap();
257 #[cfg(not(debug_assertions))]
258 let warning_threshold: f32 = 0.8;
259
260 if self.total >= self.max {
261 TokenUsageRatio::Exceeded
262 } else if self.total as f32 / self.max as f32 >= warning_threshold {
263 TokenUsageRatio::Warning
264 } else {
265 TokenUsageRatio::Normal
266 }
267 }
268
269 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
270 TotalTokenUsage {
271 total: self.total + tokens,
272 max: self.max,
273 }
274 }
275}
276
277#[derive(Debug, Default, PartialEq, Eq)]
278pub enum TokenUsageRatio {
279 #[default]
280 Normal,
281 Warning,
282 Exceeded,
283}
284
285/// A thread of conversation with the LLM.
286pub struct Thread {
287 id: ThreadId,
288 updated_at: DateTime<Utc>,
289 summary: Option<SharedString>,
290 pending_summary: Task<Option<()>>,
291 detailed_summary_state: DetailedSummaryState,
292 messages: Vec<Message>,
293 next_message_id: MessageId,
294 last_prompt_id: PromptId,
295 context: BTreeMap<ContextId, AssistantContext>,
296 context_by_message: HashMap<MessageId, Vec<ContextId>>,
297 project_context: SharedProjectContext,
298 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
299 completion_count: usize,
300 pending_completions: Vec<PendingCompletion>,
301 project: Entity<Project>,
302 prompt_builder: Arc<PromptBuilder>,
303 tools: Entity<ToolWorkingSet>,
304 tool_use: ToolUseState,
305 action_log: Entity<ActionLog>,
306 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
307 pending_checkpoint: Option<ThreadCheckpoint>,
308 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
309 request_token_usage: Vec<TokenUsage>,
310 cumulative_token_usage: TokenUsage,
311 exceeded_window_error: Option<ExceededWindowError>,
312 feedback: Option<ThreadFeedback>,
313 message_feedback: HashMap<MessageId, ThreadFeedback>,
314 last_auto_capture_at: Option<Instant>,
315 request_callback: Option<
316 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
317 >,
318 remaining_turns: u32,
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct ExceededWindowError {
323 /// Model used when last message exceeded context window
324 model_id: LanguageModelId,
325 /// Token count including last message
326 token_count: usize,
327}
328
329impl Thread {
330 pub fn new(
331 project: Entity<Project>,
332 tools: Entity<ToolWorkingSet>,
333 prompt_builder: Arc<PromptBuilder>,
334 system_prompt: SharedProjectContext,
335 cx: &mut Context<Self>,
336 ) -> Self {
337 Self {
338 id: ThreadId::new(),
339 updated_at: Utc::now(),
340 summary: None,
341 pending_summary: Task::ready(None),
342 detailed_summary_state: DetailedSummaryState::NotGenerated,
343 messages: Vec::new(),
344 next_message_id: MessageId(0),
345 last_prompt_id: PromptId::new(),
346 context: BTreeMap::default(),
347 context_by_message: HashMap::default(),
348 project_context: system_prompt,
349 checkpoints_by_message: HashMap::default(),
350 completion_count: 0,
351 pending_completions: Vec::new(),
352 project: project.clone(),
353 prompt_builder,
354 tools: tools.clone(),
355 last_restore_checkpoint: None,
356 pending_checkpoint: None,
357 tool_use: ToolUseState::new(tools.clone()),
358 action_log: cx.new(|_| ActionLog::new(project.clone())),
359 initial_project_snapshot: {
360 let project_snapshot = Self::project_snapshot(project, cx);
361 cx.foreground_executor()
362 .spawn(async move { Some(project_snapshot.await) })
363 .shared()
364 },
365 request_token_usage: Vec::new(),
366 cumulative_token_usage: TokenUsage::default(),
367 exceeded_window_error: None,
368 feedback: None,
369 message_feedback: HashMap::default(),
370 last_auto_capture_at: None,
371 request_callback: None,
372 remaining_turns: u32::MAX,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use =
393 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
394
395 Self {
396 id,
397 updated_at: serialized.updated_at,
398 summary: Some(serialized.summary),
399 pending_summary: Task::ready(None),
400 detailed_summary_state: serialized.detailed_summary_state,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 context: message.context,
421 images: Vec::new(),
422 })
423 .collect(),
424 next_message_id,
425 last_prompt_id: PromptId::new(),
426 context: BTreeMap::default(),
427 context_by_message: HashMap::default(),
428 project_context,
429 checkpoints_by_message: HashMap::default(),
430 completion_count: 0,
431 pending_completions: Vec::new(),
432 last_restore_checkpoint: None,
433 pending_checkpoint: None,
434 project: project.clone(),
435 prompt_builder,
436 tools,
437 tool_use,
438 action_log: cx.new(|_| ActionLog::new(project)),
439 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
440 request_token_usage: serialized.request_token_usage,
441 cumulative_token_usage: serialized.cumulative_token_usage,
442 exceeded_window_error: None,
443 feedback: None,
444 message_feedback: HashMap::default(),
445 last_auto_capture_at: None,
446 request_callback: None,
447 remaining_turns: u32::MAX,
448 }
449 }
450
451 pub fn set_request_callback(
452 &mut self,
453 callback: impl 'static
454 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
455 ) {
456 self.request_callback = Some(Box::new(callback));
457 }
458
459 pub fn id(&self) -> &ThreadId {
460 &self.id
461 }
462
463 pub fn is_empty(&self) -> bool {
464 self.messages.is_empty()
465 }
466
467 pub fn updated_at(&self) -> DateTime<Utc> {
468 self.updated_at
469 }
470
471 pub fn touch_updated_at(&mut self) {
472 self.updated_at = Utc::now();
473 }
474
475 pub fn advance_prompt_id(&mut self) {
476 self.last_prompt_id = PromptId::new();
477 }
478
479 pub fn summary(&self) -> Option<SharedString> {
480 self.summary.clone()
481 }
482
483 pub fn project_context(&self) -> SharedProjectContext {
484 self.project_context.clone()
485 }
486
487 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
488
489 pub fn summary_or_default(&self) -> SharedString {
490 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
491 }
492
493 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
494 let Some(current_summary) = &self.summary else {
495 // Don't allow setting summary until generated
496 return;
497 };
498
499 let mut new_summary = new_summary.into();
500
501 if new_summary.is_empty() {
502 new_summary = Self::DEFAULT_SUMMARY;
503 }
504
505 if current_summary != &new_summary {
506 self.summary = Some(new_summary);
507 cx.emit(ThreadEvent::SummaryChanged);
508 }
509 }
510
511 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
512 self.latest_detailed_summary()
513 .unwrap_or_else(|| self.text().into())
514 }
515
516 fn latest_detailed_summary(&self) -> Option<SharedString> {
517 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
518 Some(text.clone())
519 } else {
520 None
521 }
522 }
523
524 pub fn message(&self, id: MessageId) -> Option<&Message> {
525 self.messages.iter().find(|message| message.id == id)
526 }
527
528 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
529 self.messages.iter()
530 }
531
532 pub fn is_generating(&self) -> bool {
533 !self.pending_completions.is_empty() || !self.all_tools_finished()
534 }
535
536 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
537 &self.tools
538 }
539
540 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
541 self.tool_use
542 .pending_tool_uses()
543 .into_iter()
544 .find(|tool_use| &tool_use.id == id)
545 }
546
547 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
548 self.tool_use
549 .pending_tool_uses()
550 .into_iter()
551 .filter(|tool_use| tool_use.status.needs_confirmation())
552 }
553
554 pub fn has_pending_tool_uses(&self) -> bool {
555 !self.tool_use.pending_tool_uses().is_empty()
556 }
557
558 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
559 self.checkpoints_by_message.get(&id).cloned()
560 }
561
562 pub fn restore_checkpoint(
563 &mut self,
564 checkpoint: ThreadCheckpoint,
565 cx: &mut Context<Self>,
566 ) -> Task<Result<()>> {
567 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
568 message_id: checkpoint.message_id,
569 });
570 cx.emit(ThreadEvent::CheckpointChanged);
571 cx.notify();
572
573 let git_store = self.project().read(cx).git_store().clone();
574 let restore = git_store.update(cx, |git_store, cx| {
575 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
576 });
577
578 cx.spawn(async move |this, cx| {
579 let result = restore.await;
580 this.update(cx, |this, cx| {
581 if let Err(err) = result.as_ref() {
582 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
583 message_id: checkpoint.message_id,
584 error: err.to_string(),
585 });
586 } else {
587 this.truncate(checkpoint.message_id, cx);
588 this.last_restore_checkpoint = None;
589 }
590 this.pending_checkpoint = None;
591 cx.emit(ThreadEvent::CheckpointChanged);
592 cx.notify();
593 })?;
594 result
595 })
596 }
597
598 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
599 let pending_checkpoint = if self.is_generating() {
600 return;
601 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
602 checkpoint
603 } else {
604 return;
605 };
606
607 let git_store = self.project.read(cx).git_store().clone();
608 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
609 cx.spawn(async move |this, cx| match final_checkpoint.await {
610 Ok(final_checkpoint) => {
611 let equal = git_store
612 .update(cx, |store, cx| {
613 store.compare_checkpoints(
614 pending_checkpoint.git_checkpoint.clone(),
615 final_checkpoint.clone(),
616 cx,
617 )
618 })?
619 .await
620 .unwrap_or(false);
621
622 if equal {
623 git_store
624 .update(cx, |store, cx| {
625 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
626 })?
627 .detach();
628 } else {
629 this.update(cx, |this, cx| {
630 this.insert_checkpoint(pending_checkpoint, cx)
631 })?;
632 }
633
634 git_store
635 .update(cx, |store, cx| {
636 store.delete_checkpoint(final_checkpoint, cx)
637 })?
638 .detach();
639
640 Ok(())
641 }
642 Err(_) => this.update(cx, |this, cx| {
643 this.insert_checkpoint(pending_checkpoint, cx)
644 }),
645 })
646 .detach();
647 }
648
649 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
650 self.checkpoints_by_message
651 .insert(checkpoint.message_id, checkpoint);
652 cx.emit(ThreadEvent::CheckpointChanged);
653 cx.notify();
654 }
655
656 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
657 self.last_restore_checkpoint.as_ref()
658 }
659
660 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
661 let Some(message_ix) = self
662 .messages
663 .iter()
664 .rposition(|message| message.id == message_id)
665 else {
666 return;
667 };
668 for deleted_message in self.messages.drain(message_ix..) {
669 self.context_by_message.remove(&deleted_message.id);
670 self.checkpoints_by_message.remove(&deleted_message.id);
671 }
672 cx.notify();
673 }
674
675 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
676 self.context_by_message
677 .get(&id)
678 .into_iter()
679 .flat_map(|context| {
680 context
681 .iter()
682 .filter_map(|context_id| self.context.get(&context_id))
683 })
684 }
685
686 /// Returns whether all of the tool uses have finished running.
687 pub fn all_tools_finished(&self) -> bool {
688 // If the only pending tool uses left are the ones with errors, then
689 // that means that we've finished running all of the pending tools.
690 self.tool_use
691 .pending_tool_uses()
692 .iter()
693 .all(|tool_use| tool_use.status.is_error())
694 }
695
696 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
697 self.tool_use.tool_uses_for_message(id, cx)
698 }
699
700 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
701 self.tool_use.tool_results_for_message(id)
702 }
703
704 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
705 self.tool_use.tool_result(id)
706 }
707
708 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
709 Some(&self.tool_use.tool_result(id)?.content)
710 }
711
712 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
713 self.tool_use.tool_result_card(id).cloned()
714 }
715
716 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
717 self.tool_use.message_has_tool_results(message_id)
718 }
719
720 /// Filter out contexts that have already been included in previous messages
721 pub fn filter_new_context<'a>(
722 &self,
723 context: impl Iterator<Item = &'a AssistantContext>,
724 ) -> impl Iterator<Item = &'a AssistantContext> {
725 context.filter(|ctx| self.is_context_new(ctx))
726 }
727
728 fn is_context_new(&self, context: &AssistantContext) -> bool {
729 !self.context.contains_key(&context.id())
730 }
731
732 pub fn insert_user_message(
733 &mut self,
734 text: impl Into<String>,
735 context: Vec<AssistantContext>,
736 git_checkpoint: Option<GitStoreCheckpoint>,
737 cx: &mut Context<Self>,
738 ) -> MessageId {
739 let text = text.into();
740
741 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
742
743 let new_context: Vec<_> = context
744 .into_iter()
745 .filter(|ctx| self.is_context_new(ctx))
746 .collect();
747
748 if !new_context.is_empty() {
749 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
750 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
751 message.context = context_string;
752 }
753 }
754
755 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
756 message.images = new_context
757 .iter()
758 .filter_map(|context| {
759 if let AssistantContext::Image(image_context) = context {
760 image_context.image_task.clone().now_or_never().flatten()
761 } else {
762 None
763 }
764 })
765 .collect::<Vec<_>>();
766 }
767
768 self.action_log.update(cx, |log, cx| {
769 // Track all buffers added as context
770 for ctx in &new_context {
771 match ctx {
772 AssistantContext::File(file_ctx) => {
773 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
774 }
775 AssistantContext::Directory(dir_ctx) => {
776 for context_buffer in &dir_ctx.context_buffers {
777 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
778 }
779 }
780 AssistantContext::Symbol(symbol_ctx) => {
781 log.buffer_added_as_context(
782 symbol_ctx.context_symbol.buffer.clone(),
783 cx,
784 );
785 }
786 AssistantContext::Selection(selection_context) => {
787 log.buffer_added_as_context(
788 selection_context.context_buffer.buffer.clone(),
789 cx,
790 );
791 }
792 AssistantContext::FetchedUrl(_)
793 | AssistantContext::Thread(_)
794 | AssistantContext::Rules(_)
795 | AssistantContext::Image(_) => {}
796 }
797 }
798 });
799 }
800
801 let context_ids = new_context
802 .iter()
803 .map(|context| context.id())
804 .collect::<Vec<_>>();
805 self.context.extend(
806 new_context
807 .into_iter()
808 .map(|context| (context.id(), context)),
809 );
810 self.context_by_message.insert(message_id, context_ids);
811
812 if let Some(git_checkpoint) = git_checkpoint {
813 self.pending_checkpoint = Some(ThreadCheckpoint {
814 message_id,
815 git_checkpoint,
816 });
817 }
818
819 self.auto_capture_telemetry(cx);
820
821 message_id
822 }
823
824 pub fn insert_message(
825 &mut self,
826 role: Role,
827 segments: Vec<MessageSegment>,
828 cx: &mut Context<Self>,
829 ) -> MessageId {
830 let id = self.next_message_id.post_inc();
831 self.messages.push(Message {
832 id,
833 role,
834 segments,
835 context: String::new(),
836 images: Vec::new(),
837 });
838 self.touch_updated_at();
839 cx.emit(ThreadEvent::MessageAdded(id));
840 id
841 }
842
843 pub fn edit_message(
844 &mut self,
845 id: MessageId,
846 new_role: Role,
847 new_segments: Vec<MessageSegment>,
848 cx: &mut Context<Self>,
849 ) -> bool {
850 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
851 return false;
852 };
853 message.role = new_role;
854 message.segments = new_segments;
855 self.touch_updated_at();
856 cx.emit(ThreadEvent::MessageEdited(id));
857 true
858 }
859
860 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
861 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
862 return false;
863 };
864 self.messages.remove(index);
865 self.context_by_message.remove(&id);
866 self.touch_updated_at();
867 cx.emit(ThreadEvent::MessageDeleted(id));
868 true
869 }
870
871 /// Returns the representation of this [`Thread`] in a textual form.
872 ///
873 /// This is the representation we use when attaching a thread as context to another thread.
874 pub fn text(&self) -> String {
875 let mut text = String::new();
876
877 for message in &self.messages {
878 text.push_str(match message.role {
879 language_model::Role::User => "User:",
880 language_model::Role::Assistant => "Assistant:",
881 language_model::Role::System => "System:",
882 });
883 text.push('\n');
884
885 for segment in &message.segments {
886 match segment {
887 MessageSegment::Text(content) => text.push_str(content),
888 MessageSegment::Thinking { text: content, .. } => {
889 text.push_str(&format!("<think>{}</think>", content))
890 }
891 MessageSegment::RedactedThinking(_) => {}
892 }
893 }
894 text.push('\n');
895 }
896
897 text
898 }
899
900 /// Serializes this thread into a format for storage or telemetry.
901 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
902 let initial_project_snapshot = self.initial_project_snapshot.clone();
903 cx.spawn(async move |this, cx| {
904 let initial_project_snapshot = initial_project_snapshot.await;
905 this.read_with(cx, |this, cx| SerializedThread {
906 version: SerializedThread::VERSION.to_string(),
907 summary: this.summary_or_default(),
908 updated_at: this.updated_at(),
909 messages: this
910 .messages()
911 .map(|message| SerializedMessage {
912 id: message.id,
913 role: message.role,
914 segments: message
915 .segments
916 .iter()
917 .map(|segment| match segment {
918 MessageSegment::Text(text) => {
919 SerializedMessageSegment::Text { text: text.clone() }
920 }
921 MessageSegment::Thinking { text, signature } => {
922 SerializedMessageSegment::Thinking {
923 text: text.clone(),
924 signature: signature.clone(),
925 }
926 }
927 MessageSegment::RedactedThinking(data) => {
928 SerializedMessageSegment::RedactedThinking {
929 data: data.clone(),
930 }
931 }
932 })
933 .collect(),
934 tool_uses: this
935 .tool_uses_for_message(message.id, cx)
936 .into_iter()
937 .map(|tool_use| SerializedToolUse {
938 id: tool_use.id,
939 name: tool_use.name,
940 input: tool_use.input,
941 })
942 .collect(),
943 tool_results: this
944 .tool_results_for_message(message.id)
945 .into_iter()
946 .map(|tool_result| SerializedToolResult {
947 tool_use_id: tool_result.tool_use_id.clone(),
948 is_error: tool_result.is_error,
949 content: tool_result.content.clone(),
950 })
951 .collect(),
952 context: message.context.clone(),
953 })
954 .collect(),
955 initial_project_snapshot,
956 cumulative_token_usage: this.cumulative_token_usage,
957 request_token_usage: this.request_token_usage.clone(),
958 detailed_summary_state: this.detailed_summary_state.clone(),
959 exceeded_window_error: this.exceeded_window_error.clone(),
960 })
961 })
962 }
963
964 pub fn remaining_turns(&self) -> u32 {
965 self.remaining_turns
966 }
967
968 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
969 self.remaining_turns = remaining_turns;
970 }
971
972 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
973 if self.remaining_turns == 0 {
974 return;
975 }
976
977 self.remaining_turns -= 1;
978
979 let mut request = self.to_completion_request(cx);
980 if model.supports_tools() {
981 request.tools = {
982 let mut tools = Vec::new();
983 tools.extend(
984 self.tools()
985 .read(cx)
986 .enabled_tools(cx)
987 .into_iter()
988 .filter_map(|tool| {
989 // Skip tools that cannot be supported
990 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
991 Some(LanguageModelRequestTool {
992 name: tool.name(),
993 description: tool.description(),
994 input_schema,
995 })
996 }),
997 );
998
999 tools
1000 };
1001 }
1002
1003 self.stream_completion(request, model, cx);
1004 }
1005
1006 pub fn used_tools_since_last_user_message(&self) -> bool {
1007 for message in self.messages.iter().rev() {
1008 if self.tool_use.message_has_tool_results(message.id) {
1009 return true;
1010 } else if message.role == Role::User {
1011 return false;
1012 }
1013 }
1014
1015 false
1016 }
1017
1018 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1019 let mut request = LanguageModelRequest {
1020 thread_id: Some(self.id.to_string()),
1021 prompt_id: Some(self.last_prompt_id.to_string()),
1022 messages: vec![],
1023 tools: Vec::new(),
1024 stop: Vec::new(),
1025 temperature: None,
1026 };
1027
1028 if let Some(project_context) = self.project_context.borrow().as_ref() {
1029 match self
1030 .prompt_builder
1031 .generate_assistant_system_prompt(project_context)
1032 {
1033 Err(err) => {
1034 let message = format!("{err:?}").into();
1035 log::error!("{message}");
1036 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1037 header: "Error generating system prompt".into(),
1038 message,
1039 }));
1040 }
1041 Ok(system_prompt) => {
1042 request.messages.push(LanguageModelRequestMessage {
1043 role: Role::System,
1044 content: vec![MessageContent::Text(system_prompt)],
1045 cache: true,
1046 });
1047 }
1048 }
1049 } else {
1050 let message = "Context for system prompt unexpectedly not ready.".into();
1051 log::error!("{message}");
1052 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1053 header: "Error generating system prompt".into(),
1054 message,
1055 }));
1056 }
1057
1058 for message in &self.messages {
1059 let mut request_message = LanguageModelRequestMessage {
1060 role: message.role,
1061 content: Vec::new(),
1062 cache: false,
1063 };
1064
1065 self.tool_use
1066 .attach_tool_results(message.id, &mut request_message);
1067
1068 if !message.context.is_empty() {
1069 request_message
1070 .content
1071 .push(MessageContent::Text(message.context.to_string()));
1072 }
1073
1074 if !message.images.is_empty() {
1075 // Some providers only support image parts after an initial text part
1076 if request_message.content.is_empty() {
1077 request_message
1078 .content
1079 .push(MessageContent::Text("Images attached by user:".to_string()));
1080 }
1081
1082 for image in &message.images {
1083 request_message
1084 .content
1085 .push(MessageContent::Image(image.clone()))
1086 }
1087 }
1088
1089 for segment in &message.segments {
1090 match segment {
1091 MessageSegment::Text(text) => {
1092 if !text.is_empty() {
1093 request_message
1094 .content
1095 .push(MessageContent::Text(text.into()));
1096 }
1097 }
1098 MessageSegment::Thinking { text, signature } => {
1099 if !text.is_empty() {
1100 request_message.content.push(MessageContent::Thinking {
1101 text: text.into(),
1102 signature: signature.clone(),
1103 });
1104 }
1105 }
1106 MessageSegment::RedactedThinking(data) => {
1107 request_message
1108 .content
1109 .push(MessageContent::RedactedThinking(data.clone()));
1110 }
1111 };
1112 }
1113
1114 self.tool_use
1115 .attach_tool_uses(message.id, &mut request_message);
1116
1117 request.messages.push(request_message);
1118 }
1119
1120 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1121 if let Some(last) = request.messages.last_mut() {
1122 last.cache = true;
1123 }
1124
1125 self.attached_tracked_files_state(&mut request.messages, cx);
1126
1127 request
1128 }
1129
1130 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1131 let mut request = LanguageModelRequest {
1132 thread_id: None,
1133 prompt_id: None,
1134 messages: vec![],
1135 tools: Vec::new(),
1136 stop: Vec::new(),
1137 temperature: None,
1138 };
1139
1140 for message in &self.messages {
1141 let mut request_message = LanguageModelRequestMessage {
1142 role: message.role,
1143 content: Vec::new(),
1144 cache: false,
1145 };
1146
1147 // Skip tool results during summarization.
1148 if self.tool_use.message_has_tool_results(message.id) {
1149 continue;
1150 }
1151
1152 for segment in &message.segments {
1153 match segment {
1154 MessageSegment::Text(text) => request_message
1155 .content
1156 .push(MessageContent::Text(text.clone())),
1157 MessageSegment::Thinking { .. } => {}
1158 MessageSegment::RedactedThinking(_) => {}
1159 }
1160 }
1161
1162 if request_message.content.is_empty() {
1163 continue;
1164 }
1165
1166 request.messages.push(request_message);
1167 }
1168
1169 request.messages.push(LanguageModelRequestMessage {
1170 role: Role::User,
1171 content: vec![MessageContent::Text(added_user_message)],
1172 cache: false,
1173 });
1174
1175 request
1176 }
1177
1178 fn attached_tracked_files_state(
1179 &self,
1180 messages: &mut Vec<LanguageModelRequestMessage>,
1181 cx: &App,
1182 ) {
1183 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1184
1185 let mut stale_message = String::new();
1186
1187 let action_log = self.action_log.read(cx);
1188
1189 for stale_file in action_log.stale_buffers(cx) {
1190 let Some(file) = stale_file.read(cx).file() else {
1191 continue;
1192 };
1193
1194 if stale_message.is_empty() {
1195 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1196 }
1197
1198 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1199 }
1200
1201 let mut content = Vec::with_capacity(2);
1202
1203 if !stale_message.is_empty() {
1204 content.push(stale_message.into());
1205 }
1206
1207 if !content.is_empty() {
1208 let context_message = LanguageModelRequestMessage {
1209 role: Role::User,
1210 content,
1211 cache: false,
1212 };
1213
1214 messages.push(context_message);
1215 }
1216 }
1217
1218 pub fn stream_completion(
1219 &mut self,
1220 request: LanguageModelRequest,
1221 model: Arc<dyn LanguageModel>,
1222 cx: &mut Context<Self>,
1223 ) {
1224 let pending_completion_id = post_inc(&mut self.completion_count);
1225 let mut request_callback_parameters = if self.request_callback.is_some() {
1226 Some((request.clone(), Vec::new()))
1227 } else {
1228 None
1229 };
1230 let prompt_id = self.last_prompt_id.clone();
1231 let tool_use_metadata = ToolUseMetadata {
1232 model: model.clone(),
1233 thread_id: self.id.clone(),
1234 prompt_id: prompt_id.clone(),
1235 };
1236
1237 let task = cx.spawn(async move |thread, cx| {
1238 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1239 let initial_token_usage =
1240 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1241 let stream_completion = async {
1242 let (mut events, usage) = stream_completion_future.await?;
1243
1244 let mut stop_reason = StopReason::EndTurn;
1245 let mut current_token_usage = TokenUsage::default();
1246
1247 if let Some(usage) = usage {
1248 thread
1249 .update(cx, |_thread, cx| {
1250 cx.emit(ThreadEvent::UsageUpdated(usage));
1251 })
1252 .ok();
1253 }
1254
1255 while let Some(event) = events.next().await {
1256 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1257 response_events
1258 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1259 }
1260
1261 let event = event?;
1262
1263 thread.update(cx, |thread, cx| {
1264 match event {
1265 LanguageModelCompletionEvent::StartMessage { .. } => {
1266 thread.insert_message(
1267 Role::Assistant,
1268 vec![MessageSegment::Text(String::new())],
1269 cx,
1270 );
1271 }
1272 LanguageModelCompletionEvent::Stop(reason) => {
1273 stop_reason = reason;
1274 }
1275 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1276 thread.update_token_usage_at_last_message(token_usage);
1277 thread.cumulative_token_usage = thread.cumulative_token_usage
1278 + token_usage
1279 - current_token_usage;
1280 current_token_usage = token_usage;
1281 }
1282 LanguageModelCompletionEvent::Text(chunk) => {
1283 cx.emit(ThreadEvent::ReceivedTextChunk);
1284 if let Some(last_message) = thread.messages.last_mut() {
1285 if last_message.role == Role::Assistant {
1286 last_message.push_text(&chunk);
1287 cx.emit(ThreadEvent::StreamedAssistantText(
1288 last_message.id,
1289 chunk,
1290 ));
1291 } else {
1292 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1293 // of a new Assistant response.
1294 //
1295 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1296 // will result in duplicating the text of the chunk in the rendered Markdown.
1297 thread.insert_message(
1298 Role::Assistant,
1299 vec![MessageSegment::Text(chunk.to_string())],
1300 cx,
1301 );
1302 };
1303 }
1304 }
1305 LanguageModelCompletionEvent::Thinking {
1306 text: chunk,
1307 signature,
1308 } => {
1309 if let Some(last_message) = thread.messages.last_mut() {
1310 if last_message.role == Role::Assistant {
1311 last_message.push_thinking(&chunk, signature);
1312 cx.emit(ThreadEvent::StreamedAssistantThinking(
1313 last_message.id,
1314 chunk,
1315 ));
1316 } else {
1317 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1318 // of a new Assistant response.
1319 //
1320 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1321 // will result in duplicating the text of the chunk in the rendered Markdown.
1322 thread.insert_message(
1323 Role::Assistant,
1324 vec![MessageSegment::Thinking {
1325 text: chunk.to_string(),
1326 signature,
1327 }],
1328 cx,
1329 );
1330 };
1331 }
1332 }
1333 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1334 let last_assistant_message_id = thread
1335 .messages
1336 .iter_mut()
1337 .rfind(|message| message.role == Role::Assistant)
1338 .map(|message| message.id)
1339 .unwrap_or_else(|| {
1340 thread.insert_message(Role::Assistant, vec![], cx)
1341 });
1342
1343 let tool_use_id = tool_use.id.clone();
1344 let streamed_input = if tool_use.is_input_complete {
1345 None
1346 } else {
1347 Some((&tool_use.input).clone())
1348 };
1349
1350 let ui_text = thread.tool_use.request_tool_use(
1351 last_assistant_message_id,
1352 tool_use,
1353 tool_use_metadata.clone(),
1354 cx,
1355 );
1356
1357 if let Some(input) = streamed_input {
1358 cx.emit(ThreadEvent::StreamedToolUse {
1359 tool_use_id,
1360 ui_text,
1361 input,
1362 });
1363 }
1364 }
1365 }
1366
1367 thread.touch_updated_at();
1368 cx.emit(ThreadEvent::StreamedCompletion);
1369 cx.notify();
1370
1371 thread.auto_capture_telemetry(cx);
1372 })?;
1373
1374 smol::future::yield_now().await;
1375 }
1376
1377 thread.update(cx, |thread, cx| {
1378 thread
1379 .pending_completions
1380 .retain(|completion| completion.id != pending_completion_id);
1381
1382 // If there is a response without tool use, summarize the message. Otherwise,
1383 // allow two tool uses before summarizing.
1384 if thread.summary.is_none()
1385 && thread.messages.len() >= 2
1386 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1387 {
1388 thread.summarize(cx);
1389 }
1390 })?;
1391
1392 anyhow::Ok(stop_reason)
1393 };
1394
1395 let result = stream_completion.await;
1396
1397 thread
1398 .update(cx, |thread, cx| {
1399 thread.finalize_pending_checkpoint(cx);
1400 match result.as_ref() {
1401 Ok(stop_reason) => match stop_reason {
1402 StopReason::ToolUse => {
1403 let tool_uses = thread.use_pending_tools(cx);
1404 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1405 }
1406 StopReason::EndTurn => {}
1407 StopReason::MaxTokens => {}
1408 },
1409 Err(error) => {
1410 if error.is::<PaymentRequiredError>() {
1411 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1412 } else if error.is::<MaxMonthlySpendReachedError>() {
1413 cx.emit(ThreadEvent::ShowError(
1414 ThreadError::MaxMonthlySpendReached,
1415 ));
1416 } else if let Some(error) =
1417 error.downcast_ref::<ModelRequestLimitReachedError>()
1418 {
1419 cx.emit(ThreadEvent::ShowError(
1420 ThreadError::ModelRequestLimitReached { plan: error.plan },
1421 ));
1422 } else if let Some(known_error) =
1423 error.downcast_ref::<LanguageModelKnownError>()
1424 {
1425 match known_error {
1426 LanguageModelKnownError::ContextWindowLimitExceeded {
1427 tokens,
1428 } => {
1429 thread.exceeded_window_error = Some(ExceededWindowError {
1430 model_id: model.id(),
1431 token_count: *tokens,
1432 });
1433 cx.notify();
1434 }
1435 }
1436 } else {
1437 let error_message = error
1438 .chain()
1439 .map(|err| err.to_string())
1440 .collect::<Vec<_>>()
1441 .join("\n");
1442 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1443 header: "Error interacting with language model".into(),
1444 message: SharedString::from(error_message.clone()),
1445 }));
1446 }
1447
1448 thread.cancel_last_completion(cx);
1449 }
1450 }
1451 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1452
1453 if let Some((request_callback, (request, response_events))) = thread
1454 .request_callback
1455 .as_mut()
1456 .zip(request_callback_parameters.as_ref())
1457 {
1458 request_callback(request, response_events);
1459 }
1460
1461 thread.auto_capture_telemetry(cx);
1462
1463 if let Ok(initial_usage) = initial_token_usage {
1464 let usage = thread.cumulative_token_usage - initial_usage;
1465
1466 telemetry::event!(
1467 "Assistant Thread Completion",
1468 thread_id = thread.id().to_string(),
1469 prompt_id = prompt_id,
1470 model = model.telemetry_id(),
1471 model_provider = model.provider_id().to_string(),
1472 input_tokens = usage.input_tokens,
1473 output_tokens = usage.output_tokens,
1474 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1475 cache_read_input_tokens = usage.cache_read_input_tokens,
1476 );
1477 }
1478 })
1479 .ok();
1480 });
1481
1482 self.pending_completions.push(PendingCompletion {
1483 id: pending_completion_id,
1484 _task: task,
1485 });
1486 }
1487
1488 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1489 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1490 return;
1491 };
1492
1493 if !model.provider.is_authenticated(cx) {
1494 return;
1495 }
1496
1497 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1498 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1499 If the conversation is about a specific subject, include it in the title. \
1500 Be descriptive. DO NOT speak in the first person.";
1501
1502 let request = self.to_summarize_request(added_user_message.into());
1503
1504 self.pending_summary = cx.spawn(async move |this, cx| {
1505 async move {
1506 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1507 let (mut messages, usage) = stream.await?;
1508
1509 if let Some(usage) = usage {
1510 this.update(cx, |_thread, cx| {
1511 cx.emit(ThreadEvent::UsageUpdated(usage));
1512 })
1513 .ok();
1514 }
1515
1516 let mut new_summary = String::new();
1517 while let Some(message) = messages.stream.next().await {
1518 let text = message?;
1519 let mut lines = text.lines();
1520 new_summary.extend(lines.next());
1521
1522 // Stop if the LLM generated multiple lines.
1523 if lines.next().is_some() {
1524 break;
1525 }
1526 }
1527
1528 this.update(cx, |this, cx| {
1529 if !new_summary.is_empty() {
1530 this.summary = Some(new_summary.into());
1531 }
1532
1533 cx.emit(ThreadEvent::SummaryGenerated);
1534 })?;
1535
1536 anyhow::Ok(())
1537 }
1538 .log_err()
1539 .await
1540 });
1541 }
1542
1543 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1544 let last_message_id = self.messages.last().map(|message| message.id)?;
1545
1546 match &self.detailed_summary_state {
1547 DetailedSummaryState::Generating { message_id, .. }
1548 | DetailedSummaryState::Generated { message_id, .. }
1549 if *message_id == last_message_id =>
1550 {
1551 // Already up-to-date
1552 return None;
1553 }
1554 _ => {}
1555 }
1556
1557 let ConfiguredModel { model, provider } =
1558 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1559
1560 if !provider.is_authenticated(cx) {
1561 return None;
1562 }
1563
1564 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1565 1. A brief overview of what was discussed\n\
1566 2. Key facts or information discovered\n\
1567 3. Outcomes or conclusions reached\n\
1568 4. Any action items or next steps if any\n\
1569 Format it in Markdown with headings and bullet points.";
1570
1571 let request = self.to_summarize_request(added_user_message.into());
1572
1573 let task = cx.spawn(async move |thread, cx| {
1574 let stream = model.stream_completion_text(request, &cx);
1575 let Some(mut messages) = stream.await.log_err() else {
1576 thread
1577 .update(cx, |this, _cx| {
1578 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1579 })
1580 .log_err();
1581
1582 return;
1583 };
1584
1585 let mut new_detailed_summary = String::new();
1586
1587 while let Some(chunk) = messages.stream.next().await {
1588 if let Some(chunk) = chunk.log_err() {
1589 new_detailed_summary.push_str(&chunk);
1590 }
1591 }
1592
1593 thread
1594 .update(cx, |this, _cx| {
1595 this.detailed_summary_state = DetailedSummaryState::Generated {
1596 text: new_detailed_summary.into(),
1597 message_id: last_message_id,
1598 };
1599 })
1600 .log_err();
1601 });
1602
1603 self.detailed_summary_state = DetailedSummaryState::Generating {
1604 message_id: last_message_id,
1605 };
1606
1607 Some(task)
1608 }
1609
1610 pub fn is_generating_detailed_summary(&self) -> bool {
1611 matches!(
1612 self.detailed_summary_state,
1613 DetailedSummaryState::Generating { .. }
1614 )
1615 }
1616
1617 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1618 self.auto_capture_telemetry(cx);
1619 let request = self.to_completion_request(cx);
1620 let messages = Arc::new(request.messages);
1621 let pending_tool_uses = self
1622 .tool_use
1623 .pending_tool_uses()
1624 .into_iter()
1625 .filter(|tool_use| tool_use.status.is_idle())
1626 .cloned()
1627 .collect::<Vec<_>>();
1628
1629 for tool_use in pending_tool_uses.iter() {
1630 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1631 if tool.needs_confirmation(&tool_use.input, cx)
1632 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1633 {
1634 self.tool_use.confirm_tool_use(
1635 tool_use.id.clone(),
1636 tool_use.ui_text.clone(),
1637 tool_use.input.clone(),
1638 messages.clone(),
1639 tool,
1640 );
1641 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1642 } else {
1643 self.run_tool(
1644 tool_use.id.clone(),
1645 tool_use.ui_text.clone(),
1646 tool_use.input.clone(),
1647 &messages,
1648 tool,
1649 cx,
1650 );
1651 }
1652 }
1653 }
1654
1655 pending_tool_uses
1656 }
1657
1658 pub fn run_tool(
1659 &mut self,
1660 tool_use_id: LanguageModelToolUseId,
1661 ui_text: impl Into<SharedString>,
1662 input: serde_json::Value,
1663 messages: &[LanguageModelRequestMessage],
1664 tool: Arc<dyn Tool>,
1665 cx: &mut Context<Thread>,
1666 ) {
1667 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1668 self.tool_use
1669 .run_pending_tool(tool_use_id, ui_text.into(), task);
1670 }
1671
1672 fn spawn_tool_use(
1673 &mut self,
1674 tool_use_id: LanguageModelToolUseId,
1675 messages: &[LanguageModelRequestMessage],
1676 input: serde_json::Value,
1677 tool: Arc<dyn Tool>,
1678 cx: &mut Context<Thread>,
1679 ) -> Task<()> {
1680 let tool_name: Arc<str> = tool.name().into();
1681
1682 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1683 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1684 } else {
1685 tool.run(
1686 input,
1687 messages,
1688 self.project.clone(),
1689 self.action_log.clone(),
1690 cx,
1691 )
1692 };
1693
1694 // Store the card separately if it exists
1695 if let Some(card) = tool_result.card.clone() {
1696 self.tool_use
1697 .insert_tool_result_card(tool_use_id.clone(), card);
1698 }
1699
1700 cx.spawn({
1701 async move |thread: WeakEntity<Thread>, cx| {
1702 let output = tool_result.output.await;
1703
1704 thread
1705 .update(cx, |thread, cx| {
1706 let pending_tool_use = thread.tool_use.insert_tool_output(
1707 tool_use_id.clone(),
1708 tool_name,
1709 output,
1710 cx,
1711 );
1712 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1713 })
1714 .ok();
1715 }
1716 })
1717 }
1718
1719 fn tool_finished(
1720 &mut self,
1721 tool_use_id: LanguageModelToolUseId,
1722 pending_tool_use: Option<PendingToolUse>,
1723 canceled: bool,
1724 cx: &mut Context<Self>,
1725 ) {
1726 if self.all_tools_finished() {
1727 let model_registry = LanguageModelRegistry::read_global(cx);
1728 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1729 self.attach_tool_results(cx);
1730 if !canceled {
1731 self.send_to_model(model, cx);
1732 }
1733 }
1734 }
1735
1736 cx.emit(ThreadEvent::ToolFinished {
1737 tool_use_id,
1738 pending_tool_use,
1739 });
1740 }
1741
1742 /// Insert an empty message to be populated with tool results upon send.
1743 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1744 // Tool results are assumed to be waiting on the next message id, so they will populate
1745 // this empty message before sending to model. Would prefer this to be more straightforward.
1746 self.insert_message(Role::User, vec![], cx);
1747 self.auto_capture_telemetry(cx);
1748 }
1749
1750 /// Cancels the last pending completion, if there are any pending.
1751 ///
1752 /// Returns whether a completion was canceled.
1753 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1754 let canceled = if self.pending_completions.pop().is_some() {
1755 true
1756 } else {
1757 let mut canceled = false;
1758 for pending_tool_use in self.tool_use.cancel_pending() {
1759 canceled = true;
1760 self.tool_finished(
1761 pending_tool_use.id.clone(),
1762 Some(pending_tool_use),
1763 true,
1764 cx,
1765 );
1766 }
1767 canceled
1768 };
1769 self.finalize_pending_checkpoint(cx);
1770 canceled
1771 }
1772
1773 pub fn feedback(&self) -> Option<ThreadFeedback> {
1774 self.feedback
1775 }
1776
1777 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1778 self.message_feedback.get(&message_id).copied()
1779 }
1780
1781 pub fn report_message_feedback(
1782 &mut self,
1783 message_id: MessageId,
1784 feedback: ThreadFeedback,
1785 cx: &mut Context<Self>,
1786 ) -> Task<Result<()>> {
1787 if self.message_feedback.get(&message_id) == Some(&feedback) {
1788 return Task::ready(Ok(()));
1789 }
1790
1791 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1792 let serialized_thread = self.serialize(cx);
1793 let thread_id = self.id().clone();
1794 let client = self.project.read(cx).client();
1795
1796 let enabled_tool_names: Vec<String> = self
1797 .tools()
1798 .read(cx)
1799 .enabled_tools(cx)
1800 .iter()
1801 .map(|tool| tool.name().to_string())
1802 .collect();
1803
1804 self.message_feedback.insert(message_id, feedback);
1805
1806 cx.notify();
1807
1808 let message_content = self
1809 .message(message_id)
1810 .map(|msg| msg.to_string())
1811 .unwrap_or_default();
1812
1813 cx.background_spawn(async move {
1814 let final_project_snapshot = final_project_snapshot.await;
1815 let serialized_thread = serialized_thread.await?;
1816 let thread_data =
1817 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1818
1819 let rating = match feedback {
1820 ThreadFeedback::Positive => "positive",
1821 ThreadFeedback::Negative => "negative",
1822 };
1823 telemetry::event!(
1824 "Assistant Thread Rated",
1825 rating,
1826 thread_id,
1827 enabled_tool_names,
1828 message_id = message_id.0,
1829 message_content,
1830 thread_data,
1831 final_project_snapshot
1832 );
1833 client.telemetry().flush_events().await;
1834
1835 Ok(())
1836 })
1837 }
1838
1839 pub fn report_feedback(
1840 &mut self,
1841 feedback: ThreadFeedback,
1842 cx: &mut Context<Self>,
1843 ) -> Task<Result<()>> {
1844 let last_assistant_message_id = self
1845 .messages
1846 .iter()
1847 .rev()
1848 .find(|msg| msg.role == Role::Assistant)
1849 .map(|msg| msg.id);
1850
1851 if let Some(message_id) = last_assistant_message_id {
1852 self.report_message_feedback(message_id, feedback, cx)
1853 } else {
1854 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1855 let serialized_thread = self.serialize(cx);
1856 let thread_id = self.id().clone();
1857 let client = self.project.read(cx).client();
1858 self.feedback = Some(feedback);
1859 cx.notify();
1860
1861 cx.background_spawn(async move {
1862 let final_project_snapshot = final_project_snapshot.await;
1863 let serialized_thread = serialized_thread.await?;
1864 let thread_data = serde_json::to_value(serialized_thread)
1865 .unwrap_or_else(|_| serde_json::Value::Null);
1866
1867 let rating = match feedback {
1868 ThreadFeedback::Positive => "positive",
1869 ThreadFeedback::Negative => "negative",
1870 };
1871 telemetry::event!(
1872 "Assistant Thread Rated",
1873 rating,
1874 thread_id,
1875 thread_data,
1876 final_project_snapshot
1877 );
1878 client.telemetry().flush_events().await;
1879
1880 Ok(())
1881 })
1882 }
1883 }
1884
1885 /// Create a snapshot of the current project state including git information and unsaved buffers.
1886 fn project_snapshot(
1887 project: Entity<Project>,
1888 cx: &mut Context<Self>,
1889 ) -> Task<Arc<ProjectSnapshot>> {
1890 let git_store = project.read(cx).git_store().clone();
1891 let worktree_snapshots: Vec<_> = project
1892 .read(cx)
1893 .visible_worktrees(cx)
1894 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1895 .collect();
1896
1897 cx.spawn(async move |_, cx| {
1898 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1899
1900 let mut unsaved_buffers = Vec::new();
1901 cx.update(|app_cx| {
1902 let buffer_store = project.read(app_cx).buffer_store();
1903 for buffer_handle in buffer_store.read(app_cx).buffers() {
1904 let buffer = buffer_handle.read(app_cx);
1905 if buffer.is_dirty() {
1906 if let Some(file) = buffer.file() {
1907 let path = file.path().to_string_lossy().to_string();
1908 unsaved_buffers.push(path);
1909 }
1910 }
1911 }
1912 })
1913 .ok();
1914
1915 Arc::new(ProjectSnapshot {
1916 worktree_snapshots,
1917 unsaved_buffer_paths: unsaved_buffers,
1918 timestamp: Utc::now(),
1919 })
1920 })
1921 }
1922
1923 fn worktree_snapshot(
1924 worktree: Entity<project::Worktree>,
1925 git_store: Entity<GitStore>,
1926 cx: &App,
1927 ) -> Task<WorktreeSnapshot> {
1928 cx.spawn(async move |cx| {
1929 // Get worktree path and snapshot
1930 let worktree_info = cx.update(|app_cx| {
1931 let worktree = worktree.read(app_cx);
1932 let path = worktree.abs_path().to_string_lossy().to_string();
1933 let snapshot = worktree.snapshot();
1934 (path, snapshot)
1935 });
1936
1937 let Ok((worktree_path, _snapshot)) = worktree_info else {
1938 return WorktreeSnapshot {
1939 worktree_path: String::new(),
1940 git_state: None,
1941 };
1942 };
1943
1944 let git_state = git_store
1945 .update(cx, |git_store, cx| {
1946 git_store
1947 .repositories()
1948 .values()
1949 .find(|repo| {
1950 repo.read(cx)
1951 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1952 .is_some()
1953 })
1954 .cloned()
1955 })
1956 .ok()
1957 .flatten()
1958 .map(|repo| {
1959 repo.update(cx, |repo, _| {
1960 let current_branch =
1961 repo.branch.as_ref().map(|branch| branch.name.to_string());
1962 repo.send_job(None, |state, _| async move {
1963 let RepositoryState::Local { backend, .. } = state else {
1964 return GitState {
1965 remote_url: None,
1966 head_sha: None,
1967 current_branch,
1968 diff: None,
1969 };
1970 };
1971
1972 let remote_url = backend.remote_url("origin");
1973 let head_sha = backend.head_sha();
1974 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1975
1976 GitState {
1977 remote_url,
1978 head_sha,
1979 current_branch,
1980 diff,
1981 }
1982 })
1983 })
1984 });
1985
1986 let git_state = match git_state {
1987 Some(git_state) => match git_state.ok() {
1988 Some(git_state) => git_state.await.ok(),
1989 None => None,
1990 },
1991 None => None,
1992 };
1993
1994 WorktreeSnapshot {
1995 worktree_path,
1996 git_state,
1997 }
1998 })
1999 }
2000
2001 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2002 let mut markdown = Vec::new();
2003
2004 if let Some(summary) = self.summary() {
2005 writeln!(markdown, "# {summary}\n")?;
2006 };
2007
2008 for message in self.messages() {
2009 writeln!(
2010 markdown,
2011 "## {role}\n",
2012 role = match message.role {
2013 Role::User => "User",
2014 Role::Assistant => "Assistant",
2015 Role::System => "System",
2016 }
2017 )?;
2018
2019 if !message.context.is_empty() {
2020 writeln!(markdown, "{}", message.context)?;
2021 }
2022
2023 for segment in &message.segments {
2024 match segment {
2025 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2026 MessageSegment::Thinking { text, .. } => {
2027 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2028 }
2029 MessageSegment::RedactedThinking(_) => {}
2030 }
2031 }
2032
2033 for tool_use in self.tool_uses_for_message(message.id, cx) {
2034 writeln!(
2035 markdown,
2036 "**Use Tool: {} ({})**",
2037 tool_use.name, tool_use.id
2038 )?;
2039 writeln!(markdown, "```json")?;
2040 writeln!(
2041 markdown,
2042 "{}",
2043 serde_json::to_string_pretty(&tool_use.input)?
2044 )?;
2045 writeln!(markdown, "```")?;
2046 }
2047
2048 for tool_result in self.tool_results_for_message(message.id) {
2049 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2050 if tool_result.is_error {
2051 write!(markdown, " (Error)")?;
2052 }
2053
2054 writeln!(markdown, "**\n")?;
2055 writeln!(markdown, "{}", tool_result.content)?;
2056 }
2057 }
2058
2059 Ok(String::from_utf8_lossy(&markdown).to_string())
2060 }
2061
2062 pub fn keep_edits_in_range(
2063 &mut self,
2064 buffer: Entity<language::Buffer>,
2065 buffer_range: Range<language::Anchor>,
2066 cx: &mut Context<Self>,
2067 ) {
2068 self.action_log.update(cx, |action_log, cx| {
2069 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2070 });
2071 }
2072
2073 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2074 self.action_log
2075 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2076 }
2077
2078 pub fn reject_edits_in_ranges(
2079 &mut self,
2080 buffer: Entity<language::Buffer>,
2081 buffer_ranges: Vec<Range<language::Anchor>>,
2082 cx: &mut Context<Self>,
2083 ) -> Task<Result<()>> {
2084 self.action_log.update(cx, |action_log, cx| {
2085 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2086 })
2087 }
2088
2089 pub fn action_log(&self) -> &Entity<ActionLog> {
2090 &self.action_log
2091 }
2092
2093 pub fn project(&self) -> &Entity<Project> {
2094 &self.project
2095 }
2096
2097 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2098 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2099 return;
2100 }
2101
2102 let now = Instant::now();
2103 if let Some(last) = self.last_auto_capture_at {
2104 if now.duration_since(last).as_secs() < 10 {
2105 return;
2106 }
2107 }
2108
2109 self.last_auto_capture_at = Some(now);
2110
2111 let thread_id = self.id().clone();
2112 let github_login = self
2113 .project
2114 .read(cx)
2115 .user_store()
2116 .read(cx)
2117 .current_user()
2118 .map(|user| user.github_login.clone());
2119 let client = self.project.read(cx).client().clone();
2120 let serialize_task = self.serialize(cx);
2121
2122 cx.background_executor()
2123 .spawn(async move {
2124 if let Ok(serialized_thread) = serialize_task.await {
2125 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2126 telemetry::event!(
2127 "Agent Thread Auto-Captured",
2128 thread_id = thread_id.to_string(),
2129 thread_data = thread_data,
2130 auto_capture_reason = "tracked_user",
2131 github_login = github_login
2132 );
2133
2134 client.telemetry().flush_events().await;
2135 }
2136 }
2137 })
2138 .detach();
2139 }
2140
2141 pub fn cumulative_token_usage(&self) -> TokenUsage {
2142 self.cumulative_token_usage
2143 }
2144
2145 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2146 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2147 return TotalTokenUsage::default();
2148 };
2149
2150 let max = model.model.max_token_count();
2151
2152 let index = self
2153 .messages
2154 .iter()
2155 .position(|msg| msg.id == message_id)
2156 .unwrap_or(0);
2157
2158 if index == 0 {
2159 return TotalTokenUsage { total: 0, max };
2160 }
2161
2162 let token_usage = &self
2163 .request_token_usage
2164 .get(index - 1)
2165 .cloned()
2166 .unwrap_or_default();
2167
2168 TotalTokenUsage {
2169 total: token_usage.total_tokens() as usize,
2170 max,
2171 }
2172 }
2173
2174 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2175 let model_registry = LanguageModelRegistry::read_global(cx);
2176 let Some(model) = model_registry.default_model() else {
2177 return TotalTokenUsage::default();
2178 };
2179
2180 let max = model.model.max_token_count();
2181
2182 if let Some(exceeded_error) = &self.exceeded_window_error {
2183 if model.model.id() == exceeded_error.model_id {
2184 return TotalTokenUsage {
2185 total: exceeded_error.token_count,
2186 max,
2187 };
2188 }
2189 }
2190
2191 let total = self
2192 .token_usage_at_last_message()
2193 .unwrap_or_default()
2194 .total_tokens() as usize;
2195
2196 TotalTokenUsage { total, max }
2197 }
2198
2199 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2200 self.request_token_usage
2201 .get(self.messages.len().saturating_sub(1))
2202 .or_else(|| self.request_token_usage.last())
2203 .cloned()
2204 }
2205
2206 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2207 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2208 self.request_token_usage
2209 .resize(self.messages.len(), placeholder);
2210
2211 if let Some(last) = self.request_token_usage.last_mut() {
2212 *last = token_usage;
2213 }
2214 }
2215
2216 pub fn deny_tool_use(
2217 &mut self,
2218 tool_use_id: LanguageModelToolUseId,
2219 tool_name: Arc<str>,
2220 cx: &mut Context<Self>,
2221 ) {
2222 let err = Err(anyhow::anyhow!(
2223 "Permission to run tool action denied by user"
2224 ));
2225
2226 self.tool_use
2227 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2228 self.tool_finished(tool_use_id.clone(), None, true, cx);
2229 }
2230}
2231
2232#[derive(Debug, Clone, Error)]
2233pub enum ThreadError {
2234 #[error("Payment required")]
2235 PaymentRequired,
2236 #[error("Max monthly spend reached")]
2237 MaxMonthlySpendReached,
2238 #[error("Model request limit reached")]
2239 ModelRequestLimitReached { plan: Plan },
2240 #[error("Message {header}: {message}")]
2241 Message {
2242 header: SharedString,
2243 message: SharedString,
2244 },
2245}
2246
2247#[derive(Debug, Clone)]
2248pub enum ThreadEvent {
2249 ShowError(ThreadError),
2250 UsageUpdated(RequestUsage),
2251 StreamedCompletion,
2252 ReceivedTextChunk,
2253 StreamedAssistantText(MessageId, String),
2254 StreamedAssistantThinking(MessageId, String),
2255 StreamedToolUse {
2256 tool_use_id: LanguageModelToolUseId,
2257 ui_text: Arc<str>,
2258 input: serde_json::Value,
2259 },
2260 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2261 MessageAdded(MessageId),
2262 MessageEdited(MessageId),
2263 MessageDeleted(MessageId),
2264 SummaryGenerated,
2265 SummaryChanged,
2266 UsePendingTools {
2267 tool_uses: Vec<PendingToolUse>,
2268 },
2269 ToolFinished {
2270 #[allow(unused)]
2271 tool_use_id: LanguageModelToolUseId,
2272 /// The pending tool use that corresponds to this tool.
2273 pending_tool_use: Option<PendingToolUse>,
2274 },
2275 CheckpointChanged,
2276 ToolConfirmationNeeded,
2277}
2278
2279impl EventEmitter<ThreadEvent> for Thread {}
2280
2281struct PendingCompletion {
2282 id: usize,
2283 _task: Task<()>,
2284}
2285
2286#[cfg(test)]
2287mod tests {
2288 use super::*;
2289 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2290 use assistant_settings::AssistantSettings;
2291 use context_server::ContextServerSettings;
2292 use editor::EditorSettings;
2293 use gpui::TestAppContext;
2294 use project::{FakeFs, Project};
2295 use prompt_store::PromptBuilder;
2296 use serde_json::json;
2297 use settings::{Settings, SettingsStore};
2298 use std::sync::Arc;
2299 use theme::ThemeSettings;
2300 use util::path;
2301 use workspace::Workspace;
2302
2303 #[gpui::test]
2304 async fn test_message_with_context(cx: &mut TestAppContext) {
2305 init_test_settings(cx);
2306
2307 let project = create_test_project(
2308 cx,
2309 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2310 )
2311 .await;
2312
2313 let (_workspace, _thread_store, thread, context_store) =
2314 setup_test_environment(cx, project.clone()).await;
2315
2316 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2317 .await
2318 .unwrap();
2319
2320 let context =
2321 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2322
2323 // Insert user message with context
2324 let message_id = thread.update(cx, |thread, cx| {
2325 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2326 });
2327
2328 // Check content and context in message object
2329 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2330
2331 // Use different path format strings based on platform for the test
2332 #[cfg(windows)]
2333 let path_part = r"test\code.rs";
2334 #[cfg(not(windows))]
2335 let path_part = "test/code.rs";
2336
2337 let expected_context = format!(
2338 r#"
2339<context>
2340The following items were attached by the user. You don't need to use other tools to read them.
2341
2342<files>
2343```rs {path_part}
2344fn main() {{
2345 println!("Hello, world!");
2346}}
2347```
2348</files>
2349</context>
2350"#
2351 );
2352
2353 assert_eq!(message.role, Role::User);
2354 assert_eq!(message.segments.len(), 1);
2355 assert_eq!(
2356 message.segments[0],
2357 MessageSegment::Text("Please explain this code".to_string())
2358 );
2359 assert_eq!(message.context, expected_context);
2360
2361 // Check message in request
2362 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2363
2364 assert_eq!(request.messages.len(), 2);
2365 let expected_full_message = format!("{}Please explain this code", expected_context);
2366 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2367 }
2368
2369 #[gpui::test]
2370 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2371 init_test_settings(cx);
2372
2373 let project = create_test_project(
2374 cx,
2375 json!({
2376 "file1.rs": "fn function1() {}\n",
2377 "file2.rs": "fn function2() {}\n",
2378 "file3.rs": "fn function3() {}\n",
2379 }),
2380 )
2381 .await;
2382
2383 let (_, _thread_store, thread, context_store) =
2384 setup_test_environment(cx, project.clone()).await;
2385
2386 // Open files individually
2387 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2388 .await
2389 .unwrap();
2390 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2391 .await
2392 .unwrap();
2393 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2394 .await
2395 .unwrap();
2396
2397 // Get the context objects
2398 let contexts = context_store.update(cx, |store, _| store.context().clone());
2399 assert_eq!(contexts.len(), 3);
2400
2401 // First message with context 1
2402 let message1_id = thread.update(cx, |thread, cx| {
2403 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2404 });
2405
2406 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2407 let message2_id = thread.update(cx, |thread, cx| {
2408 thread.insert_user_message(
2409 "Message 2",
2410 vec![contexts[0].clone(), contexts[1].clone()],
2411 None,
2412 cx,
2413 )
2414 });
2415
2416 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2417 let message3_id = thread.update(cx, |thread, cx| {
2418 thread.insert_user_message(
2419 "Message 3",
2420 vec![
2421 contexts[0].clone(),
2422 contexts[1].clone(),
2423 contexts[2].clone(),
2424 ],
2425 None,
2426 cx,
2427 )
2428 });
2429
2430 // Check what contexts are included in each message
2431 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2432 (
2433 thread.message(message1_id).unwrap().clone(),
2434 thread.message(message2_id).unwrap().clone(),
2435 thread.message(message3_id).unwrap().clone(),
2436 )
2437 });
2438
2439 // First message should include context 1
2440 assert!(message1.context.contains("file1.rs"));
2441
2442 // Second message should include only context 2 (not 1)
2443 assert!(!message2.context.contains("file1.rs"));
2444 assert!(message2.context.contains("file2.rs"));
2445
2446 // Third message should include only context 3 (not 1 or 2)
2447 assert!(!message3.context.contains("file1.rs"));
2448 assert!(!message3.context.contains("file2.rs"));
2449 assert!(message3.context.contains("file3.rs"));
2450
2451 // Check entire request to make sure all contexts are properly included
2452 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2453
2454 // The request should contain all 3 messages
2455 assert_eq!(request.messages.len(), 4);
2456
2457 // Check that the contexts are properly formatted in each message
2458 assert!(request.messages[1].string_contents().contains("file1.rs"));
2459 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2460 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2461
2462 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2463 assert!(request.messages[2].string_contents().contains("file2.rs"));
2464 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2465
2466 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2467 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2468 assert!(request.messages[3].string_contents().contains("file3.rs"));
2469 }
2470
2471 #[gpui::test]
2472 async fn test_message_without_files(cx: &mut TestAppContext) {
2473 init_test_settings(cx);
2474
2475 let project = create_test_project(
2476 cx,
2477 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2478 )
2479 .await;
2480
2481 let (_, _thread_store, thread, _context_store) =
2482 setup_test_environment(cx, project.clone()).await;
2483
2484 // Insert user message without any context (empty context vector)
2485 let message_id = thread.update(cx, |thread, cx| {
2486 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2487 });
2488
2489 // Check content and context in message object
2490 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2491
2492 // Context should be empty when no files are included
2493 assert_eq!(message.role, Role::User);
2494 assert_eq!(message.segments.len(), 1);
2495 assert_eq!(
2496 message.segments[0],
2497 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2498 );
2499 assert_eq!(message.context, "");
2500
2501 // Check message in request
2502 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2503
2504 assert_eq!(request.messages.len(), 2);
2505 assert_eq!(
2506 request.messages[1].string_contents(),
2507 "What is the best way to learn Rust?"
2508 );
2509
2510 // Add second message, also without context
2511 let message2_id = thread.update(cx, |thread, cx| {
2512 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2513 });
2514
2515 let message2 =
2516 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2517 assert_eq!(message2.context, "");
2518
2519 // Check that both messages appear in the request
2520 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2521
2522 assert_eq!(request.messages.len(), 3);
2523 assert_eq!(
2524 request.messages[1].string_contents(),
2525 "What is the best way to learn Rust?"
2526 );
2527 assert_eq!(
2528 request.messages[2].string_contents(),
2529 "Are there any good books?"
2530 );
2531 }
2532
2533 #[gpui::test]
2534 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2535 init_test_settings(cx);
2536
2537 let project = create_test_project(
2538 cx,
2539 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2540 )
2541 .await;
2542
2543 let (_workspace, _thread_store, thread, context_store) =
2544 setup_test_environment(cx, project.clone()).await;
2545
2546 // Open buffer and add it to context
2547 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2548 .await
2549 .unwrap();
2550
2551 let context =
2552 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2553
2554 // Insert user message with the buffer as context
2555 thread.update(cx, |thread, cx| {
2556 thread.insert_user_message("Explain this code", vec![context], None, cx)
2557 });
2558
2559 // Create a request and check that it doesn't have a stale buffer warning yet
2560 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2561
2562 // Make sure we don't have a stale file warning yet
2563 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2564 msg.string_contents()
2565 .contains("These files changed since last read:")
2566 });
2567 assert!(
2568 !has_stale_warning,
2569 "Should not have stale buffer warning before buffer is modified"
2570 );
2571
2572 // Modify the buffer
2573 buffer.update(cx, |buffer, cx| {
2574 // Find a position at the end of line 1
2575 buffer.edit(
2576 [(1..1, "\n println!(\"Added a new line\");\n")],
2577 None,
2578 cx,
2579 );
2580 });
2581
2582 // Insert another user message without context
2583 thread.update(cx, |thread, cx| {
2584 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2585 });
2586
2587 // Create a new request and check for the stale buffer warning
2588 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2589
2590 // We should have a stale file warning as the last message
2591 let last_message = new_request
2592 .messages
2593 .last()
2594 .expect("Request should have messages");
2595
2596 // The last message should be the stale buffer notification
2597 assert_eq!(last_message.role, Role::User);
2598
2599 // Check the exact content of the message
2600 let expected_content = "These files changed since last read:\n- code.rs\n";
2601 assert_eq!(
2602 last_message.string_contents(),
2603 expected_content,
2604 "Last message should be exactly the stale buffer notification"
2605 );
2606 }
2607
2608 fn init_test_settings(cx: &mut TestAppContext) {
2609 cx.update(|cx| {
2610 let settings_store = SettingsStore::test(cx);
2611 cx.set_global(settings_store);
2612 language::init(cx);
2613 Project::init_settings(cx);
2614 AssistantSettings::register(cx);
2615 prompt_store::init(cx);
2616 thread_store::init(cx);
2617 workspace::init_settings(cx);
2618 ThemeSettings::register(cx);
2619 ContextServerSettings::register(cx);
2620 EditorSettings::register(cx);
2621 });
2622 }
2623
2624 // Helper to create a test project with test files
2625 async fn create_test_project(
2626 cx: &mut TestAppContext,
2627 files: serde_json::Value,
2628 ) -> Entity<Project> {
2629 let fs = FakeFs::new(cx.executor());
2630 fs.insert_tree(path!("/test"), files).await;
2631 Project::test(fs, [path!("/test").as_ref()], cx).await
2632 }
2633
2634 async fn setup_test_environment(
2635 cx: &mut TestAppContext,
2636 project: Entity<Project>,
2637 ) -> (
2638 Entity<Workspace>,
2639 Entity<ThreadStore>,
2640 Entity<Thread>,
2641 Entity<ContextStore>,
2642 ) {
2643 let (workspace, cx) =
2644 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2645
2646 let thread_store = cx
2647 .update(|_, cx| {
2648 ThreadStore::load(
2649 project.clone(),
2650 cx.new(|_| ToolWorkingSet::default()),
2651 Arc::new(PromptBuilder::new(None).unwrap()),
2652 cx,
2653 )
2654 })
2655 .await
2656 .unwrap();
2657
2658 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2659 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2660
2661 (workspace, thread_store, thread, context_store)
2662 }
2663
2664 async fn add_file_to_context(
2665 project: &Entity<Project>,
2666 context_store: &Entity<ContextStore>,
2667 path: &str,
2668 cx: &mut TestAppContext,
2669 ) -> Result<Entity<language::Buffer>> {
2670 let buffer_path = project
2671 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2672 .unwrap();
2673
2674 let buffer = project
2675 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2676 .await
2677 .unwrap();
2678
2679 context_store
2680 .update(cx, |store, cx| {
2681 store.add_file_from_buffer(buffer.clone(), cx)
2682 })
2683 .await?;
2684
2685 Ok(buffer)
2686 }
2687}