1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100 pub images: Vec<LanguageModelImage>,
101}
102
103impl Message {
104 /// Returns whether the message contains any meaningful text that should be displayed
105 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
106 pub fn should_display_content(&self) -> bool {
107 self.segments.iter().all(|segment| segment.should_display())
108 }
109
110 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
111 if let Some(MessageSegment::Thinking {
112 text: segment,
113 signature: current_signature,
114 }) = self.segments.last_mut()
115 {
116 if let Some(signature) = signature {
117 *current_signature = Some(signature);
118 }
119 segment.push_str(text);
120 } else {
121 self.segments.push(MessageSegment::Thinking {
122 text: text.to_string(),
123 signature,
124 });
125 }
126 }
127
128 pub fn push_text(&mut self, text: &str) {
129 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
130 segment.push_str(text);
131 } else {
132 self.segments.push(MessageSegment::Text(text.to_string()));
133 }
134 }
135
136 pub fn to_string(&self) -> String {
137 let mut result = String::new();
138
139 if !self.context.is_empty() {
140 result.push_str(&self.context);
141 }
142
143 for segment in &self.segments {
144 match segment {
145 MessageSegment::Text(text) => result.push_str(text),
146 MessageSegment::Thinking { text, .. } => {
147 result.push_str("<think>\n");
148 result.push_str(text);
149 result.push_str("\n</think>");
150 }
151 MessageSegment::RedactedThinking(_) => {}
152 }
153 }
154
155 result
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub enum MessageSegment {
161 Text(String),
162 Thinking {
163 text: String,
164 signature: Option<String>,
165 },
166 RedactedThinking(Vec<u8>),
167}
168
169impl MessageSegment {
170 pub fn should_display(&self) -> bool {
171 match self {
172 Self::Text(text) => text.is_empty(),
173 Self::Thinking { text, .. } => text.is_empty(),
174 Self::RedactedThinking(_) => false,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ProjectSnapshot {
181 pub worktree_snapshots: Vec<WorktreeSnapshot>,
182 pub unsaved_buffer_paths: Vec<String>,
183 pub timestamp: DateTime<Utc>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct WorktreeSnapshot {
188 pub worktree_path: String,
189 pub git_state: Option<GitState>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GitState {
194 pub remote_url: Option<String>,
195 pub head_sha: Option<String>,
196 pub current_branch: Option<String>,
197 pub diff: Option<String>,
198}
199
200#[derive(Clone)]
201pub struct ThreadCheckpoint {
202 message_id: MessageId,
203 git_checkpoint: GitStoreCheckpoint,
204}
205
206#[derive(Copy, Clone, Debug, PartialEq, Eq)]
207pub enum ThreadFeedback {
208 Positive,
209 Negative,
210}
211
212pub enum LastRestoreCheckpoint {
213 Pending {
214 message_id: MessageId,
215 },
216 Error {
217 message_id: MessageId,
218 error: String,
219 },
220}
221
222impl LastRestoreCheckpoint {
223 pub fn message_id(&self) -> MessageId {
224 match self {
225 LastRestoreCheckpoint::Pending { message_id } => *message_id,
226 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
227 }
228 }
229}
230
231#[derive(Clone, Debug, Default, Serialize, Deserialize)]
232pub enum DetailedSummaryState {
233 #[default]
234 NotGenerated,
235 Generating {
236 message_id: MessageId,
237 },
238 Generated {
239 text: SharedString,
240 message_id: MessageId,
241 },
242}
243
244#[derive(Default)]
245pub struct TotalTokenUsage {
246 pub total: usize,
247 pub max: usize,
248}
249
250impl TotalTokenUsage {
251 pub fn ratio(&self) -> TokenUsageRatio {
252 #[cfg(debug_assertions)]
253 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
254 .unwrap_or("0.8".to_string())
255 .parse()
256 .unwrap();
257 #[cfg(not(debug_assertions))]
258 let warning_threshold: f32 = 0.8;
259
260 if self.total >= self.max {
261 TokenUsageRatio::Exceeded
262 } else if self.total as f32 / self.max as f32 >= warning_threshold {
263 TokenUsageRatio::Warning
264 } else {
265 TokenUsageRatio::Normal
266 }
267 }
268
269 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
270 TotalTokenUsage {
271 total: self.total + tokens,
272 max: self.max,
273 }
274 }
275}
276
277#[derive(Debug, Default, PartialEq, Eq)]
278pub enum TokenUsageRatio {
279 #[default]
280 Normal,
281 Warning,
282 Exceeded,
283}
284
285/// A thread of conversation with the LLM.
286pub struct Thread {
287 id: ThreadId,
288 updated_at: DateTime<Utc>,
289 summary: Option<SharedString>,
290 pending_summary: Task<Option<()>>,
291 detailed_summary_state: DetailedSummaryState,
292 messages: Vec<Message>,
293 next_message_id: MessageId,
294 last_prompt_id: PromptId,
295 context: BTreeMap<ContextId, AssistantContext>,
296 context_by_message: HashMap<MessageId, Vec<ContextId>>,
297 project_context: SharedProjectContext,
298 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
299 completion_count: usize,
300 pending_completions: Vec<PendingCompletion>,
301 project: Entity<Project>,
302 prompt_builder: Arc<PromptBuilder>,
303 tools: Entity<ToolWorkingSet>,
304 tool_use: ToolUseState,
305 action_log: Entity<ActionLog>,
306 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
307 pending_checkpoint: Option<ThreadCheckpoint>,
308 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
309 request_token_usage: Vec<TokenUsage>,
310 cumulative_token_usage: TokenUsage,
311 exceeded_window_error: Option<ExceededWindowError>,
312 feedback: Option<ThreadFeedback>,
313 message_feedback: HashMap<MessageId, ThreadFeedback>,
314 last_auto_capture_at: Option<Instant>,
315 request_callback: Option<
316 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
317 >,
318 remaining_turns: u32,
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct ExceededWindowError {
323 /// Model used when last message exceeded context window
324 model_id: LanguageModelId,
325 /// Token count including last message
326 token_count: usize,
327}
328
329impl Thread {
330 pub fn new(
331 project: Entity<Project>,
332 tools: Entity<ToolWorkingSet>,
333 prompt_builder: Arc<PromptBuilder>,
334 system_prompt: SharedProjectContext,
335 cx: &mut Context<Self>,
336 ) -> Self {
337 Self {
338 id: ThreadId::new(),
339 updated_at: Utc::now(),
340 summary: None,
341 pending_summary: Task::ready(None),
342 detailed_summary_state: DetailedSummaryState::NotGenerated,
343 messages: Vec::new(),
344 next_message_id: MessageId(0),
345 last_prompt_id: PromptId::new(),
346 context: BTreeMap::default(),
347 context_by_message: HashMap::default(),
348 project_context: system_prompt,
349 checkpoints_by_message: HashMap::default(),
350 completion_count: 0,
351 pending_completions: Vec::new(),
352 project: project.clone(),
353 prompt_builder,
354 tools: tools.clone(),
355 last_restore_checkpoint: None,
356 pending_checkpoint: None,
357 tool_use: ToolUseState::new(tools.clone()),
358 action_log: cx.new(|_| ActionLog::new(project.clone())),
359 initial_project_snapshot: {
360 let project_snapshot = Self::project_snapshot(project, cx);
361 cx.foreground_executor()
362 .spawn(async move { Some(project_snapshot.await) })
363 .shared()
364 },
365 request_token_usage: Vec::new(),
366 cumulative_token_usage: TokenUsage::default(),
367 exceeded_window_error: None,
368 feedback: None,
369 message_feedback: HashMap::default(),
370 last_auto_capture_at: None,
371 request_callback: None,
372 remaining_turns: u32::MAX,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use =
393 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
394
395 Self {
396 id,
397 updated_at: serialized.updated_at,
398 summary: Some(serialized.summary),
399 pending_summary: Task::ready(None),
400 detailed_summary_state: serialized.detailed_summary_state,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 context: message.context,
421 images: Vec::new(),
422 })
423 .collect(),
424 next_message_id,
425 last_prompt_id: PromptId::new(),
426 context: BTreeMap::default(),
427 context_by_message: HashMap::default(),
428 project_context,
429 checkpoints_by_message: HashMap::default(),
430 completion_count: 0,
431 pending_completions: Vec::new(),
432 last_restore_checkpoint: None,
433 pending_checkpoint: None,
434 project: project.clone(),
435 prompt_builder,
436 tools,
437 tool_use,
438 action_log: cx.new(|_| ActionLog::new(project)),
439 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
440 request_token_usage: serialized.request_token_usage,
441 cumulative_token_usage: serialized.cumulative_token_usage,
442 exceeded_window_error: None,
443 feedback: None,
444 message_feedback: HashMap::default(),
445 last_auto_capture_at: None,
446 request_callback: None,
447 remaining_turns: u32::MAX,
448 }
449 }
450
451 pub fn set_request_callback(
452 &mut self,
453 callback: impl 'static
454 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
455 ) {
456 self.request_callback = Some(Box::new(callback));
457 }
458
459 pub fn id(&self) -> &ThreadId {
460 &self.id
461 }
462
463 pub fn is_empty(&self) -> bool {
464 self.messages.is_empty()
465 }
466
467 pub fn updated_at(&self) -> DateTime<Utc> {
468 self.updated_at
469 }
470
471 pub fn touch_updated_at(&mut self) {
472 self.updated_at = Utc::now();
473 }
474
475 pub fn advance_prompt_id(&mut self) {
476 self.last_prompt_id = PromptId::new();
477 }
478
479 pub fn summary(&self) -> Option<SharedString> {
480 self.summary.clone()
481 }
482
483 pub fn project_context(&self) -> SharedProjectContext {
484 self.project_context.clone()
485 }
486
487 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
488
489 pub fn summary_or_default(&self) -> SharedString {
490 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
491 }
492
493 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
494 let Some(current_summary) = &self.summary else {
495 // Don't allow setting summary until generated
496 return;
497 };
498
499 let mut new_summary = new_summary.into();
500
501 if new_summary.is_empty() {
502 new_summary = Self::DEFAULT_SUMMARY;
503 }
504
505 if current_summary != &new_summary {
506 self.summary = Some(new_summary);
507 cx.emit(ThreadEvent::SummaryChanged);
508 }
509 }
510
511 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
512 self.latest_detailed_summary()
513 .unwrap_or_else(|| self.text().into())
514 }
515
516 fn latest_detailed_summary(&self) -> Option<SharedString> {
517 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
518 Some(text.clone())
519 } else {
520 None
521 }
522 }
523
524 pub fn message(&self, id: MessageId) -> Option<&Message> {
525 self.messages.iter().find(|message| message.id == id)
526 }
527
528 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
529 self.messages.iter()
530 }
531
532 pub fn is_generating(&self) -> bool {
533 !self.pending_completions.is_empty() || !self.all_tools_finished()
534 }
535
536 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
537 &self.tools
538 }
539
540 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
541 self.tool_use
542 .pending_tool_uses()
543 .into_iter()
544 .find(|tool_use| &tool_use.id == id)
545 }
546
547 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
548 self.tool_use
549 .pending_tool_uses()
550 .into_iter()
551 .filter(|tool_use| tool_use.status.needs_confirmation())
552 }
553
554 pub fn has_pending_tool_uses(&self) -> bool {
555 !self.tool_use.pending_tool_uses().is_empty()
556 }
557
558 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
559 self.checkpoints_by_message.get(&id).cloned()
560 }
561
562 pub fn restore_checkpoint(
563 &mut self,
564 checkpoint: ThreadCheckpoint,
565 cx: &mut Context<Self>,
566 ) -> Task<Result<()>> {
567 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
568 message_id: checkpoint.message_id,
569 });
570 cx.emit(ThreadEvent::CheckpointChanged);
571 cx.notify();
572
573 let git_store = self.project().read(cx).git_store().clone();
574 let restore = git_store.update(cx, |git_store, cx| {
575 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
576 });
577
578 cx.spawn(async move |this, cx| {
579 let result = restore.await;
580 this.update(cx, |this, cx| {
581 if let Err(err) = result.as_ref() {
582 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
583 message_id: checkpoint.message_id,
584 error: err.to_string(),
585 });
586 } else {
587 this.truncate(checkpoint.message_id, cx);
588 this.last_restore_checkpoint = None;
589 }
590 this.pending_checkpoint = None;
591 cx.emit(ThreadEvent::CheckpointChanged);
592 cx.notify();
593 })?;
594 result
595 })
596 }
597
598 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
599 let pending_checkpoint = if self.is_generating() {
600 return;
601 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
602 checkpoint
603 } else {
604 return;
605 };
606
607 let git_store = self.project.read(cx).git_store().clone();
608 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
609 cx.spawn(async move |this, cx| match final_checkpoint.await {
610 Ok(final_checkpoint) => {
611 let equal = git_store
612 .update(cx, |store, cx| {
613 store.compare_checkpoints(
614 pending_checkpoint.git_checkpoint.clone(),
615 final_checkpoint.clone(),
616 cx,
617 )
618 })?
619 .await
620 .unwrap_or(false);
621
622 if equal {
623 git_store
624 .update(cx, |store, cx| {
625 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
626 })?
627 .detach();
628 } else {
629 this.update(cx, |this, cx| {
630 this.insert_checkpoint(pending_checkpoint, cx)
631 })?;
632 }
633
634 git_store
635 .update(cx, |store, cx| {
636 store.delete_checkpoint(final_checkpoint, cx)
637 })?
638 .detach();
639
640 Ok(())
641 }
642 Err(_) => this.update(cx, |this, cx| {
643 this.insert_checkpoint(pending_checkpoint, cx)
644 }),
645 })
646 .detach();
647 }
648
649 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
650 self.checkpoints_by_message
651 .insert(checkpoint.message_id, checkpoint);
652 cx.emit(ThreadEvent::CheckpointChanged);
653 cx.notify();
654 }
655
656 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
657 self.last_restore_checkpoint.as_ref()
658 }
659
660 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
661 let Some(message_ix) = self
662 .messages
663 .iter()
664 .rposition(|message| message.id == message_id)
665 else {
666 return;
667 };
668 for deleted_message in self.messages.drain(message_ix..) {
669 self.context_by_message.remove(&deleted_message.id);
670 self.checkpoints_by_message.remove(&deleted_message.id);
671 }
672 cx.notify();
673 }
674
675 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
676 self.context_by_message
677 .get(&id)
678 .into_iter()
679 .flat_map(|context| {
680 context
681 .iter()
682 .filter_map(|context_id| self.context.get(&context_id))
683 })
684 }
685
686 /// Returns whether all of the tool uses have finished running.
687 pub fn all_tools_finished(&self) -> bool {
688 // If the only pending tool uses left are the ones with errors, then
689 // that means that we've finished running all of the pending tools.
690 self.tool_use
691 .pending_tool_uses()
692 .iter()
693 .all(|tool_use| tool_use.status.is_error())
694 }
695
696 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
697 self.tool_use.tool_uses_for_message(id, cx)
698 }
699
700 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
701 self.tool_use.tool_results_for_message(id)
702 }
703
704 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
705 self.tool_use.tool_result(id)
706 }
707
708 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
709 Some(&self.tool_use.tool_result(id)?.content)
710 }
711
712 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
713 self.tool_use.tool_result_card(id).cloned()
714 }
715
716 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
717 self.tool_use.message_has_tool_results(message_id)
718 }
719
720 /// Filter out contexts that have already been included in previous messages
721 pub fn filter_new_context<'a>(
722 &self,
723 context: impl Iterator<Item = &'a AssistantContext>,
724 ) -> impl Iterator<Item = &'a AssistantContext> {
725 context.filter(|ctx| self.is_context_new(ctx))
726 }
727
728 fn is_context_new(&self, context: &AssistantContext) -> bool {
729 !self.context.contains_key(&context.id())
730 }
731
732 pub fn insert_user_message(
733 &mut self,
734 text: impl Into<String>,
735 context: Vec<AssistantContext>,
736 git_checkpoint: Option<GitStoreCheckpoint>,
737 cx: &mut Context<Self>,
738 ) -> MessageId {
739 let text = text.into();
740
741 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
742
743 let new_context: Vec<_> = context
744 .into_iter()
745 .filter(|ctx| self.is_context_new(ctx))
746 .collect();
747
748 if !new_context.is_empty() {
749 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
750 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
751 message.context = context_string;
752 }
753 }
754
755 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
756 message.images = new_context
757 .iter()
758 .filter_map(|context| {
759 if let AssistantContext::Image(image_context) = context {
760 image_context.image_task.clone().now_or_never().flatten()
761 } else {
762 None
763 }
764 })
765 .collect::<Vec<_>>();
766 }
767
768 self.action_log.update(cx, |log, cx| {
769 // Track all buffers added as context
770 for ctx in &new_context {
771 match ctx {
772 AssistantContext::File(file_ctx) => {
773 log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
774 }
775 AssistantContext::Directory(dir_ctx) => {
776 for context_buffer in &dir_ctx.context_buffers {
777 log.track_buffer(context_buffer.buffer.clone(), cx);
778 }
779 }
780 AssistantContext::Symbol(symbol_ctx) => {
781 log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
782 }
783 AssistantContext::Selection(selection_context) => {
784 log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
785 }
786 AssistantContext::FetchedUrl(_)
787 | AssistantContext::Thread(_)
788 | AssistantContext::Rules(_)
789 | AssistantContext::Image(_) => {}
790 }
791 }
792 });
793 }
794
795 let context_ids = new_context
796 .iter()
797 .map(|context| context.id())
798 .collect::<Vec<_>>();
799 self.context.extend(
800 new_context
801 .into_iter()
802 .map(|context| (context.id(), context)),
803 );
804 self.context_by_message.insert(message_id, context_ids);
805
806 if let Some(git_checkpoint) = git_checkpoint {
807 self.pending_checkpoint = Some(ThreadCheckpoint {
808 message_id,
809 git_checkpoint,
810 });
811 }
812
813 self.auto_capture_telemetry(cx);
814
815 message_id
816 }
817
818 pub fn insert_message(
819 &mut self,
820 role: Role,
821 segments: Vec<MessageSegment>,
822 cx: &mut Context<Self>,
823 ) -> MessageId {
824 let id = self.next_message_id.post_inc();
825 self.messages.push(Message {
826 id,
827 role,
828 segments,
829 context: String::new(),
830 images: Vec::new(),
831 });
832 self.touch_updated_at();
833 cx.emit(ThreadEvent::MessageAdded(id));
834 id
835 }
836
837 pub fn edit_message(
838 &mut self,
839 id: MessageId,
840 new_role: Role,
841 new_segments: Vec<MessageSegment>,
842 cx: &mut Context<Self>,
843 ) -> bool {
844 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
845 return false;
846 };
847 message.role = new_role;
848 message.segments = new_segments;
849 self.touch_updated_at();
850 cx.emit(ThreadEvent::MessageEdited(id));
851 true
852 }
853
854 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
855 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
856 return false;
857 };
858 self.messages.remove(index);
859 self.context_by_message.remove(&id);
860 self.touch_updated_at();
861 cx.emit(ThreadEvent::MessageDeleted(id));
862 true
863 }
864
865 /// Returns the representation of this [`Thread`] in a textual form.
866 ///
867 /// This is the representation we use when attaching a thread as context to another thread.
868 pub fn text(&self) -> String {
869 let mut text = String::new();
870
871 for message in &self.messages {
872 text.push_str(match message.role {
873 language_model::Role::User => "User:",
874 language_model::Role::Assistant => "Assistant:",
875 language_model::Role::System => "System:",
876 });
877 text.push('\n');
878
879 for segment in &message.segments {
880 match segment {
881 MessageSegment::Text(content) => text.push_str(content),
882 MessageSegment::Thinking { text: content, .. } => {
883 text.push_str(&format!("<think>{}</think>", content))
884 }
885 MessageSegment::RedactedThinking(_) => {}
886 }
887 }
888 text.push('\n');
889 }
890
891 text
892 }
893
894 /// Serializes this thread into a format for storage or telemetry.
895 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
896 let initial_project_snapshot = self.initial_project_snapshot.clone();
897 cx.spawn(async move |this, cx| {
898 let initial_project_snapshot = initial_project_snapshot.await;
899 this.read_with(cx, |this, cx| SerializedThread {
900 version: SerializedThread::VERSION.to_string(),
901 summary: this.summary_or_default(),
902 updated_at: this.updated_at(),
903 messages: this
904 .messages()
905 .map(|message| SerializedMessage {
906 id: message.id,
907 role: message.role,
908 segments: message
909 .segments
910 .iter()
911 .map(|segment| match segment {
912 MessageSegment::Text(text) => {
913 SerializedMessageSegment::Text { text: text.clone() }
914 }
915 MessageSegment::Thinking { text, signature } => {
916 SerializedMessageSegment::Thinking {
917 text: text.clone(),
918 signature: signature.clone(),
919 }
920 }
921 MessageSegment::RedactedThinking(data) => {
922 SerializedMessageSegment::RedactedThinking {
923 data: data.clone(),
924 }
925 }
926 })
927 .collect(),
928 tool_uses: this
929 .tool_uses_for_message(message.id, cx)
930 .into_iter()
931 .map(|tool_use| SerializedToolUse {
932 id: tool_use.id,
933 name: tool_use.name,
934 input: tool_use.input,
935 })
936 .collect(),
937 tool_results: this
938 .tool_results_for_message(message.id)
939 .into_iter()
940 .map(|tool_result| SerializedToolResult {
941 tool_use_id: tool_result.tool_use_id.clone(),
942 is_error: tool_result.is_error,
943 content: tool_result.content.clone(),
944 })
945 .collect(),
946 context: message.context.clone(),
947 })
948 .collect(),
949 initial_project_snapshot,
950 cumulative_token_usage: this.cumulative_token_usage,
951 request_token_usage: this.request_token_usage.clone(),
952 detailed_summary_state: this.detailed_summary_state.clone(),
953 exceeded_window_error: this.exceeded_window_error.clone(),
954 })
955 })
956 }
957
958 pub fn remaining_turns(&self) -> u32 {
959 self.remaining_turns
960 }
961
962 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
963 self.remaining_turns = remaining_turns;
964 }
965
966 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
967 if self.remaining_turns == 0 {
968 return;
969 }
970
971 self.remaining_turns -= 1;
972
973 let mut request = self.to_completion_request(cx);
974 if model.supports_tools() {
975 request.tools = {
976 let mut tools = Vec::new();
977 tools.extend(
978 self.tools()
979 .read(cx)
980 .enabled_tools(cx)
981 .into_iter()
982 .filter_map(|tool| {
983 // Skip tools that cannot be supported
984 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
985 Some(LanguageModelRequestTool {
986 name: tool.name(),
987 description: tool.description(),
988 input_schema,
989 })
990 }),
991 );
992
993 tools
994 };
995 }
996
997 self.stream_completion(request, model, cx);
998 }
999
1000 pub fn used_tools_since_last_user_message(&self) -> bool {
1001 for message in self.messages.iter().rev() {
1002 if self.tool_use.message_has_tool_results(message.id) {
1003 return true;
1004 } else if message.role == Role::User {
1005 return false;
1006 }
1007 }
1008
1009 false
1010 }
1011
1012 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1013 let mut request = LanguageModelRequest {
1014 thread_id: Some(self.id.to_string()),
1015 prompt_id: Some(self.last_prompt_id.to_string()),
1016 messages: vec![],
1017 tools: Vec::new(),
1018 stop: Vec::new(),
1019 temperature: None,
1020 };
1021
1022 if let Some(project_context) = self.project_context.borrow().as_ref() {
1023 match self
1024 .prompt_builder
1025 .generate_assistant_system_prompt(project_context)
1026 {
1027 Err(err) => {
1028 let message = format!("{err:?}").into();
1029 log::error!("{message}");
1030 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1031 header: "Error generating system prompt".into(),
1032 message,
1033 }));
1034 }
1035 Ok(system_prompt) => {
1036 request.messages.push(LanguageModelRequestMessage {
1037 role: Role::System,
1038 content: vec![MessageContent::Text(system_prompt)],
1039 cache: true,
1040 });
1041 }
1042 }
1043 } else {
1044 let message = "Context for system prompt unexpectedly not ready.".into();
1045 log::error!("{message}");
1046 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1047 header: "Error generating system prompt".into(),
1048 message,
1049 }));
1050 }
1051
1052 for message in &self.messages {
1053 let mut request_message = LanguageModelRequestMessage {
1054 role: message.role,
1055 content: Vec::new(),
1056 cache: false,
1057 };
1058
1059 self.tool_use
1060 .attach_tool_results(message.id, &mut request_message);
1061
1062 if !message.context.is_empty() {
1063 request_message
1064 .content
1065 .push(MessageContent::Text(message.context.to_string()));
1066 }
1067
1068 if !message.images.is_empty() {
1069 // Some providers only support image parts after an initial text part
1070 if request_message.content.is_empty() {
1071 request_message
1072 .content
1073 .push(MessageContent::Text("Images attached by user:".to_string()));
1074 }
1075
1076 for image in &message.images {
1077 request_message
1078 .content
1079 .push(MessageContent::Image(image.clone()))
1080 }
1081 }
1082
1083 for segment in &message.segments {
1084 match segment {
1085 MessageSegment::Text(text) => {
1086 if !text.is_empty() {
1087 request_message
1088 .content
1089 .push(MessageContent::Text(text.into()));
1090 }
1091 }
1092 MessageSegment::Thinking { text, signature } => {
1093 if !text.is_empty() {
1094 request_message.content.push(MessageContent::Thinking {
1095 text: text.into(),
1096 signature: signature.clone(),
1097 });
1098 }
1099 }
1100 MessageSegment::RedactedThinking(data) => {
1101 request_message
1102 .content
1103 .push(MessageContent::RedactedThinking(data.clone()));
1104 }
1105 };
1106 }
1107
1108 self.tool_use
1109 .attach_tool_uses(message.id, &mut request_message);
1110
1111 request.messages.push(request_message);
1112 }
1113
1114 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1115 if let Some(last) = request.messages.last_mut() {
1116 last.cache = true;
1117 }
1118
1119 self.attached_tracked_files_state(&mut request.messages, cx);
1120
1121 request
1122 }
1123
1124 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1125 let mut request = LanguageModelRequest {
1126 thread_id: None,
1127 prompt_id: None,
1128 messages: vec![],
1129 tools: Vec::new(),
1130 stop: Vec::new(),
1131 temperature: None,
1132 };
1133
1134 for message in &self.messages {
1135 let mut request_message = LanguageModelRequestMessage {
1136 role: message.role,
1137 content: Vec::new(),
1138 cache: false,
1139 };
1140
1141 // Skip tool results during summarization.
1142 if self.tool_use.message_has_tool_results(message.id) {
1143 continue;
1144 }
1145
1146 for segment in &message.segments {
1147 match segment {
1148 MessageSegment::Text(text) => request_message
1149 .content
1150 .push(MessageContent::Text(text.clone())),
1151 MessageSegment::Thinking { .. } => {}
1152 MessageSegment::RedactedThinking(_) => {}
1153 }
1154 }
1155
1156 if request_message.content.is_empty() {
1157 continue;
1158 }
1159
1160 request.messages.push(request_message);
1161 }
1162
1163 request.messages.push(LanguageModelRequestMessage {
1164 role: Role::User,
1165 content: vec![MessageContent::Text(added_user_message)],
1166 cache: false,
1167 });
1168
1169 request
1170 }
1171
1172 fn attached_tracked_files_state(
1173 &self,
1174 messages: &mut Vec<LanguageModelRequestMessage>,
1175 cx: &App,
1176 ) {
1177 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1178
1179 let mut stale_message = String::new();
1180
1181 let action_log = self.action_log.read(cx);
1182
1183 for stale_file in action_log.stale_buffers(cx) {
1184 let Some(file) = stale_file.read(cx).file() else {
1185 continue;
1186 };
1187
1188 if stale_message.is_empty() {
1189 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1190 }
1191
1192 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1193 }
1194
1195 let mut content = Vec::with_capacity(2);
1196
1197 if !stale_message.is_empty() {
1198 content.push(stale_message.into());
1199 }
1200
1201 if !content.is_empty() {
1202 let context_message = LanguageModelRequestMessage {
1203 role: Role::User,
1204 content,
1205 cache: false,
1206 };
1207
1208 messages.push(context_message);
1209 }
1210 }
1211
1212 pub fn stream_completion(
1213 &mut self,
1214 request: LanguageModelRequest,
1215 model: Arc<dyn LanguageModel>,
1216 cx: &mut Context<Self>,
1217 ) {
1218 let pending_completion_id = post_inc(&mut self.completion_count);
1219 let mut request_callback_parameters = if self.request_callback.is_some() {
1220 Some((request.clone(), Vec::new()))
1221 } else {
1222 None
1223 };
1224 let prompt_id = self.last_prompt_id.clone();
1225 let tool_use_metadata = ToolUseMetadata {
1226 model: model.clone(),
1227 thread_id: self.id.clone(),
1228 prompt_id: prompt_id.clone(),
1229 };
1230
1231 let task = cx.spawn(async move |thread, cx| {
1232 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1233 let initial_token_usage =
1234 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1235 let stream_completion = async {
1236 let (mut events, usage) = stream_completion_future.await?;
1237
1238 let mut stop_reason = StopReason::EndTurn;
1239 let mut current_token_usage = TokenUsage::default();
1240
1241 if let Some(usage) = usage {
1242 thread
1243 .update(cx, |_thread, cx| {
1244 cx.emit(ThreadEvent::UsageUpdated(usage));
1245 })
1246 .ok();
1247 }
1248
1249 while let Some(event) = events.next().await {
1250 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1251 response_events
1252 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1253 }
1254
1255 let event = event?;
1256
1257 thread.update(cx, |thread, cx| {
1258 match event {
1259 LanguageModelCompletionEvent::StartMessage { .. } => {
1260 thread.insert_message(
1261 Role::Assistant,
1262 vec![MessageSegment::Text(String::new())],
1263 cx,
1264 );
1265 }
1266 LanguageModelCompletionEvent::Stop(reason) => {
1267 stop_reason = reason;
1268 }
1269 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1270 thread.update_token_usage_at_last_message(token_usage);
1271 thread.cumulative_token_usage = thread.cumulative_token_usage
1272 + token_usage
1273 - current_token_usage;
1274 current_token_usage = token_usage;
1275 }
1276 LanguageModelCompletionEvent::Text(chunk) => {
1277 cx.emit(ThreadEvent::ReceivedTextChunk);
1278 if let Some(last_message) = thread.messages.last_mut() {
1279 if last_message.role == Role::Assistant {
1280 last_message.push_text(&chunk);
1281 cx.emit(ThreadEvent::StreamedAssistantText(
1282 last_message.id,
1283 chunk,
1284 ));
1285 } else {
1286 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1287 // of a new Assistant response.
1288 //
1289 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1290 // will result in duplicating the text of the chunk in the rendered Markdown.
1291 thread.insert_message(
1292 Role::Assistant,
1293 vec![MessageSegment::Text(chunk.to_string())],
1294 cx,
1295 );
1296 };
1297 }
1298 }
1299 LanguageModelCompletionEvent::Thinking {
1300 text: chunk,
1301 signature,
1302 } => {
1303 if let Some(last_message) = thread.messages.last_mut() {
1304 if last_message.role == Role::Assistant {
1305 last_message.push_thinking(&chunk, signature);
1306 cx.emit(ThreadEvent::StreamedAssistantThinking(
1307 last_message.id,
1308 chunk,
1309 ));
1310 } else {
1311 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1312 // of a new Assistant response.
1313 //
1314 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1315 // will result in duplicating the text of the chunk in the rendered Markdown.
1316 thread.insert_message(
1317 Role::Assistant,
1318 vec![MessageSegment::Thinking {
1319 text: chunk.to_string(),
1320 signature,
1321 }],
1322 cx,
1323 );
1324 };
1325 }
1326 }
1327 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1328 let last_assistant_message_id = thread
1329 .messages
1330 .iter_mut()
1331 .rfind(|message| message.role == Role::Assistant)
1332 .map(|message| message.id)
1333 .unwrap_or_else(|| {
1334 thread.insert_message(Role::Assistant, vec![], cx)
1335 });
1336
1337 let tool_use_id = tool_use.id.clone();
1338 let streamed_input = if tool_use.is_input_complete {
1339 None
1340 } else {
1341 Some((&tool_use.input).clone())
1342 };
1343
1344 let ui_text = thread.tool_use.request_tool_use(
1345 last_assistant_message_id,
1346 tool_use,
1347 tool_use_metadata.clone(),
1348 cx,
1349 );
1350
1351 if let Some(input) = streamed_input {
1352 cx.emit(ThreadEvent::StreamedToolUse {
1353 tool_use_id,
1354 ui_text,
1355 input,
1356 });
1357 }
1358 }
1359 }
1360
1361 thread.touch_updated_at();
1362 cx.emit(ThreadEvent::StreamedCompletion);
1363 cx.notify();
1364
1365 thread.auto_capture_telemetry(cx);
1366 })?;
1367
1368 smol::future::yield_now().await;
1369 }
1370
1371 thread.update(cx, |thread, cx| {
1372 thread
1373 .pending_completions
1374 .retain(|completion| completion.id != pending_completion_id);
1375
1376 // If there is a response without tool use, summarize the message. Otherwise,
1377 // allow two tool uses before summarizing.
1378 if thread.summary.is_none()
1379 && thread.messages.len() >= 2
1380 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1381 {
1382 thread.summarize(cx);
1383 }
1384 })?;
1385
1386 anyhow::Ok(stop_reason)
1387 };
1388
1389 let result = stream_completion.await;
1390
1391 thread
1392 .update(cx, |thread, cx| {
1393 thread.finalize_pending_checkpoint(cx);
1394 match result.as_ref() {
1395 Ok(stop_reason) => match stop_reason {
1396 StopReason::ToolUse => {
1397 let tool_uses = thread.use_pending_tools(cx);
1398 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1399 }
1400 StopReason::EndTurn => {}
1401 StopReason::MaxTokens => {}
1402 },
1403 Err(error) => {
1404 if error.is::<PaymentRequiredError>() {
1405 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1406 } else if error.is::<MaxMonthlySpendReachedError>() {
1407 cx.emit(ThreadEvent::ShowError(
1408 ThreadError::MaxMonthlySpendReached,
1409 ));
1410 } else if let Some(error) =
1411 error.downcast_ref::<ModelRequestLimitReachedError>()
1412 {
1413 cx.emit(ThreadEvent::ShowError(
1414 ThreadError::ModelRequestLimitReached { plan: error.plan },
1415 ));
1416 } else if let Some(known_error) =
1417 error.downcast_ref::<LanguageModelKnownError>()
1418 {
1419 match known_error {
1420 LanguageModelKnownError::ContextWindowLimitExceeded {
1421 tokens,
1422 } => {
1423 thread.exceeded_window_error = Some(ExceededWindowError {
1424 model_id: model.id(),
1425 token_count: *tokens,
1426 });
1427 cx.notify();
1428 }
1429 }
1430 } else {
1431 let error_message = error
1432 .chain()
1433 .map(|err| err.to_string())
1434 .collect::<Vec<_>>()
1435 .join("\n");
1436 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1437 header: "Error interacting with language model".into(),
1438 message: SharedString::from(error_message.clone()),
1439 }));
1440 }
1441
1442 thread.cancel_last_completion(cx);
1443 }
1444 }
1445 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1446
1447 if let Some((request_callback, (request, response_events))) = thread
1448 .request_callback
1449 .as_mut()
1450 .zip(request_callback_parameters.as_ref())
1451 {
1452 request_callback(request, response_events);
1453 }
1454
1455 thread.auto_capture_telemetry(cx);
1456
1457 if let Ok(initial_usage) = initial_token_usage {
1458 let usage = thread.cumulative_token_usage - initial_usage;
1459
1460 telemetry::event!(
1461 "Assistant Thread Completion",
1462 thread_id = thread.id().to_string(),
1463 prompt_id = prompt_id,
1464 model = model.telemetry_id(),
1465 model_provider = model.provider_id().to_string(),
1466 input_tokens = usage.input_tokens,
1467 output_tokens = usage.output_tokens,
1468 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1469 cache_read_input_tokens = usage.cache_read_input_tokens,
1470 );
1471 }
1472 })
1473 .ok();
1474 });
1475
1476 self.pending_completions.push(PendingCompletion {
1477 id: pending_completion_id,
1478 _task: task,
1479 });
1480 }
1481
1482 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1483 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1484 return;
1485 };
1486
1487 if !model.provider.is_authenticated(cx) {
1488 return;
1489 }
1490
1491 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1492 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1493 If the conversation is about a specific subject, include it in the title. \
1494 Be descriptive. DO NOT speak in the first person.";
1495
1496 let request = self.to_summarize_request(added_user_message.into());
1497
1498 self.pending_summary = cx.spawn(async move |this, cx| {
1499 async move {
1500 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1501 let (mut messages, usage) = stream.await?;
1502
1503 if let Some(usage) = usage {
1504 this.update(cx, |_thread, cx| {
1505 cx.emit(ThreadEvent::UsageUpdated(usage));
1506 })
1507 .ok();
1508 }
1509
1510 let mut new_summary = String::new();
1511 while let Some(message) = messages.stream.next().await {
1512 let text = message?;
1513 let mut lines = text.lines();
1514 new_summary.extend(lines.next());
1515
1516 // Stop if the LLM generated multiple lines.
1517 if lines.next().is_some() {
1518 break;
1519 }
1520 }
1521
1522 this.update(cx, |this, cx| {
1523 if !new_summary.is_empty() {
1524 this.summary = Some(new_summary.into());
1525 }
1526
1527 cx.emit(ThreadEvent::SummaryGenerated);
1528 })?;
1529
1530 anyhow::Ok(())
1531 }
1532 .log_err()
1533 .await
1534 });
1535 }
1536
1537 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1538 let last_message_id = self.messages.last().map(|message| message.id)?;
1539
1540 match &self.detailed_summary_state {
1541 DetailedSummaryState::Generating { message_id, .. }
1542 | DetailedSummaryState::Generated { message_id, .. }
1543 if *message_id == last_message_id =>
1544 {
1545 // Already up-to-date
1546 return None;
1547 }
1548 _ => {}
1549 }
1550
1551 let ConfiguredModel { model, provider } =
1552 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1553
1554 if !provider.is_authenticated(cx) {
1555 return None;
1556 }
1557
1558 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1559 1. A brief overview of what was discussed\n\
1560 2. Key facts or information discovered\n\
1561 3. Outcomes or conclusions reached\n\
1562 4. Any action items or next steps if any\n\
1563 Format it in Markdown with headings and bullet points.";
1564
1565 let request = self.to_summarize_request(added_user_message.into());
1566
1567 let task = cx.spawn(async move |thread, cx| {
1568 let stream = model.stream_completion_text(request, &cx);
1569 let Some(mut messages) = stream.await.log_err() else {
1570 thread
1571 .update(cx, |this, _cx| {
1572 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1573 })
1574 .log_err();
1575
1576 return;
1577 };
1578
1579 let mut new_detailed_summary = String::new();
1580
1581 while let Some(chunk) = messages.stream.next().await {
1582 if let Some(chunk) = chunk.log_err() {
1583 new_detailed_summary.push_str(&chunk);
1584 }
1585 }
1586
1587 thread
1588 .update(cx, |this, _cx| {
1589 this.detailed_summary_state = DetailedSummaryState::Generated {
1590 text: new_detailed_summary.into(),
1591 message_id: last_message_id,
1592 };
1593 })
1594 .log_err();
1595 });
1596
1597 self.detailed_summary_state = DetailedSummaryState::Generating {
1598 message_id: last_message_id,
1599 };
1600
1601 Some(task)
1602 }
1603
1604 pub fn is_generating_detailed_summary(&self) -> bool {
1605 matches!(
1606 self.detailed_summary_state,
1607 DetailedSummaryState::Generating { .. }
1608 )
1609 }
1610
1611 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1612 self.auto_capture_telemetry(cx);
1613 let request = self.to_completion_request(cx);
1614 let messages = Arc::new(request.messages);
1615 let pending_tool_uses = self
1616 .tool_use
1617 .pending_tool_uses()
1618 .into_iter()
1619 .filter(|tool_use| tool_use.status.is_idle())
1620 .cloned()
1621 .collect::<Vec<_>>();
1622
1623 for tool_use in pending_tool_uses.iter() {
1624 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1625 if tool.needs_confirmation(&tool_use.input, cx)
1626 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1627 {
1628 self.tool_use.confirm_tool_use(
1629 tool_use.id.clone(),
1630 tool_use.ui_text.clone(),
1631 tool_use.input.clone(),
1632 messages.clone(),
1633 tool,
1634 );
1635 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1636 } else {
1637 self.run_tool(
1638 tool_use.id.clone(),
1639 tool_use.ui_text.clone(),
1640 tool_use.input.clone(),
1641 &messages,
1642 tool,
1643 cx,
1644 );
1645 }
1646 }
1647 }
1648
1649 pending_tool_uses
1650 }
1651
1652 pub fn run_tool(
1653 &mut self,
1654 tool_use_id: LanguageModelToolUseId,
1655 ui_text: impl Into<SharedString>,
1656 input: serde_json::Value,
1657 messages: &[LanguageModelRequestMessage],
1658 tool: Arc<dyn Tool>,
1659 cx: &mut Context<Thread>,
1660 ) {
1661 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1662 self.tool_use
1663 .run_pending_tool(tool_use_id, ui_text.into(), task);
1664 }
1665
1666 fn spawn_tool_use(
1667 &mut self,
1668 tool_use_id: LanguageModelToolUseId,
1669 messages: &[LanguageModelRequestMessage],
1670 input: serde_json::Value,
1671 tool: Arc<dyn Tool>,
1672 cx: &mut Context<Thread>,
1673 ) -> Task<()> {
1674 let tool_name: Arc<str> = tool.name().into();
1675
1676 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1677 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1678 } else {
1679 tool.run(
1680 input,
1681 messages,
1682 self.project.clone(),
1683 self.action_log.clone(),
1684 cx,
1685 )
1686 };
1687
1688 // Store the card separately if it exists
1689 if let Some(card) = tool_result.card.clone() {
1690 self.tool_use
1691 .insert_tool_result_card(tool_use_id.clone(), card);
1692 }
1693
1694 cx.spawn({
1695 async move |thread: WeakEntity<Thread>, cx| {
1696 let output = tool_result.output.await;
1697
1698 thread
1699 .update(cx, |thread, cx| {
1700 let pending_tool_use = thread.tool_use.insert_tool_output(
1701 tool_use_id.clone(),
1702 tool_name,
1703 output,
1704 cx,
1705 );
1706 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1707 })
1708 .ok();
1709 }
1710 })
1711 }
1712
1713 fn tool_finished(
1714 &mut self,
1715 tool_use_id: LanguageModelToolUseId,
1716 pending_tool_use: Option<PendingToolUse>,
1717 canceled: bool,
1718 cx: &mut Context<Self>,
1719 ) {
1720 if self.all_tools_finished() {
1721 let model_registry = LanguageModelRegistry::read_global(cx);
1722 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1723 self.attach_tool_results(cx);
1724 if !canceled {
1725 self.send_to_model(model, cx);
1726 }
1727 }
1728 }
1729
1730 cx.emit(ThreadEvent::ToolFinished {
1731 tool_use_id,
1732 pending_tool_use,
1733 });
1734 }
1735
1736 /// Insert an empty message to be populated with tool results upon send.
1737 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1738 // Tool results are assumed to be waiting on the next message id, so they will populate
1739 // this empty message before sending to model. Would prefer this to be more straightforward.
1740 self.insert_message(Role::User, vec![], cx);
1741 self.auto_capture_telemetry(cx);
1742 }
1743
1744 /// Cancels the last pending completion, if there are any pending.
1745 ///
1746 /// Returns whether a completion was canceled.
1747 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1748 let canceled = if self.pending_completions.pop().is_some() {
1749 true
1750 } else {
1751 let mut canceled = false;
1752 for pending_tool_use in self.tool_use.cancel_pending() {
1753 canceled = true;
1754 self.tool_finished(
1755 pending_tool_use.id.clone(),
1756 Some(pending_tool_use),
1757 true,
1758 cx,
1759 );
1760 }
1761 canceled
1762 };
1763 self.finalize_pending_checkpoint(cx);
1764 canceled
1765 }
1766
1767 pub fn feedback(&self) -> Option<ThreadFeedback> {
1768 self.feedback
1769 }
1770
1771 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1772 self.message_feedback.get(&message_id).copied()
1773 }
1774
1775 pub fn report_message_feedback(
1776 &mut self,
1777 message_id: MessageId,
1778 feedback: ThreadFeedback,
1779 cx: &mut Context<Self>,
1780 ) -> Task<Result<()>> {
1781 if self.message_feedback.get(&message_id) == Some(&feedback) {
1782 return Task::ready(Ok(()));
1783 }
1784
1785 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1786 let serialized_thread = self.serialize(cx);
1787 let thread_id = self.id().clone();
1788 let client = self.project.read(cx).client();
1789
1790 let enabled_tool_names: Vec<String> = self
1791 .tools()
1792 .read(cx)
1793 .enabled_tools(cx)
1794 .iter()
1795 .map(|tool| tool.name().to_string())
1796 .collect();
1797
1798 self.message_feedback.insert(message_id, feedback);
1799
1800 cx.notify();
1801
1802 let message_content = self
1803 .message(message_id)
1804 .map(|msg| msg.to_string())
1805 .unwrap_or_default();
1806
1807 cx.background_spawn(async move {
1808 let final_project_snapshot = final_project_snapshot.await;
1809 let serialized_thread = serialized_thread.await?;
1810 let thread_data =
1811 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1812
1813 let rating = match feedback {
1814 ThreadFeedback::Positive => "positive",
1815 ThreadFeedback::Negative => "negative",
1816 };
1817 telemetry::event!(
1818 "Assistant Thread Rated",
1819 rating,
1820 thread_id,
1821 enabled_tool_names,
1822 message_id = message_id.0,
1823 message_content,
1824 thread_data,
1825 final_project_snapshot
1826 );
1827 client.telemetry().flush_events().await;
1828
1829 Ok(())
1830 })
1831 }
1832
1833 pub fn report_feedback(
1834 &mut self,
1835 feedback: ThreadFeedback,
1836 cx: &mut Context<Self>,
1837 ) -> Task<Result<()>> {
1838 let last_assistant_message_id = self
1839 .messages
1840 .iter()
1841 .rev()
1842 .find(|msg| msg.role == Role::Assistant)
1843 .map(|msg| msg.id);
1844
1845 if let Some(message_id) = last_assistant_message_id {
1846 self.report_message_feedback(message_id, feedback, cx)
1847 } else {
1848 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1849 let serialized_thread = self.serialize(cx);
1850 let thread_id = self.id().clone();
1851 let client = self.project.read(cx).client();
1852 self.feedback = Some(feedback);
1853 cx.notify();
1854
1855 cx.background_spawn(async move {
1856 let final_project_snapshot = final_project_snapshot.await;
1857 let serialized_thread = serialized_thread.await?;
1858 let thread_data = serde_json::to_value(serialized_thread)
1859 .unwrap_or_else(|_| serde_json::Value::Null);
1860
1861 let rating = match feedback {
1862 ThreadFeedback::Positive => "positive",
1863 ThreadFeedback::Negative => "negative",
1864 };
1865 telemetry::event!(
1866 "Assistant Thread Rated",
1867 rating,
1868 thread_id,
1869 thread_data,
1870 final_project_snapshot
1871 );
1872 client.telemetry().flush_events().await;
1873
1874 Ok(())
1875 })
1876 }
1877 }
1878
1879 /// Create a snapshot of the current project state including git information and unsaved buffers.
1880 fn project_snapshot(
1881 project: Entity<Project>,
1882 cx: &mut Context<Self>,
1883 ) -> Task<Arc<ProjectSnapshot>> {
1884 let git_store = project.read(cx).git_store().clone();
1885 let worktree_snapshots: Vec<_> = project
1886 .read(cx)
1887 .visible_worktrees(cx)
1888 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1889 .collect();
1890
1891 cx.spawn(async move |_, cx| {
1892 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1893
1894 let mut unsaved_buffers = Vec::new();
1895 cx.update(|app_cx| {
1896 let buffer_store = project.read(app_cx).buffer_store();
1897 for buffer_handle in buffer_store.read(app_cx).buffers() {
1898 let buffer = buffer_handle.read(app_cx);
1899 if buffer.is_dirty() {
1900 if let Some(file) = buffer.file() {
1901 let path = file.path().to_string_lossy().to_string();
1902 unsaved_buffers.push(path);
1903 }
1904 }
1905 }
1906 })
1907 .ok();
1908
1909 Arc::new(ProjectSnapshot {
1910 worktree_snapshots,
1911 unsaved_buffer_paths: unsaved_buffers,
1912 timestamp: Utc::now(),
1913 })
1914 })
1915 }
1916
1917 fn worktree_snapshot(
1918 worktree: Entity<project::Worktree>,
1919 git_store: Entity<GitStore>,
1920 cx: &App,
1921 ) -> Task<WorktreeSnapshot> {
1922 cx.spawn(async move |cx| {
1923 // Get worktree path and snapshot
1924 let worktree_info = cx.update(|app_cx| {
1925 let worktree = worktree.read(app_cx);
1926 let path = worktree.abs_path().to_string_lossy().to_string();
1927 let snapshot = worktree.snapshot();
1928 (path, snapshot)
1929 });
1930
1931 let Ok((worktree_path, _snapshot)) = worktree_info else {
1932 return WorktreeSnapshot {
1933 worktree_path: String::new(),
1934 git_state: None,
1935 };
1936 };
1937
1938 let git_state = git_store
1939 .update(cx, |git_store, cx| {
1940 git_store
1941 .repositories()
1942 .values()
1943 .find(|repo| {
1944 repo.read(cx)
1945 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1946 .is_some()
1947 })
1948 .cloned()
1949 })
1950 .ok()
1951 .flatten()
1952 .map(|repo| {
1953 repo.update(cx, |repo, _| {
1954 let current_branch =
1955 repo.branch.as_ref().map(|branch| branch.name.to_string());
1956 repo.send_job(None, |state, _| async move {
1957 let RepositoryState::Local { backend, .. } = state else {
1958 return GitState {
1959 remote_url: None,
1960 head_sha: None,
1961 current_branch,
1962 diff: None,
1963 };
1964 };
1965
1966 let remote_url = backend.remote_url("origin");
1967 let head_sha = backend.head_sha();
1968 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1969
1970 GitState {
1971 remote_url,
1972 head_sha,
1973 current_branch,
1974 diff,
1975 }
1976 })
1977 })
1978 });
1979
1980 let git_state = match git_state {
1981 Some(git_state) => match git_state.ok() {
1982 Some(git_state) => git_state.await.ok(),
1983 None => None,
1984 },
1985 None => None,
1986 };
1987
1988 WorktreeSnapshot {
1989 worktree_path,
1990 git_state,
1991 }
1992 })
1993 }
1994
1995 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1996 let mut markdown = Vec::new();
1997
1998 if let Some(summary) = self.summary() {
1999 writeln!(markdown, "# {summary}\n")?;
2000 };
2001
2002 for message in self.messages() {
2003 writeln!(
2004 markdown,
2005 "## {role}\n",
2006 role = match message.role {
2007 Role::User => "User",
2008 Role::Assistant => "Assistant",
2009 Role::System => "System",
2010 }
2011 )?;
2012
2013 if !message.context.is_empty() {
2014 writeln!(markdown, "{}", message.context)?;
2015 }
2016
2017 for segment in &message.segments {
2018 match segment {
2019 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2020 MessageSegment::Thinking { text, .. } => {
2021 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2022 }
2023 MessageSegment::RedactedThinking(_) => {}
2024 }
2025 }
2026
2027 for tool_use in self.tool_uses_for_message(message.id, cx) {
2028 writeln!(
2029 markdown,
2030 "**Use Tool: {} ({})**",
2031 tool_use.name, tool_use.id
2032 )?;
2033 writeln!(markdown, "```json")?;
2034 writeln!(
2035 markdown,
2036 "{}",
2037 serde_json::to_string_pretty(&tool_use.input)?
2038 )?;
2039 writeln!(markdown, "```")?;
2040 }
2041
2042 for tool_result in self.tool_results_for_message(message.id) {
2043 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2044 if tool_result.is_error {
2045 write!(markdown, " (Error)")?;
2046 }
2047
2048 writeln!(markdown, "**\n")?;
2049 writeln!(markdown, "{}", tool_result.content)?;
2050 }
2051 }
2052
2053 Ok(String::from_utf8_lossy(&markdown).to_string())
2054 }
2055
2056 pub fn keep_edits_in_range(
2057 &mut self,
2058 buffer: Entity<language::Buffer>,
2059 buffer_range: Range<language::Anchor>,
2060 cx: &mut Context<Self>,
2061 ) {
2062 self.action_log.update(cx, |action_log, cx| {
2063 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2064 });
2065 }
2066
2067 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2068 self.action_log
2069 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2070 }
2071
2072 pub fn reject_edits_in_ranges(
2073 &mut self,
2074 buffer: Entity<language::Buffer>,
2075 buffer_ranges: Vec<Range<language::Anchor>>,
2076 cx: &mut Context<Self>,
2077 ) -> Task<Result<()>> {
2078 self.action_log.update(cx, |action_log, cx| {
2079 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2080 })
2081 }
2082
2083 pub fn action_log(&self) -> &Entity<ActionLog> {
2084 &self.action_log
2085 }
2086
2087 pub fn project(&self) -> &Entity<Project> {
2088 &self.project
2089 }
2090
2091 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2092 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2093 return;
2094 }
2095
2096 let now = Instant::now();
2097 if let Some(last) = self.last_auto_capture_at {
2098 if now.duration_since(last).as_secs() < 10 {
2099 return;
2100 }
2101 }
2102
2103 self.last_auto_capture_at = Some(now);
2104
2105 let thread_id = self.id().clone();
2106 let github_login = self
2107 .project
2108 .read(cx)
2109 .user_store()
2110 .read(cx)
2111 .current_user()
2112 .map(|user| user.github_login.clone());
2113 let client = self.project.read(cx).client().clone();
2114 let serialize_task = self.serialize(cx);
2115
2116 cx.background_executor()
2117 .spawn(async move {
2118 if let Ok(serialized_thread) = serialize_task.await {
2119 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2120 telemetry::event!(
2121 "Agent Thread Auto-Captured",
2122 thread_id = thread_id.to_string(),
2123 thread_data = thread_data,
2124 auto_capture_reason = "tracked_user",
2125 github_login = github_login
2126 );
2127
2128 client.telemetry().flush_events().await;
2129 }
2130 }
2131 })
2132 .detach();
2133 }
2134
2135 pub fn cumulative_token_usage(&self) -> TokenUsage {
2136 self.cumulative_token_usage
2137 }
2138
2139 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2140 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2141 return TotalTokenUsage::default();
2142 };
2143
2144 let max = model.model.max_token_count();
2145
2146 let index = self
2147 .messages
2148 .iter()
2149 .position(|msg| msg.id == message_id)
2150 .unwrap_or(0);
2151
2152 if index == 0 {
2153 return TotalTokenUsage { total: 0, max };
2154 }
2155
2156 let token_usage = &self
2157 .request_token_usage
2158 .get(index - 1)
2159 .cloned()
2160 .unwrap_or_default();
2161
2162 TotalTokenUsage {
2163 total: token_usage.total_tokens() as usize,
2164 max,
2165 }
2166 }
2167
2168 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2169 let model_registry = LanguageModelRegistry::read_global(cx);
2170 let Some(model) = model_registry.default_model() else {
2171 return TotalTokenUsage::default();
2172 };
2173
2174 let max = model.model.max_token_count();
2175
2176 if let Some(exceeded_error) = &self.exceeded_window_error {
2177 if model.model.id() == exceeded_error.model_id {
2178 return TotalTokenUsage {
2179 total: exceeded_error.token_count,
2180 max,
2181 };
2182 }
2183 }
2184
2185 let total = self
2186 .token_usage_at_last_message()
2187 .unwrap_or_default()
2188 .total_tokens() as usize;
2189
2190 TotalTokenUsage { total, max }
2191 }
2192
2193 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2194 self.request_token_usage
2195 .get(self.messages.len().saturating_sub(1))
2196 .or_else(|| self.request_token_usage.last())
2197 .cloned()
2198 }
2199
2200 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2201 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2202 self.request_token_usage
2203 .resize(self.messages.len(), placeholder);
2204
2205 if let Some(last) = self.request_token_usage.last_mut() {
2206 *last = token_usage;
2207 }
2208 }
2209
2210 pub fn deny_tool_use(
2211 &mut self,
2212 tool_use_id: LanguageModelToolUseId,
2213 tool_name: Arc<str>,
2214 cx: &mut Context<Self>,
2215 ) {
2216 let err = Err(anyhow::anyhow!(
2217 "Permission to run tool action denied by user"
2218 ));
2219
2220 self.tool_use
2221 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2222 self.tool_finished(tool_use_id.clone(), None, true, cx);
2223 }
2224}
2225
2226#[derive(Debug, Clone, Error)]
2227pub enum ThreadError {
2228 #[error("Payment required")]
2229 PaymentRequired,
2230 #[error("Max monthly spend reached")]
2231 MaxMonthlySpendReached,
2232 #[error("Model request limit reached")]
2233 ModelRequestLimitReached { plan: Plan },
2234 #[error("Message {header}: {message}")]
2235 Message {
2236 header: SharedString,
2237 message: SharedString,
2238 },
2239}
2240
2241#[derive(Debug, Clone)]
2242pub enum ThreadEvent {
2243 ShowError(ThreadError),
2244 UsageUpdated(RequestUsage),
2245 StreamedCompletion,
2246 ReceivedTextChunk,
2247 StreamedAssistantText(MessageId, String),
2248 StreamedAssistantThinking(MessageId, String),
2249 StreamedToolUse {
2250 tool_use_id: LanguageModelToolUseId,
2251 ui_text: Arc<str>,
2252 input: serde_json::Value,
2253 },
2254 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2255 MessageAdded(MessageId),
2256 MessageEdited(MessageId),
2257 MessageDeleted(MessageId),
2258 SummaryGenerated,
2259 SummaryChanged,
2260 UsePendingTools {
2261 tool_uses: Vec<PendingToolUse>,
2262 },
2263 ToolFinished {
2264 #[allow(unused)]
2265 tool_use_id: LanguageModelToolUseId,
2266 /// The pending tool use that corresponds to this tool.
2267 pending_tool_use: Option<PendingToolUse>,
2268 },
2269 CheckpointChanged,
2270 ToolConfirmationNeeded,
2271}
2272
2273impl EventEmitter<ThreadEvent> for Thread {}
2274
2275struct PendingCompletion {
2276 id: usize,
2277 _task: Task<()>,
2278}
2279
2280#[cfg(test)]
2281mod tests {
2282 use super::*;
2283 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2284 use assistant_settings::AssistantSettings;
2285 use context_server::ContextServerSettings;
2286 use editor::EditorSettings;
2287 use gpui::TestAppContext;
2288 use project::{FakeFs, Project};
2289 use prompt_store::PromptBuilder;
2290 use serde_json::json;
2291 use settings::{Settings, SettingsStore};
2292 use std::sync::Arc;
2293 use theme::ThemeSettings;
2294 use util::path;
2295 use workspace::Workspace;
2296
2297 #[gpui::test]
2298 async fn test_message_with_context(cx: &mut TestAppContext) {
2299 init_test_settings(cx);
2300
2301 let project = create_test_project(
2302 cx,
2303 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2304 )
2305 .await;
2306
2307 let (_workspace, _thread_store, thread, context_store) =
2308 setup_test_environment(cx, project.clone()).await;
2309
2310 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2311 .await
2312 .unwrap();
2313
2314 let context =
2315 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2316
2317 // Insert user message with context
2318 let message_id = thread.update(cx, |thread, cx| {
2319 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2320 });
2321
2322 // Check content and context in message object
2323 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2324
2325 // Use different path format strings based on platform for the test
2326 #[cfg(windows)]
2327 let path_part = r"test\code.rs";
2328 #[cfg(not(windows))]
2329 let path_part = "test/code.rs";
2330
2331 let expected_context = format!(
2332 r#"
2333<context>
2334The following items were attached by the user. You don't need to use other tools to read them.
2335
2336<files>
2337```rs {path_part}
2338fn main() {{
2339 println!("Hello, world!");
2340}}
2341```
2342</files>
2343</context>
2344"#
2345 );
2346
2347 assert_eq!(message.role, Role::User);
2348 assert_eq!(message.segments.len(), 1);
2349 assert_eq!(
2350 message.segments[0],
2351 MessageSegment::Text("Please explain this code".to_string())
2352 );
2353 assert_eq!(message.context, expected_context);
2354
2355 // Check message in request
2356 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2357
2358 assert_eq!(request.messages.len(), 2);
2359 let expected_full_message = format!("{}Please explain this code", expected_context);
2360 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2361 }
2362
2363 #[gpui::test]
2364 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2365 init_test_settings(cx);
2366
2367 let project = create_test_project(
2368 cx,
2369 json!({
2370 "file1.rs": "fn function1() {}\n",
2371 "file2.rs": "fn function2() {}\n",
2372 "file3.rs": "fn function3() {}\n",
2373 }),
2374 )
2375 .await;
2376
2377 let (_, _thread_store, thread, context_store) =
2378 setup_test_environment(cx, project.clone()).await;
2379
2380 // Open files individually
2381 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2382 .await
2383 .unwrap();
2384 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2385 .await
2386 .unwrap();
2387 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2388 .await
2389 .unwrap();
2390
2391 // Get the context objects
2392 let contexts = context_store.update(cx, |store, _| store.context().clone());
2393 assert_eq!(contexts.len(), 3);
2394
2395 // First message with context 1
2396 let message1_id = thread.update(cx, |thread, cx| {
2397 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2398 });
2399
2400 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2401 let message2_id = thread.update(cx, |thread, cx| {
2402 thread.insert_user_message(
2403 "Message 2",
2404 vec![contexts[0].clone(), contexts[1].clone()],
2405 None,
2406 cx,
2407 )
2408 });
2409
2410 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2411 let message3_id = thread.update(cx, |thread, cx| {
2412 thread.insert_user_message(
2413 "Message 3",
2414 vec![
2415 contexts[0].clone(),
2416 contexts[1].clone(),
2417 contexts[2].clone(),
2418 ],
2419 None,
2420 cx,
2421 )
2422 });
2423
2424 // Check what contexts are included in each message
2425 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2426 (
2427 thread.message(message1_id).unwrap().clone(),
2428 thread.message(message2_id).unwrap().clone(),
2429 thread.message(message3_id).unwrap().clone(),
2430 )
2431 });
2432
2433 // First message should include context 1
2434 assert!(message1.context.contains("file1.rs"));
2435
2436 // Second message should include only context 2 (not 1)
2437 assert!(!message2.context.contains("file1.rs"));
2438 assert!(message2.context.contains("file2.rs"));
2439
2440 // Third message should include only context 3 (not 1 or 2)
2441 assert!(!message3.context.contains("file1.rs"));
2442 assert!(!message3.context.contains("file2.rs"));
2443 assert!(message3.context.contains("file3.rs"));
2444
2445 // Check entire request to make sure all contexts are properly included
2446 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2447
2448 // The request should contain all 3 messages
2449 assert_eq!(request.messages.len(), 4);
2450
2451 // Check that the contexts are properly formatted in each message
2452 assert!(request.messages[1].string_contents().contains("file1.rs"));
2453 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2454 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2455
2456 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2457 assert!(request.messages[2].string_contents().contains("file2.rs"));
2458 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2459
2460 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2461 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2462 assert!(request.messages[3].string_contents().contains("file3.rs"));
2463 }
2464
2465 #[gpui::test]
2466 async fn test_message_without_files(cx: &mut TestAppContext) {
2467 init_test_settings(cx);
2468
2469 let project = create_test_project(
2470 cx,
2471 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2472 )
2473 .await;
2474
2475 let (_, _thread_store, thread, _context_store) =
2476 setup_test_environment(cx, project.clone()).await;
2477
2478 // Insert user message without any context (empty context vector)
2479 let message_id = thread.update(cx, |thread, cx| {
2480 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2481 });
2482
2483 // Check content and context in message object
2484 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2485
2486 // Context should be empty when no files are included
2487 assert_eq!(message.role, Role::User);
2488 assert_eq!(message.segments.len(), 1);
2489 assert_eq!(
2490 message.segments[0],
2491 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2492 );
2493 assert_eq!(message.context, "");
2494
2495 // Check message in request
2496 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2497
2498 assert_eq!(request.messages.len(), 2);
2499 assert_eq!(
2500 request.messages[1].string_contents(),
2501 "What is the best way to learn Rust?"
2502 );
2503
2504 // Add second message, also without context
2505 let message2_id = thread.update(cx, |thread, cx| {
2506 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2507 });
2508
2509 let message2 =
2510 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2511 assert_eq!(message2.context, "");
2512
2513 // Check that both messages appear in the request
2514 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2515
2516 assert_eq!(request.messages.len(), 3);
2517 assert_eq!(
2518 request.messages[1].string_contents(),
2519 "What is the best way to learn Rust?"
2520 );
2521 assert_eq!(
2522 request.messages[2].string_contents(),
2523 "Are there any good books?"
2524 );
2525 }
2526
2527 #[gpui::test]
2528 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2529 init_test_settings(cx);
2530
2531 let project = create_test_project(
2532 cx,
2533 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2534 )
2535 .await;
2536
2537 let (_workspace, _thread_store, thread, context_store) =
2538 setup_test_environment(cx, project.clone()).await;
2539
2540 // Open buffer and add it to context
2541 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2542 .await
2543 .unwrap();
2544
2545 let context =
2546 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2547
2548 // Insert user message with the buffer as context
2549 thread.update(cx, |thread, cx| {
2550 thread.insert_user_message("Explain this code", vec![context], None, cx)
2551 });
2552
2553 // Create a request and check that it doesn't have a stale buffer warning yet
2554 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2555
2556 // Make sure we don't have a stale file warning yet
2557 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2558 msg.string_contents()
2559 .contains("These files changed since last read:")
2560 });
2561 assert!(
2562 !has_stale_warning,
2563 "Should not have stale buffer warning before buffer is modified"
2564 );
2565
2566 // Modify the buffer
2567 buffer.update(cx, |buffer, cx| {
2568 // Find a position at the end of line 1
2569 buffer.edit(
2570 [(1..1, "\n println!(\"Added a new line\");\n")],
2571 None,
2572 cx,
2573 );
2574 });
2575
2576 // Insert another user message without context
2577 thread.update(cx, |thread, cx| {
2578 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2579 });
2580
2581 // Create a new request and check for the stale buffer warning
2582 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2583
2584 // We should have a stale file warning as the last message
2585 let last_message = new_request
2586 .messages
2587 .last()
2588 .expect("Request should have messages");
2589
2590 // The last message should be the stale buffer notification
2591 assert_eq!(last_message.role, Role::User);
2592
2593 // Check the exact content of the message
2594 let expected_content = "These files changed since last read:\n- code.rs\n";
2595 assert_eq!(
2596 last_message.string_contents(),
2597 expected_content,
2598 "Last message should be exactly the stale buffer notification"
2599 );
2600 }
2601
2602 fn init_test_settings(cx: &mut TestAppContext) {
2603 cx.update(|cx| {
2604 let settings_store = SettingsStore::test(cx);
2605 cx.set_global(settings_store);
2606 language::init(cx);
2607 Project::init_settings(cx);
2608 AssistantSettings::register(cx);
2609 prompt_store::init(cx);
2610 thread_store::init(cx);
2611 workspace::init_settings(cx);
2612 ThemeSettings::register(cx);
2613 ContextServerSettings::register(cx);
2614 EditorSettings::register(cx);
2615 });
2616 }
2617
2618 // Helper to create a test project with test files
2619 async fn create_test_project(
2620 cx: &mut TestAppContext,
2621 files: serde_json::Value,
2622 ) -> Entity<Project> {
2623 let fs = FakeFs::new(cx.executor());
2624 fs.insert_tree(path!("/test"), files).await;
2625 Project::test(fs, [path!("/test").as_ref()], cx).await
2626 }
2627
2628 async fn setup_test_environment(
2629 cx: &mut TestAppContext,
2630 project: Entity<Project>,
2631 ) -> (
2632 Entity<Workspace>,
2633 Entity<ThreadStore>,
2634 Entity<Thread>,
2635 Entity<ContextStore>,
2636 ) {
2637 let (workspace, cx) =
2638 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2639
2640 let thread_store = cx
2641 .update(|_, cx| {
2642 ThreadStore::load(
2643 project.clone(),
2644 cx.new(|_| ToolWorkingSet::default()),
2645 Arc::new(PromptBuilder::new(None).unwrap()),
2646 cx,
2647 )
2648 })
2649 .await
2650 .unwrap();
2651
2652 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2653 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2654
2655 (workspace, thread_store, thread, context_store)
2656 }
2657
2658 async fn add_file_to_context(
2659 project: &Entity<Project>,
2660 context_store: &Entity<ContextStore>,
2661 path: &str,
2662 cx: &mut TestAppContext,
2663 ) -> Result<Entity<language::Buffer>> {
2664 let buffer_path = project
2665 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2666 .unwrap();
2667
2668 let buffer = project
2669 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2670 .await
2671 .unwrap();
2672
2673 context_store
2674 .update(cx, |store, cx| {
2675 store.add_file_from_buffer(buffer.clone(), cx)
2676 })
2677 .await?;
2678
2679 Ok(buffer)
2680 }
2681}