1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100 pub images: Vec<LanguageModelImage>,
101}
102
103impl Message {
104 /// Returns whether the message contains any meaningful text that should be displayed
105 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
106 pub fn should_display_content(&self) -> bool {
107 self.segments.iter().all(|segment| segment.should_display())
108 }
109
110 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
111 if let Some(MessageSegment::Thinking {
112 text: segment,
113 signature: current_signature,
114 }) = self.segments.last_mut()
115 {
116 if let Some(signature) = signature {
117 *current_signature = Some(signature);
118 }
119 segment.push_str(text);
120 } else {
121 self.segments.push(MessageSegment::Thinking {
122 text: text.to_string(),
123 signature,
124 });
125 }
126 }
127
128 pub fn push_text(&mut self, text: &str) {
129 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
130 segment.push_str(text);
131 } else {
132 self.segments.push(MessageSegment::Text(text.to_string()));
133 }
134 }
135
136 pub fn to_string(&self) -> String {
137 let mut result = String::new();
138
139 if !self.context.is_empty() {
140 result.push_str(&self.context);
141 }
142
143 for segment in &self.segments {
144 match segment {
145 MessageSegment::Text(text) => result.push_str(text),
146 MessageSegment::Thinking { text, .. } => {
147 result.push_str("<think>\n");
148 result.push_str(text);
149 result.push_str("\n</think>");
150 }
151 MessageSegment::RedactedThinking(_) => {}
152 }
153 }
154
155 result
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub enum MessageSegment {
161 Text(String),
162 Thinking {
163 text: String,
164 signature: Option<String>,
165 },
166 RedactedThinking(Vec<u8>),
167}
168
169impl MessageSegment {
170 pub fn should_display(&self) -> bool {
171 match self {
172 Self::Text(text) => text.is_empty(),
173 Self::Thinking { text, .. } => text.is_empty(),
174 Self::RedactedThinking(_) => false,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ProjectSnapshot {
181 pub worktree_snapshots: Vec<WorktreeSnapshot>,
182 pub unsaved_buffer_paths: Vec<String>,
183 pub timestamp: DateTime<Utc>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct WorktreeSnapshot {
188 pub worktree_path: String,
189 pub git_state: Option<GitState>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GitState {
194 pub remote_url: Option<String>,
195 pub head_sha: Option<String>,
196 pub current_branch: Option<String>,
197 pub diff: Option<String>,
198}
199
200#[derive(Clone)]
201pub struct ThreadCheckpoint {
202 message_id: MessageId,
203 git_checkpoint: GitStoreCheckpoint,
204}
205
206#[derive(Copy, Clone, Debug, PartialEq, Eq)]
207pub enum ThreadFeedback {
208 Positive,
209 Negative,
210}
211
212pub enum LastRestoreCheckpoint {
213 Pending {
214 message_id: MessageId,
215 },
216 Error {
217 message_id: MessageId,
218 error: String,
219 },
220}
221
222impl LastRestoreCheckpoint {
223 pub fn message_id(&self) -> MessageId {
224 match self {
225 LastRestoreCheckpoint::Pending { message_id } => *message_id,
226 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
227 }
228 }
229}
230
231#[derive(Clone, Debug, Default, Serialize, Deserialize)]
232pub enum DetailedSummaryState {
233 #[default]
234 NotGenerated,
235 Generating {
236 message_id: MessageId,
237 },
238 Generated {
239 text: SharedString,
240 message_id: MessageId,
241 },
242}
243
244#[derive(Default)]
245pub struct TotalTokenUsage {
246 pub total: usize,
247 pub max: usize,
248}
249
250impl TotalTokenUsage {
251 pub fn ratio(&self) -> TokenUsageRatio {
252 #[cfg(debug_assertions)]
253 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
254 .unwrap_or("0.8".to_string())
255 .parse()
256 .unwrap();
257 #[cfg(not(debug_assertions))]
258 let warning_threshold: f32 = 0.8;
259
260 if self.total >= self.max {
261 TokenUsageRatio::Exceeded
262 } else if self.total as f32 / self.max as f32 >= warning_threshold {
263 TokenUsageRatio::Warning
264 } else {
265 TokenUsageRatio::Normal
266 }
267 }
268
269 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
270 TotalTokenUsage {
271 total: self.total + tokens,
272 max: self.max,
273 }
274 }
275}
276
277#[derive(Debug, Default, PartialEq, Eq)]
278pub enum TokenUsageRatio {
279 #[default]
280 Normal,
281 Warning,
282 Exceeded,
283}
284
285/// A thread of conversation with the LLM.
286pub struct Thread {
287 id: ThreadId,
288 updated_at: DateTime<Utc>,
289 summary: Option<SharedString>,
290 pending_summary: Task<Option<()>>,
291 detailed_summary_state: DetailedSummaryState,
292 messages: Vec<Message>,
293 next_message_id: MessageId,
294 last_prompt_id: PromptId,
295 context: BTreeMap<ContextId, AssistantContext>,
296 context_by_message: HashMap<MessageId, Vec<ContextId>>,
297 project_context: SharedProjectContext,
298 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
299 completion_count: usize,
300 pending_completions: Vec<PendingCompletion>,
301 project: Entity<Project>,
302 prompt_builder: Arc<PromptBuilder>,
303 tools: Entity<ToolWorkingSet>,
304 tool_use: ToolUseState,
305 action_log: Entity<ActionLog>,
306 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
307 pending_checkpoint: Option<ThreadCheckpoint>,
308 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
309 request_token_usage: Vec<TokenUsage>,
310 cumulative_token_usage: TokenUsage,
311 exceeded_window_error: Option<ExceededWindowError>,
312 feedback: Option<ThreadFeedback>,
313 message_feedback: HashMap<MessageId, ThreadFeedback>,
314 last_auto_capture_at: Option<Instant>,
315 request_callback: Option<
316 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
317 >,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct ExceededWindowError {
322 /// Model used when last message exceeded context window
323 model_id: LanguageModelId,
324 /// Token count including last message
325 token_count: usize,
326}
327
328impl Thread {
329 pub fn new(
330 project: Entity<Project>,
331 tools: Entity<ToolWorkingSet>,
332 prompt_builder: Arc<PromptBuilder>,
333 system_prompt: SharedProjectContext,
334 cx: &mut Context<Self>,
335 ) -> Self {
336 Self {
337 id: ThreadId::new(),
338 updated_at: Utc::now(),
339 summary: None,
340 pending_summary: Task::ready(None),
341 detailed_summary_state: DetailedSummaryState::NotGenerated,
342 messages: Vec::new(),
343 next_message_id: MessageId(0),
344 last_prompt_id: PromptId::new(),
345 context: BTreeMap::default(),
346 context_by_message: HashMap::default(),
347 project_context: system_prompt,
348 checkpoints_by_message: HashMap::default(),
349 completion_count: 0,
350 pending_completions: Vec::new(),
351 project: project.clone(),
352 prompt_builder,
353 tools: tools.clone(),
354 last_restore_checkpoint: None,
355 pending_checkpoint: None,
356 tool_use: ToolUseState::new(tools.clone()),
357 action_log: cx.new(|_| ActionLog::new(project.clone())),
358 initial_project_snapshot: {
359 let project_snapshot = Self::project_snapshot(project, cx);
360 cx.foreground_executor()
361 .spawn(async move { Some(project_snapshot.await) })
362 .shared()
363 },
364 request_token_usage: Vec::new(),
365 cumulative_token_usage: TokenUsage::default(),
366 exceeded_window_error: None,
367 feedback: None,
368 message_feedback: HashMap::default(),
369 last_auto_capture_at: None,
370 request_callback: None,
371 }
372 }
373
374 pub fn deserialize(
375 id: ThreadId,
376 serialized: SerializedThread,
377 project: Entity<Project>,
378 tools: Entity<ToolWorkingSet>,
379 prompt_builder: Arc<PromptBuilder>,
380 project_context: SharedProjectContext,
381 cx: &mut Context<Self>,
382 ) -> Self {
383 let next_message_id = MessageId(
384 serialized
385 .messages
386 .last()
387 .map(|message| message.id.0 + 1)
388 .unwrap_or(0),
389 );
390 let tool_use =
391 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
392
393 Self {
394 id,
395 updated_at: serialized.updated_at,
396 summary: Some(serialized.summary),
397 pending_summary: Task::ready(None),
398 detailed_summary_state: serialized.detailed_summary_state,
399 messages: serialized
400 .messages
401 .into_iter()
402 .map(|message| Message {
403 id: message.id,
404 role: message.role,
405 segments: message
406 .segments
407 .into_iter()
408 .map(|segment| match segment {
409 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
410 SerializedMessageSegment::Thinking { text, signature } => {
411 MessageSegment::Thinking { text, signature }
412 }
413 SerializedMessageSegment::RedactedThinking { data } => {
414 MessageSegment::RedactedThinking(data)
415 }
416 })
417 .collect(),
418 context: message.context,
419 images: Vec::new(),
420 })
421 .collect(),
422 next_message_id,
423 last_prompt_id: PromptId::new(),
424 context: BTreeMap::default(),
425 context_by_message: HashMap::default(),
426 project_context,
427 checkpoints_by_message: HashMap::default(),
428 completion_count: 0,
429 pending_completions: Vec::new(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 project: project.clone(),
433 prompt_builder,
434 tools,
435 tool_use,
436 action_log: cx.new(|_| ActionLog::new(project)),
437 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
438 request_token_usage: serialized.request_token_usage,
439 cumulative_token_usage: serialized.cumulative_token_usage,
440 exceeded_window_error: None,
441 feedback: None,
442 message_feedback: HashMap::default(),
443 last_auto_capture_at: None,
444 request_callback: None,
445 }
446 }
447
448 pub fn set_request_callback(
449 &mut self,
450 callback: impl 'static
451 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
452 ) {
453 self.request_callback = Some(Box::new(callback));
454 }
455
456 pub fn id(&self) -> &ThreadId {
457 &self.id
458 }
459
460 pub fn is_empty(&self) -> bool {
461 self.messages.is_empty()
462 }
463
464 pub fn updated_at(&self) -> DateTime<Utc> {
465 self.updated_at
466 }
467
468 pub fn touch_updated_at(&mut self) {
469 self.updated_at = Utc::now();
470 }
471
472 pub fn advance_prompt_id(&mut self) {
473 self.last_prompt_id = PromptId::new();
474 }
475
476 pub fn summary(&self) -> Option<SharedString> {
477 self.summary.clone()
478 }
479
480 pub fn project_context(&self) -> SharedProjectContext {
481 self.project_context.clone()
482 }
483
484 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
485
486 pub fn summary_or_default(&self) -> SharedString {
487 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
488 }
489
490 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
491 let Some(current_summary) = &self.summary else {
492 // Don't allow setting summary until generated
493 return;
494 };
495
496 let mut new_summary = new_summary.into();
497
498 if new_summary.is_empty() {
499 new_summary = Self::DEFAULT_SUMMARY;
500 }
501
502 if current_summary != &new_summary {
503 self.summary = Some(new_summary);
504 cx.emit(ThreadEvent::SummaryChanged);
505 }
506 }
507
508 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
509 self.latest_detailed_summary()
510 .unwrap_or_else(|| self.text().into())
511 }
512
513 fn latest_detailed_summary(&self) -> Option<SharedString> {
514 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
515 Some(text.clone())
516 } else {
517 None
518 }
519 }
520
521 pub fn message(&self, id: MessageId) -> Option<&Message> {
522 self.messages.iter().find(|message| message.id == id)
523 }
524
525 pub fn messages(&self) -> impl Iterator<Item = &Message> {
526 self.messages.iter()
527 }
528
529 pub fn is_generating(&self) -> bool {
530 !self.pending_completions.is_empty() || !self.all_tools_finished()
531 }
532
533 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
534 &self.tools
535 }
536
537 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
538 self.tool_use
539 .pending_tool_uses()
540 .into_iter()
541 .find(|tool_use| &tool_use.id == id)
542 }
543
544 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
545 self.tool_use
546 .pending_tool_uses()
547 .into_iter()
548 .filter(|tool_use| tool_use.status.needs_confirmation())
549 }
550
551 pub fn has_pending_tool_uses(&self) -> bool {
552 !self.tool_use.pending_tool_uses().is_empty()
553 }
554
555 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
556 self.checkpoints_by_message.get(&id).cloned()
557 }
558
559 pub fn restore_checkpoint(
560 &mut self,
561 checkpoint: ThreadCheckpoint,
562 cx: &mut Context<Self>,
563 ) -> Task<Result<()>> {
564 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
565 message_id: checkpoint.message_id,
566 });
567 cx.emit(ThreadEvent::CheckpointChanged);
568 cx.notify();
569
570 let git_store = self.project().read(cx).git_store().clone();
571 let restore = git_store.update(cx, |git_store, cx| {
572 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
573 });
574
575 cx.spawn(async move |this, cx| {
576 let result = restore.await;
577 this.update(cx, |this, cx| {
578 if let Err(err) = result.as_ref() {
579 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
580 message_id: checkpoint.message_id,
581 error: err.to_string(),
582 });
583 } else {
584 this.truncate(checkpoint.message_id, cx);
585 this.last_restore_checkpoint = None;
586 }
587 this.pending_checkpoint = None;
588 cx.emit(ThreadEvent::CheckpointChanged);
589 cx.notify();
590 })?;
591 result
592 })
593 }
594
595 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
596 let pending_checkpoint = if self.is_generating() {
597 return;
598 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
599 checkpoint
600 } else {
601 return;
602 };
603
604 let git_store = self.project.read(cx).git_store().clone();
605 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
606 cx.spawn(async move |this, cx| match final_checkpoint.await {
607 Ok(final_checkpoint) => {
608 let equal = git_store
609 .update(cx, |store, cx| {
610 store.compare_checkpoints(
611 pending_checkpoint.git_checkpoint.clone(),
612 final_checkpoint.clone(),
613 cx,
614 )
615 })?
616 .await
617 .unwrap_or(false);
618
619 if equal {
620 git_store
621 .update(cx, |store, cx| {
622 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
623 })?
624 .detach();
625 } else {
626 this.update(cx, |this, cx| {
627 this.insert_checkpoint(pending_checkpoint, cx)
628 })?;
629 }
630
631 git_store
632 .update(cx, |store, cx| {
633 store.delete_checkpoint(final_checkpoint, cx)
634 })?
635 .detach();
636
637 Ok(())
638 }
639 Err(_) => this.update(cx, |this, cx| {
640 this.insert_checkpoint(pending_checkpoint, cx)
641 }),
642 })
643 .detach();
644 }
645
646 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
647 self.checkpoints_by_message
648 .insert(checkpoint.message_id, checkpoint);
649 cx.emit(ThreadEvent::CheckpointChanged);
650 cx.notify();
651 }
652
653 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
654 self.last_restore_checkpoint.as_ref()
655 }
656
657 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
658 let Some(message_ix) = self
659 .messages
660 .iter()
661 .rposition(|message| message.id == message_id)
662 else {
663 return;
664 };
665 for deleted_message in self.messages.drain(message_ix..) {
666 self.context_by_message.remove(&deleted_message.id);
667 self.checkpoints_by_message.remove(&deleted_message.id);
668 }
669 cx.notify();
670 }
671
672 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
673 self.context_by_message
674 .get(&id)
675 .into_iter()
676 .flat_map(|context| {
677 context
678 .iter()
679 .filter_map(|context_id| self.context.get(&context_id))
680 })
681 }
682
683 /// Returns whether all of the tool uses have finished running.
684 pub fn all_tools_finished(&self) -> bool {
685 // If the only pending tool uses left are the ones with errors, then
686 // that means that we've finished running all of the pending tools.
687 self.tool_use
688 .pending_tool_uses()
689 .iter()
690 .all(|tool_use| tool_use.status.is_error())
691 }
692
693 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
694 self.tool_use.tool_uses_for_message(id, cx)
695 }
696
697 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
698 self.tool_use.tool_results_for_message(id)
699 }
700
701 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
702 self.tool_use.tool_result(id)
703 }
704
705 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
706 Some(&self.tool_use.tool_result(id)?.content)
707 }
708
709 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
710 self.tool_use.tool_result_card(id).cloned()
711 }
712
713 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
714 self.tool_use.message_has_tool_results(message_id)
715 }
716
717 /// Filter out contexts that have already been included in previous messages
718 pub fn filter_new_context<'a>(
719 &self,
720 context: impl Iterator<Item = &'a AssistantContext>,
721 ) -> impl Iterator<Item = &'a AssistantContext> {
722 context.filter(|ctx| self.is_context_new(ctx))
723 }
724
725 fn is_context_new(&self, context: &AssistantContext) -> bool {
726 !self.context.contains_key(&context.id())
727 }
728
729 pub fn insert_user_message(
730 &mut self,
731 text: impl Into<String>,
732 context: Vec<AssistantContext>,
733 git_checkpoint: Option<GitStoreCheckpoint>,
734 cx: &mut Context<Self>,
735 ) -> MessageId {
736 let text = text.into();
737
738 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
739
740 let new_context: Vec<_> = context
741 .into_iter()
742 .filter(|ctx| self.is_context_new(ctx))
743 .collect();
744
745 if !new_context.is_empty() {
746 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
747 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
748 message.context = context_string;
749 }
750 }
751
752 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
753 message.images = new_context
754 .iter()
755 .filter_map(|context| {
756 if let AssistantContext::Image(image_context) = context {
757 image_context.image_task.clone().now_or_never().flatten()
758 } else {
759 None
760 }
761 })
762 .collect::<Vec<_>>();
763 }
764
765 self.action_log.update(cx, |log, cx| {
766 // Track all buffers added as context
767 for ctx in &new_context {
768 match ctx {
769 AssistantContext::File(file_ctx) => {
770 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
771 }
772 AssistantContext::Directory(dir_ctx) => {
773 for context_buffer in &dir_ctx.context_buffers {
774 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
775 }
776 }
777 AssistantContext::Symbol(symbol_ctx) => {
778 log.buffer_added_as_context(
779 symbol_ctx.context_symbol.buffer.clone(),
780 cx,
781 );
782 }
783 AssistantContext::Selection(selection_context) => {
784 log.buffer_added_as_context(
785 selection_context.context_buffer.buffer.clone(),
786 cx,
787 );
788 }
789 AssistantContext::FetchedUrl(_)
790 | AssistantContext::Thread(_)
791 | AssistantContext::Rules(_)
792 | AssistantContext::Image(_) => {}
793 }
794 }
795 });
796 }
797
798 let context_ids = new_context
799 .iter()
800 .map(|context| context.id())
801 .collect::<Vec<_>>();
802 self.context.extend(
803 new_context
804 .into_iter()
805 .map(|context| (context.id(), context)),
806 );
807 self.context_by_message.insert(message_id, context_ids);
808
809 if let Some(git_checkpoint) = git_checkpoint {
810 self.pending_checkpoint = Some(ThreadCheckpoint {
811 message_id,
812 git_checkpoint,
813 });
814 }
815
816 self.auto_capture_telemetry(cx);
817
818 message_id
819 }
820
821 pub fn insert_message(
822 &mut self,
823 role: Role,
824 segments: Vec<MessageSegment>,
825 cx: &mut Context<Self>,
826 ) -> MessageId {
827 let id = self.next_message_id.post_inc();
828 self.messages.push(Message {
829 id,
830 role,
831 segments,
832 context: String::new(),
833 images: Vec::new(),
834 });
835 self.touch_updated_at();
836 cx.emit(ThreadEvent::MessageAdded(id));
837 id
838 }
839
840 pub fn edit_message(
841 &mut self,
842 id: MessageId,
843 new_role: Role,
844 new_segments: Vec<MessageSegment>,
845 cx: &mut Context<Self>,
846 ) -> bool {
847 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
848 return false;
849 };
850 message.role = new_role;
851 message.segments = new_segments;
852 self.touch_updated_at();
853 cx.emit(ThreadEvent::MessageEdited(id));
854 true
855 }
856
857 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
858 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
859 return false;
860 };
861 self.messages.remove(index);
862 self.context_by_message.remove(&id);
863 self.touch_updated_at();
864 cx.emit(ThreadEvent::MessageDeleted(id));
865 true
866 }
867
868 /// Returns the representation of this [`Thread`] in a textual form.
869 ///
870 /// This is the representation we use when attaching a thread as context to another thread.
871 pub fn text(&self) -> String {
872 let mut text = String::new();
873
874 for message in &self.messages {
875 text.push_str(match message.role {
876 language_model::Role::User => "User:",
877 language_model::Role::Assistant => "Assistant:",
878 language_model::Role::System => "System:",
879 });
880 text.push('\n');
881
882 for segment in &message.segments {
883 match segment {
884 MessageSegment::Text(content) => text.push_str(content),
885 MessageSegment::Thinking { text: content, .. } => {
886 text.push_str(&format!("<think>{}</think>", content))
887 }
888 MessageSegment::RedactedThinking(_) => {}
889 }
890 }
891 text.push('\n');
892 }
893
894 text
895 }
896
897 /// Serializes this thread into a format for storage or telemetry.
898 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
899 let initial_project_snapshot = self.initial_project_snapshot.clone();
900 cx.spawn(async move |this, cx| {
901 let initial_project_snapshot = initial_project_snapshot.await;
902 this.read_with(cx, |this, cx| SerializedThread {
903 version: SerializedThread::VERSION.to_string(),
904 summary: this.summary_or_default(),
905 updated_at: this.updated_at(),
906 messages: this
907 .messages()
908 .map(|message| SerializedMessage {
909 id: message.id,
910 role: message.role,
911 segments: message
912 .segments
913 .iter()
914 .map(|segment| match segment {
915 MessageSegment::Text(text) => {
916 SerializedMessageSegment::Text { text: text.clone() }
917 }
918 MessageSegment::Thinking { text, signature } => {
919 SerializedMessageSegment::Thinking {
920 text: text.clone(),
921 signature: signature.clone(),
922 }
923 }
924 MessageSegment::RedactedThinking(data) => {
925 SerializedMessageSegment::RedactedThinking {
926 data: data.clone(),
927 }
928 }
929 })
930 .collect(),
931 tool_uses: this
932 .tool_uses_for_message(message.id, cx)
933 .into_iter()
934 .map(|tool_use| SerializedToolUse {
935 id: tool_use.id,
936 name: tool_use.name,
937 input: tool_use.input,
938 })
939 .collect(),
940 tool_results: this
941 .tool_results_for_message(message.id)
942 .into_iter()
943 .map(|tool_result| SerializedToolResult {
944 tool_use_id: tool_result.tool_use_id.clone(),
945 is_error: tool_result.is_error,
946 content: tool_result.content.clone(),
947 })
948 .collect(),
949 context: message.context.clone(),
950 })
951 .collect(),
952 initial_project_snapshot,
953 cumulative_token_usage: this.cumulative_token_usage,
954 request_token_usage: this.request_token_usage.clone(),
955 detailed_summary_state: this.detailed_summary_state.clone(),
956 exceeded_window_error: this.exceeded_window_error.clone(),
957 })
958 })
959 }
960
961 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
962 let mut request = self.to_completion_request(cx);
963 if model.supports_tools() {
964 request.tools = {
965 let mut tools = Vec::new();
966 tools.extend(
967 self.tools()
968 .read(cx)
969 .enabled_tools(cx)
970 .into_iter()
971 .filter_map(|tool| {
972 // Skip tools that cannot be supported
973 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
974 Some(LanguageModelRequestTool {
975 name: tool.name(),
976 description: tool.description(),
977 input_schema,
978 })
979 }),
980 );
981
982 tools
983 };
984 }
985
986 self.stream_completion(request, model, cx);
987 }
988
989 pub fn used_tools_since_last_user_message(&self) -> bool {
990 for message in self.messages.iter().rev() {
991 if self.tool_use.message_has_tool_results(message.id) {
992 return true;
993 } else if message.role == Role::User {
994 return false;
995 }
996 }
997
998 false
999 }
1000
1001 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1002 let mut request = LanguageModelRequest {
1003 thread_id: Some(self.id.to_string()),
1004 prompt_id: Some(self.last_prompt_id.to_string()),
1005 messages: vec![],
1006 tools: Vec::new(),
1007 stop: Vec::new(),
1008 temperature: None,
1009 };
1010
1011 if let Some(project_context) = self.project_context.borrow().as_ref() {
1012 match self
1013 .prompt_builder
1014 .generate_assistant_system_prompt(project_context)
1015 {
1016 Err(err) => {
1017 let message = format!("{err:?}").into();
1018 log::error!("{message}");
1019 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020 header: "Error generating system prompt".into(),
1021 message,
1022 }));
1023 }
1024 Ok(system_prompt) => {
1025 request.messages.push(LanguageModelRequestMessage {
1026 role: Role::System,
1027 content: vec![MessageContent::Text(system_prompt)],
1028 cache: true,
1029 });
1030 }
1031 }
1032 } else {
1033 let message = "Context for system prompt unexpectedly not ready.".into();
1034 log::error!("{message}");
1035 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1036 header: "Error generating system prompt".into(),
1037 message,
1038 }));
1039 }
1040
1041 for message in &self.messages {
1042 let mut request_message = LanguageModelRequestMessage {
1043 role: message.role,
1044 content: Vec::new(),
1045 cache: false,
1046 };
1047
1048 self.tool_use
1049 .attach_tool_results(message.id, &mut request_message);
1050
1051 if !message.context.is_empty() {
1052 request_message
1053 .content
1054 .push(MessageContent::Text(message.context.to_string()));
1055 }
1056
1057 if !message.images.is_empty() {
1058 // Some providers only support image parts after an initial text part
1059 if request_message.content.is_empty() {
1060 request_message
1061 .content
1062 .push(MessageContent::Text("Images attached by user:".to_string()));
1063 }
1064
1065 for image in &message.images {
1066 request_message
1067 .content
1068 .push(MessageContent::Image(image.clone()))
1069 }
1070 }
1071
1072 for segment in &message.segments {
1073 match segment {
1074 MessageSegment::Text(text) => {
1075 if !text.is_empty() {
1076 request_message
1077 .content
1078 .push(MessageContent::Text(text.into()));
1079 }
1080 }
1081 MessageSegment::Thinking { text, signature } => {
1082 if !text.is_empty() {
1083 request_message.content.push(MessageContent::Thinking {
1084 text: text.into(),
1085 signature: signature.clone(),
1086 });
1087 }
1088 }
1089 MessageSegment::RedactedThinking(data) => {
1090 request_message
1091 .content
1092 .push(MessageContent::RedactedThinking(data.clone()));
1093 }
1094 };
1095 }
1096
1097 self.tool_use
1098 .attach_tool_uses(message.id, &mut request_message);
1099
1100 request.messages.push(request_message);
1101 }
1102
1103 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1104 if let Some(last) = request.messages.last_mut() {
1105 last.cache = true;
1106 }
1107
1108 self.attached_tracked_files_state(&mut request.messages, cx);
1109
1110 request
1111 }
1112
1113 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1114 let mut request = LanguageModelRequest {
1115 thread_id: None,
1116 prompt_id: None,
1117 messages: vec![],
1118 tools: Vec::new(),
1119 stop: Vec::new(),
1120 temperature: None,
1121 };
1122
1123 for message in &self.messages {
1124 let mut request_message = LanguageModelRequestMessage {
1125 role: message.role,
1126 content: Vec::new(),
1127 cache: false,
1128 };
1129
1130 // Skip tool results during summarization.
1131 if self.tool_use.message_has_tool_results(message.id) {
1132 continue;
1133 }
1134
1135 for segment in &message.segments {
1136 match segment {
1137 MessageSegment::Text(text) => request_message
1138 .content
1139 .push(MessageContent::Text(text.clone())),
1140 MessageSegment::Thinking { .. } => {}
1141 MessageSegment::RedactedThinking(_) => {}
1142 }
1143 }
1144
1145 if request_message.content.is_empty() {
1146 continue;
1147 }
1148
1149 request.messages.push(request_message);
1150 }
1151
1152 request.messages.push(LanguageModelRequestMessage {
1153 role: Role::User,
1154 content: vec![MessageContent::Text(added_user_message)],
1155 cache: false,
1156 });
1157
1158 request
1159 }
1160
1161 fn attached_tracked_files_state(
1162 &self,
1163 messages: &mut Vec<LanguageModelRequestMessage>,
1164 cx: &App,
1165 ) {
1166 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1167
1168 let mut stale_message = String::new();
1169
1170 let action_log = self.action_log.read(cx);
1171
1172 for stale_file in action_log.stale_buffers(cx) {
1173 let Some(file) = stale_file.read(cx).file() else {
1174 continue;
1175 };
1176
1177 if stale_message.is_empty() {
1178 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1179 }
1180
1181 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1182 }
1183
1184 let mut content = Vec::with_capacity(2);
1185
1186 if !stale_message.is_empty() {
1187 content.push(stale_message.into());
1188 }
1189
1190 if !content.is_empty() {
1191 let context_message = LanguageModelRequestMessage {
1192 role: Role::User,
1193 content,
1194 cache: false,
1195 };
1196
1197 messages.push(context_message);
1198 }
1199 }
1200
1201 pub fn stream_completion(
1202 &mut self,
1203 request: LanguageModelRequest,
1204 model: Arc<dyn LanguageModel>,
1205 cx: &mut Context<Self>,
1206 ) {
1207 let pending_completion_id = post_inc(&mut self.completion_count);
1208 let mut request_callback_parameters = if self.request_callback.is_some() {
1209 Some((request.clone(), Vec::new()))
1210 } else {
1211 None
1212 };
1213 let prompt_id = self.last_prompt_id.clone();
1214 let tool_use_metadata = ToolUseMetadata {
1215 model: model.clone(),
1216 thread_id: self.id.clone(),
1217 prompt_id: prompt_id.clone(),
1218 };
1219
1220 let task = cx.spawn(async move |thread, cx| {
1221 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1222 let initial_token_usage =
1223 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1224 let stream_completion = async {
1225 let (mut events, usage) = stream_completion_future.await?;
1226
1227 let mut stop_reason = StopReason::EndTurn;
1228 let mut current_token_usage = TokenUsage::default();
1229
1230 if let Some(usage) = usage {
1231 thread
1232 .update(cx, |_thread, cx| {
1233 cx.emit(ThreadEvent::UsageUpdated(usage));
1234 })
1235 .ok();
1236 }
1237
1238 while let Some(event) = events.next().await {
1239 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1240 response_events
1241 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1242 }
1243
1244 let event = event?;
1245
1246 thread.update(cx, |thread, cx| {
1247 match event {
1248 LanguageModelCompletionEvent::StartMessage { .. } => {
1249 thread.insert_message(
1250 Role::Assistant,
1251 vec![MessageSegment::Text(String::new())],
1252 cx,
1253 );
1254 }
1255 LanguageModelCompletionEvent::Stop(reason) => {
1256 stop_reason = reason;
1257 }
1258 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1259 thread.update_token_usage_at_last_message(token_usage);
1260 thread.cumulative_token_usage = thread.cumulative_token_usage
1261 + token_usage
1262 - current_token_usage;
1263 current_token_usage = token_usage;
1264 }
1265 LanguageModelCompletionEvent::Text(chunk) => {
1266 if let Some(last_message) = thread.messages.last_mut() {
1267 if last_message.role == Role::Assistant {
1268 last_message.push_text(&chunk);
1269 cx.emit(ThreadEvent::StreamedAssistantText(
1270 last_message.id,
1271 chunk,
1272 ));
1273 } else {
1274 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1275 // of a new Assistant response.
1276 //
1277 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1278 // will result in duplicating the text of the chunk in the rendered Markdown.
1279 thread.insert_message(
1280 Role::Assistant,
1281 vec![MessageSegment::Text(chunk.to_string())],
1282 cx,
1283 );
1284 };
1285 }
1286 }
1287 LanguageModelCompletionEvent::Thinking {
1288 text: chunk,
1289 signature,
1290 } => {
1291 if let Some(last_message) = thread.messages.last_mut() {
1292 if last_message.role == Role::Assistant {
1293 last_message.push_thinking(&chunk, signature);
1294 cx.emit(ThreadEvent::StreamedAssistantThinking(
1295 last_message.id,
1296 chunk,
1297 ));
1298 } else {
1299 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1300 // of a new Assistant response.
1301 //
1302 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1303 // will result in duplicating the text of the chunk in the rendered Markdown.
1304 thread.insert_message(
1305 Role::Assistant,
1306 vec![MessageSegment::Thinking {
1307 text: chunk.to_string(),
1308 signature,
1309 }],
1310 cx,
1311 );
1312 };
1313 }
1314 }
1315 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1316 let last_assistant_message_id = thread
1317 .messages
1318 .iter_mut()
1319 .rfind(|message| message.role == Role::Assistant)
1320 .map(|message| message.id)
1321 .unwrap_or_else(|| {
1322 thread.insert_message(Role::Assistant, vec![], cx)
1323 });
1324
1325 let tool_use_id = tool_use.id.clone();
1326 let streamed_input = if tool_use.is_input_complete {
1327 None
1328 } else {
1329 Some((&tool_use.input).clone())
1330 };
1331
1332 let ui_text = thread.tool_use.request_tool_use(
1333 last_assistant_message_id,
1334 tool_use,
1335 tool_use_metadata.clone(),
1336 cx,
1337 );
1338
1339 if let Some(input) = streamed_input {
1340 cx.emit(ThreadEvent::StreamedToolUse {
1341 tool_use_id,
1342 ui_text,
1343 input,
1344 });
1345 }
1346 }
1347 }
1348
1349 thread.touch_updated_at();
1350 cx.emit(ThreadEvent::StreamedCompletion);
1351 cx.notify();
1352
1353 thread.auto_capture_telemetry(cx);
1354 })?;
1355
1356 smol::future::yield_now().await;
1357 }
1358
1359 thread.update(cx, |thread, cx| {
1360 thread
1361 .pending_completions
1362 .retain(|completion| completion.id != pending_completion_id);
1363
1364 // If there is a response without tool use, summarize the message. Otherwise,
1365 // allow two tool uses before summarizing.
1366 if thread.summary.is_none()
1367 && thread.messages.len() >= 2
1368 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1369 {
1370 thread.summarize(cx);
1371 }
1372 })?;
1373
1374 anyhow::Ok(stop_reason)
1375 };
1376
1377 let result = stream_completion.await;
1378
1379 thread
1380 .update(cx, |thread, cx| {
1381 thread.finalize_pending_checkpoint(cx);
1382 match result.as_ref() {
1383 Ok(stop_reason) => match stop_reason {
1384 StopReason::ToolUse => {
1385 let tool_uses = thread.use_pending_tools(cx);
1386 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1387 }
1388 StopReason::EndTurn => {}
1389 StopReason::MaxTokens => {}
1390 },
1391 Err(error) => {
1392 if error.is::<PaymentRequiredError>() {
1393 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1394 } else if error.is::<MaxMonthlySpendReachedError>() {
1395 cx.emit(ThreadEvent::ShowError(
1396 ThreadError::MaxMonthlySpendReached,
1397 ));
1398 } else if let Some(error) =
1399 error.downcast_ref::<ModelRequestLimitReachedError>()
1400 {
1401 cx.emit(ThreadEvent::ShowError(
1402 ThreadError::ModelRequestLimitReached { plan: error.plan },
1403 ));
1404 } else if let Some(known_error) =
1405 error.downcast_ref::<LanguageModelKnownError>()
1406 {
1407 match known_error {
1408 LanguageModelKnownError::ContextWindowLimitExceeded {
1409 tokens,
1410 } => {
1411 thread.exceeded_window_error = Some(ExceededWindowError {
1412 model_id: model.id(),
1413 token_count: *tokens,
1414 });
1415 cx.notify();
1416 }
1417 }
1418 } else {
1419 let error_message = error
1420 .chain()
1421 .map(|err| err.to_string())
1422 .collect::<Vec<_>>()
1423 .join("\n");
1424 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1425 header: "Error interacting with language model".into(),
1426 message: SharedString::from(error_message.clone()),
1427 }));
1428 }
1429
1430 thread.cancel_last_completion(cx);
1431 }
1432 }
1433 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1434
1435 if let Some((request_callback, (request, response_events))) = thread
1436 .request_callback
1437 .as_mut()
1438 .zip(request_callback_parameters.as_ref())
1439 {
1440 request_callback(request, response_events);
1441 }
1442
1443 thread.auto_capture_telemetry(cx);
1444
1445 if let Ok(initial_usage) = initial_token_usage {
1446 let usage = thread.cumulative_token_usage - initial_usage;
1447
1448 telemetry::event!(
1449 "Assistant Thread Completion",
1450 thread_id = thread.id().to_string(),
1451 prompt_id = prompt_id,
1452 model = model.telemetry_id(),
1453 model_provider = model.provider_id().to_string(),
1454 input_tokens = usage.input_tokens,
1455 output_tokens = usage.output_tokens,
1456 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1457 cache_read_input_tokens = usage.cache_read_input_tokens,
1458 );
1459 }
1460 })
1461 .ok();
1462 });
1463
1464 self.pending_completions.push(PendingCompletion {
1465 id: pending_completion_id,
1466 _task: task,
1467 });
1468 }
1469
1470 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1471 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1472 return;
1473 };
1474
1475 if !model.provider.is_authenticated(cx) {
1476 return;
1477 }
1478
1479 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1480 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1481 If the conversation is about a specific subject, include it in the title. \
1482 Be descriptive. DO NOT speak in the first person.";
1483
1484 let request = self.to_summarize_request(added_user_message.into());
1485
1486 self.pending_summary = cx.spawn(async move |this, cx| {
1487 async move {
1488 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1489 let (mut messages, usage) = stream.await?;
1490
1491 if let Some(usage) = usage {
1492 this.update(cx, |_thread, cx| {
1493 cx.emit(ThreadEvent::UsageUpdated(usage));
1494 })
1495 .ok();
1496 }
1497
1498 let mut new_summary = String::new();
1499 while let Some(message) = messages.stream.next().await {
1500 let text = message?;
1501 let mut lines = text.lines();
1502 new_summary.extend(lines.next());
1503
1504 // Stop if the LLM generated multiple lines.
1505 if lines.next().is_some() {
1506 break;
1507 }
1508 }
1509
1510 this.update(cx, |this, cx| {
1511 if !new_summary.is_empty() {
1512 this.summary = Some(new_summary.into());
1513 }
1514
1515 cx.emit(ThreadEvent::SummaryGenerated);
1516 })?;
1517
1518 anyhow::Ok(())
1519 }
1520 .log_err()
1521 .await
1522 });
1523 }
1524
1525 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1526 let last_message_id = self.messages.last().map(|message| message.id)?;
1527
1528 match &self.detailed_summary_state {
1529 DetailedSummaryState::Generating { message_id, .. }
1530 | DetailedSummaryState::Generated { message_id, .. }
1531 if *message_id == last_message_id =>
1532 {
1533 // Already up-to-date
1534 return None;
1535 }
1536 _ => {}
1537 }
1538
1539 let ConfiguredModel { model, provider } =
1540 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1541
1542 if !provider.is_authenticated(cx) {
1543 return None;
1544 }
1545
1546 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1547 1. A brief overview of what was discussed\n\
1548 2. Key facts or information discovered\n\
1549 3. Outcomes or conclusions reached\n\
1550 4. Any action items or next steps if any\n\
1551 Format it in Markdown with headings and bullet points.";
1552
1553 let request = self.to_summarize_request(added_user_message.into());
1554
1555 let task = cx.spawn(async move |thread, cx| {
1556 let stream = model.stream_completion_text(request, &cx);
1557 let Some(mut messages) = stream.await.log_err() else {
1558 thread
1559 .update(cx, |this, _cx| {
1560 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1561 })
1562 .log_err();
1563
1564 return;
1565 };
1566
1567 let mut new_detailed_summary = String::new();
1568
1569 while let Some(chunk) = messages.stream.next().await {
1570 if let Some(chunk) = chunk.log_err() {
1571 new_detailed_summary.push_str(&chunk);
1572 }
1573 }
1574
1575 thread
1576 .update(cx, |this, _cx| {
1577 this.detailed_summary_state = DetailedSummaryState::Generated {
1578 text: new_detailed_summary.into(),
1579 message_id: last_message_id,
1580 };
1581 })
1582 .log_err();
1583 });
1584
1585 self.detailed_summary_state = DetailedSummaryState::Generating {
1586 message_id: last_message_id,
1587 };
1588
1589 Some(task)
1590 }
1591
1592 pub fn is_generating_detailed_summary(&self) -> bool {
1593 matches!(
1594 self.detailed_summary_state,
1595 DetailedSummaryState::Generating { .. }
1596 )
1597 }
1598
1599 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1600 self.auto_capture_telemetry(cx);
1601 let request = self.to_completion_request(cx);
1602 let messages = Arc::new(request.messages);
1603 let pending_tool_uses = self
1604 .tool_use
1605 .pending_tool_uses()
1606 .into_iter()
1607 .filter(|tool_use| tool_use.status.is_idle())
1608 .cloned()
1609 .collect::<Vec<_>>();
1610
1611 for tool_use in pending_tool_uses.iter() {
1612 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1613 if tool.needs_confirmation(&tool_use.input, cx)
1614 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1615 {
1616 self.tool_use.confirm_tool_use(
1617 tool_use.id.clone(),
1618 tool_use.ui_text.clone(),
1619 tool_use.input.clone(),
1620 messages.clone(),
1621 tool,
1622 );
1623 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1624 } else {
1625 self.run_tool(
1626 tool_use.id.clone(),
1627 tool_use.ui_text.clone(),
1628 tool_use.input.clone(),
1629 &messages,
1630 tool,
1631 cx,
1632 );
1633 }
1634 }
1635 }
1636
1637 pending_tool_uses
1638 }
1639
1640 pub fn run_tool(
1641 &mut self,
1642 tool_use_id: LanguageModelToolUseId,
1643 ui_text: impl Into<SharedString>,
1644 input: serde_json::Value,
1645 messages: &[LanguageModelRequestMessage],
1646 tool: Arc<dyn Tool>,
1647 cx: &mut Context<Thread>,
1648 ) {
1649 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1650 self.tool_use
1651 .run_pending_tool(tool_use_id, ui_text.into(), task);
1652 }
1653
1654 fn spawn_tool_use(
1655 &mut self,
1656 tool_use_id: LanguageModelToolUseId,
1657 messages: &[LanguageModelRequestMessage],
1658 input: serde_json::Value,
1659 tool: Arc<dyn Tool>,
1660 cx: &mut Context<Thread>,
1661 ) -> Task<()> {
1662 let tool_name: Arc<str> = tool.name().into();
1663
1664 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1665 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1666 } else {
1667 tool.run(
1668 input,
1669 messages,
1670 self.project.clone(),
1671 self.action_log.clone(),
1672 cx,
1673 )
1674 };
1675
1676 // Store the card separately if it exists
1677 if let Some(card) = tool_result.card.clone() {
1678 self.tool_use
1679 .insert_tool_result_card(tool_use_id.clone(), card);
1680 }
1681
1682 cx.spawn({
1683 async move |thread: WeakEntity<Thread>, cx| {
1684 let output = tool_result.output.await;
1685
1686 thread
1687 .update(cx, |thread, cx| {
1688 let pending_tool_use = thread.tool_use.insert_tool_output(
1689 tool_use_id.clone(),
1690 tool_name,
1691 output,
1692 cx,
1693 );
1694 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1695 })
1696 .ok();
1697 }
1698 })
1699 }
1700
1701 fn tool_finished(
1702 &mut self,
1703 tool_use_id: LanguageModelToolUseId,
1704 pending_tool_use: Option<PendingToolUse>,
1705 canceled: bool,
1706 cx: &mut Context<Self>,
1707 ) {
1708 if self.all_tools_finished() {
1709 let model_registry = LanguageModelRegistry::read_global(cx);
1710 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1711 self.attach_tool_results(cx);
1712 if !canceled {
1713 self.send_to_model(model, cx);
1714 }
1715 }
1716 }
1717
1718 cx.emit(ThreadEvent::ToolFinished {
1719 tool_use_id,
1720 pending_tool_use,
1721 });
1722 }
1723
1724 /// Insert an empty message to be populated with tool results upon send.
1725 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1726 // Tool results are assumed to be waiting on the next message id, so they will populate
1727 // this empty message before sending to model. Would prefer this to be more straightforward.
1728 self.insert_message(Role::User, vec![], cx);
1729 self.auto_capture_telemetry(cx);
1730 }
1731
1732 /// Cancels the last pending completion, if there are any pending.
1733 ///
1734 /// Returns whether a completion was canceled.
1735 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1736 let canceled = if self.pending_completions.pop().is_some() {
1737 true
1738 } else {
1739 let mut canceled = false;
1740 for pending_tool_use in self.tool_use.cancel_pending() {
1741 canceled = true;
1742 self.tool_finished(
1743 pending_tool_use.id.clone(),
1744 Some(pending_tool_use),
1745 true,
1746 cx,
1747 );
1748 }
1749 canceled
1750 };
1751 self.finalize_pending_checkpoint(cx);
1752 canceled
1753 }
1754
1755 pub fn feedback(&self) -> Option<ThreadFeedback> {
1756 self.feedback
1757 }
1758
1759 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1760 self.message_feedback.get(&message_id).copied()
1761 }
1762
1763 pub fn report_message_feedback(
1764 &mut self,
1765 message_id: MessageId,
1766 feedback: ThreadFeedback,
1767 cx: &mut Context<Self>,
1768 ) -> Task<Result<()>> {
1769 if self.message_feedback.get(&message_id) == Some(&feedback) {
1770 return Task::ready(Ok(()));
1771 }
1772
1773 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1774 let serialized_thread = self.serialize(cx);
1775 let thread_id = self.id().clone();
1776 let client = self.project.read(cx).client();
1777
1778 let enabled_tool_names: Vec<String> = self
1779 .tools()
1780 .read(cx)
1781 .enabled_tools(cx)
1782 .iter()
1783 .map(|tool| tool.name().to_string())
1784 .collect();
1785
1786 self.message_feedback.insert(message_id, feedback);
1787
1788 cx.notify();
1789
1790 let message_content = self
1791 .message(message_id)
1792 .map(|msg| msg.to_string())
1793 .unwrap_or_default();
1794
1795 cx.background_spawn(async move {
1796 let final_project_snapshot = final_project_snapshot.await;
1797 let serialized_thread = serialized_thread.await?;
1798 let thread_data =
1799 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1800
1801 let rating = match feedback {
1802 ThreadFeedback::Positive => "positive",
1803 ThreadFeedback::Negative => "negative",
1804 };
1805 telemetry::event!(
1806 "Assistant Thread Rated",
1807 rating,
1808 thread_id,
1809 enabled_tool_names,
1810 message_id = message_id.0,
1811 message_content,
1812 thread_data,
1813 final_project_snapshot
1814 );
1815 client.telemetry().flush_events();
1816
1817 Ok(())
1818 })
1819 }
1820
1821 pub fn report_feedback(
1822 &mut self,
1823 feedback: ThreadFeedback,
1824 cx: &mut Context<Self>,
1825 ) -> Task<Result<()>> {
1826 let last_assistant_message_id = self
1827 .messages
1828 .iter()
1829 .rev()
1830 .find(|msg| msg.role == Role::Assistant)
1831 .map(|msg| msg.id);
1832
1833 if let Some(message_id) = last_assistant_message_id {
1834 self.report_message_feedback(message_id, feedback, cx)
1835 } else {
1836 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1837 let serialized_thread = self.serialize(cx);
1838 let thread_id = self.id().clone();
1839 let client = self.project.read(cx).client();
1840 self.feedback = Some(feedback);
1841 cx.notify();
1842
1843 cx.background_spawn(async move {
1844 let final_project_snapshot = final_project_snapshot.await;
1845 let serialized_thread = serialized_thread.await?;
1846 let thread_data = serde_json::to_value(serialized_thread)
1847 .unwrap_or_else(|_| serde_json::Value::Null);
1848
1849 let rating = match feedback {
1850 ThreadFeedback::Positive => "positive",
1851 ThreadFeedback::Negative => "negative",
1852 };
1853 telemetry::event!(
1854 "Assistant Thread Rated",
1855 rating,
1856 thread_id,
1857 thread_data,
1858 final_project_snapshot
1859 );
1860 client.telemetry().flush_events();
1861
1862 Ok(())
1863 })
1864 }
1865 }
1866
1867 /// Create a snapshot of the current project state including git information and unsaved buffers.
1868 fn project_snapshot(
1869 project: Entity<Project>,
1870 cx: &mut Context<Self>,
1871 ) -> Task<Arc<ProjectSnapshot>> {
1872 let git_store = project.read(cx).git_store().clone();
1873 let worktree_snapshots: Vec<_> = project
1874 .read(cx)
1875 .visible_worktrees(cx)
1876 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1877 .collect();
1878
1879 cx.spawn(async move |_, cx| {
1880 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1881
1882 let mut unsaved_buffers = Vec::new();
1883 cx.update(|app_cx| {
1884 let buffer_store = project.read(app_cx).buffer_store();
1885 for buffer_handle in buffer_store.read(app_cx).buffers() {
1886 let buffer = buffer_handle.read(app_cx);
1887 if buffer.is_dirty() {
1888 if let Some(file) = buffer.file() {
1889 let path = file.path().to_string_lossy().to_string();
1890 unsaved_buffers.push(path);
1891 }
1892 }
1893 }
1894 })
1895 .ok();
1896
1897 Arc::new(ProjectSnapshot {
1898 worktree_snapshots,
1899 unsaved_buffer_paths: unsaved_buffers,
1900 timestamp: Utc::now(),
1901 })
1902 })
1903 }
1904
1905 fn worktree_snapshot(
1906 worktree: Entity<project::Worktree>,
1907 git_store: Entity<GitStore>,
1908 cx: &App,
1909 ) -> Task<WorktreeSnapshot> {
1910 cx.spawn(async move |cx| {
1911 // Get worktree path and snapshot
1912 let worktree_info = cx.update(|app_cx| {
1913 let worktree = worktree.read(app_cx);
1914 let path = worktree.abs_path().to_string_lossy().to_string();
1915 let snapshot = worktree.snapshot();
1916 (path, snapshot)
1917 });
1918
1919 let Ok((worktree_path, _snapshot)) = worktree_info else {
1920 return WorktreeSnapshot {
1921 worktree_path: String::new(),
1922 git_state: None,
1923 };
1924 };
1925
1926 let git_state = git_store
1927 .update(cx, |git_store, cx| {
1928 git_store
1929 .repositories()
1930 .values()
1931 .find(|repo| {
1932 repo.read(cx)
1933 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1934 .is_some()
1935 })
1936 .cloned()
1937 })
1938 .ok()
1939 .flatten()
1940 .map(|repo| {
1941 repo.update(cx, |repo, _| {
1942 let current_branch =
1943 repo.branch.as_ref().map(|branch| branch.name.to_string());
1944 repo.send_job(None, |state, _| async move {
1945 let RepositoryState::Local { backend, .. } = state else {
1946 return GitState {
1947 remote_url: None,
1948 head_sha: None,
1949 current_branch,
1950 diff: None,
1951 };
1952 };
1953
1954 let remote_url = backend.remote_url("origin");
1955 let head_sha = backend.head_sha();
1956 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1957
1958 GitState {
1959 remote_url,
1960 head_sha,
1961 current_branch,
1962 diff,
1963 }
1964 })
1965 })
1966 });
1967
1968 let git_state = match git_state {
1969 Some(git_state) => match git_state.ok() {
1970 Some(git_state) => git_state.await.ok(),
1971 None => None,
1972 },
1973 None => None,
1974 };
1975
1976 WorktreeSnapshot {
1977 worktree_path,
1978 git_state,
1979 }
1980 })
1981 }
1982
1983 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1984 let mut markdown = Vec::new();
1985
1986 if let Some(summary) = self.summary() {
1987 writeln!(markdown, "# {summary}\n")?;
1988 };
1989
1990 for message in self.messages() {
1991 writeln!(
1992 markdown,
1993 "## {role}\n",
1994 role = match message.role {
1995 Role::User => "User",
1996 Role::Assistant => "Assistant",
1997 Role::System => "System",
1998 }
1999 )?;
2000
2001 if !message.context.is_empty() {
2002 writeln!(markdown, "{}", message.context)?;
2003 }
2004
2005 for segment in &message.segments {
2006 match segment {
2007 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2008 MessageSegment::Thinking { text, .. } => {
2009 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2010 }
2011 MessageSegment::RedactedThinking(_) => {}
2012 }
2013 }
2014
2015 for tool_use in self.tool_uses_for_message(message.id, cx) {
2016 writeln!(
2017 markdown,
2018 "**Use Tool: {} ({})**",
2019 tool_use.name, tool_use.id
2020 )?;
2021 writeln!(markdown, "```json")?;
2022 writeln!(
2023 markdown,
2024 "{}",
2025 serde_json::to_string_pretty(&tool_use.input)?
2026 )?;
2027 writeln!(markdown, "```")?;
2028 }
2029
2030 for tool_result in self.tool_results_for_message(message.id) {
2031 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2032 if tool_result.is_error {
2033 write!(markdown, " (Error)")?;
2034 }
2035
2036 writeln!(markdown, "**\n")?;
2037 writeln!(markdown, "{}", tool_result.content)?;
2038 }
2039 }
2040
2041 Ok(String::from_utf8_lossy(&markdown).to_string())
2042 }
2043
2044 pub fn keep_edits_in_range(
2045 &mut self,
2046 buffer: Entity<language::Buffer>,
2047 buffer_range: Range<language::Anchor>,
2048 cx: &mut Context<Self>,
2049 ) {
2050 self.action_log.update(cx, |action_log, cx| {
2051 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2052 });
2053 }
2054
2055 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2056 self.action_log
2057 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2058 }
2059
2060 pub fn reject_edits_in_ranges(
2061 &mut self,
2062 buffer: Entity<language::Buffer>,
2063 buffer_ranges: Vec<Range<language::Anchor>>,
2064 cx: &mut Context<Self>,
2065 ) -> Task<Result<()>> {
2066 self.action_log.update(cx, |action_log, cx| {
2067 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2068 })
2069 }
2070
2071 pub fn action_log(&self) -> &Entity<ActionLog> {
2072 &self.action_log
2073 }
2074
2075 pub fn project(&self) -> &Entity<Project> {
2076 &self.project
2077 }
2078
2079 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2080 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2081 return;
2082 }
2083
2084 let now = Instant::now();
2085 if let Some(last) = self.last_auto_capture_at {
2086 if now.duration_since(last).as_secs() < 10 {
2087 return;
2088 }
2089 }
2090
2091 self.last_auto_capture_at = Some(now);
2092
2093 let thread_id = self.id().clone();
2094 let github_login = self
2095 .project
2096 .read(cx)
2097 .user_store()
2098 .read(cx)
2099 .current_user()
2100 .map(|user| user.github_login.clone());
2101 let client = self.project.read(cx).client().clone();
2102 let serialize_task = self.serialize(cx);
2103
2104 cx.background_executor()
2105 .spawn(async move {
2106 if let Ok(serialized_thread) = serialize_task.await {
2107 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2108 telemetry::event!(
2109 "Agent Thread Auto-Captured",
2110 thread_id = thread_id.to_string(),
2111 thread_data = thread_data,
2112 auto_capture_reason = "tracked_user",
2113 github_login = github_login
2114 );
2115
2116 client.telemetry().flush_events();
2117 }
2118 }
2119 })
2120 .detach();
2121 }
2122
2123 pub fn cumulative_token_usage(&self) -> TokenUsage {
2124 self.cumulative_token_usage
2125 }
2126
2127 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2128 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2129 return TotalTokenUsage::default();
2130 };
2131
2132 let max = model.model.max_token_count();
2133
2134 let index = self
2135 .messages
2136 .iter()
2137 .position(|msg| msg.id == message_id)
2138 .unwrap_or(0);
2139
2140 if index == 0 {
2141 return TotalTokenUsage { total: 0, max };
2142 }
2143
2144 let token_usage = &self
2145 .request_token_usage
2146 .get(index - 1)
2147 .cloned()
2148 .unwrap_or_default();
2149
2150 TotalTokenUsage {
2151 total: token_usage.total_tokens() as usize,
2152 max,
2153 }
2154 }
2155
2156 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2157 let model_registry = LanguageModelRegistry::read_global(cx);
2158 let Some(model) = model_registry.default_model() else {
2159 return TotalTokenUsage::default();
2160 };
2161
2162 let max = model.model.max_token_count();
2163
2164 if let Some(exceeded_error) = &self.exceeded_window_error {
2165 if model.model.id() == exceeded_error.model_id {
2166 return TotalTokenUsage {
2167 total: exceeded_error.token_count,
2168 max,
2169 };
2170 }
2171 }
2172
2173 let total = self
2174 .token_usage_at_last_message()
2175 .unwrap_or_default()
2176 .total_tokens() as usize;
2177
2178 TotalTokenUsage { total, max }
2179 }
2180
2181 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2182 self.request_token_usage
2183 .get(self.messages.len().saturating_sub(1))
2184 .or_else(|| self.request_token_usage.last())
2185 .cloned()
2186 }
2187
2188 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2189 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2190 self.request_token_usage
2191 .resize(self.messages.len(), placeholder);
2192
2193 if let Some(last) = self.request_token_usage.last_mut() {
2194 *last = token_usage;
2195 }
2196 }
2197
2198 pub fn deny_tool_use(
2199 &mut self,
2200 tool_use_id: LanguageModelToolUseId,
2201 tool_name: Arc<str>,
2202 cx: &mut Context<Self>,
2203 ) {
2204 let err = Err(anyhow::anyhow!(
2205 "Permission to run tool action denied by user"
2206 ));
2207
2208 self.tool_use
2209 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2210 self.tool_finished(tool_use_id.clone(), None, true, cx);
2211 }
2212}
2213
2214#[derive(Debug, Clone, Error)]
2215pub enum ThreadError {
2216 #[error("Payment required")]
2217 PaymentRequired,
2218 #[error("Max monthly spend reached")]
2219 MaxMonthlySpendReached,
2220 #[error("Model request limit reached")]
2221 ModelRequestLimitReached { plan: Plan },
2222 #[error("Message {header}: {message}")]
2223 Message {
2224 header: SharedString,
2225 message: SharedString,
2226 },
2227}
2228
2229#[derive(Debug, Clone)]
2230pub enum ThreadEvent {
2231 ShowError(ThreadError),
2232 UsageUpdated(RequestUsage),
2233 StreamedCompletion,
2234 StreamedAssistantText(MessageId, String),
2235 StreamedAssistantThinking(MessageId, String),
2236 StreamedToolUse {
2237 tool_use_id: LanguageModelToolUseId,
2238 ui_text: Arc<str>,
2239 input: serde_json::Value,
2240 },
2241 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2242 MessageAdded(MessageId),
2243 MessageEdited(MessageId),
2244 MessageDeleted(MessageId),
2245 SummaryGenerated,
2246 SummaryChanged,
2247 UsePendingTools {
2248 tool_uses: Vec<PendingToolUse>,
2249 },
2250 ToolFinished {
2251 #[allow(unused)]
2252 tool_use_id: LanguageModelToolUseId,
2253 /// The pending tool use that corresponds to this tool.
2254 pending_tool_use: Option<PendingToolUse>,
2255 },
2256 CheckpointChanged,
2257 ToolConfirmationNeeded,
2258}
2259
2260impl EventEmitter<ThreadEvent> for Thread {}
2261
2262struct PendingCompletion {
2263 id: usize,
2264 _task: Task<()>,
2265}
2266
2267#[cfg(test)]
2268mod tests {
2269 use super::*;
2270 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2271 use assistant_settings::AssistantSettings;
2272 use context_server::ContextServerSettings;
2273 use editor::EditorSettings;
2274 use gpui::TestAppContext;
2275 use project::{FakeFs, Project};
2276 use prompt_store::PromptBuilder;
2277 use serde_json::json;
2278 use settings::{Settings, SettingsStore};
2279 use std::sync::Arc;
2280 use theme::ThemeSettings;
2281 use util::path;
2282 use workspace::Workspace;
2283
2284 #[gpui::test]
2285 async fn test_message_with_context(cx: &mut TestAppContext) {
2286 init_test_settings(cx);
2287
2288 let project = create_test_project(
2289 cx,
2290 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2291 )
2292 .await;
2293
2294 let (_workspace, _thread_store, thread, context_store) =
2295 setup_test_environment(cx, project.clone()).await;
2296
2297 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2298 .await
2299 .unwrap();
2300
2301 let context =
2302 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2303
2304 // Insert user message with context
2305 let message_id = thread.update(cx, |thread, cx| {
2306 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2307 });
2308
2309 // Check content and context in message object
2310 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2311
2312 // Use different path format strings based on platform for the test
2313 #[cfg(windows)]
2314 let path_part = r"test\code.rs";
2315 #[cfg(not(windows))]
2316 let path_part = "test/code.rs";
2317
2318 let expected_context = format!(
2319 r#"
2320<context>
2321The following items were attached by the user. You don't need to use other tools to read them.
2322
2323<files>
2324```rs {path_part}
2325fn main() {{
2326 println!("Hello, world!");
2327}}
2328```
2329</files>
2330</context>
2331"#
2332 );
2333
2334 assert_eq!(message.role, Role::User);
2335 assert_eq!(message.segments.len(), 1);
2336 assert_eq!(
2337 message.segments[0],
2338 MessageSegment::Text("Please explain this code".to_string())
2339 );
2340 assert_eq!(message.context, expected_context);
2341
2342 // Check message in request
2343 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2344
2345 assert_eq!(request.messages.len(), 2);
2346 let expected_full_message = format!("{}Please explain this code", expected_context);
2347 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2348 }
2349
2350 #[gpui::test]
2351 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2352 init_test_settings(cx);
2353
2354 let project = create_test_project(
2355 cx,
2356 json!({
2357 "file1.rs": "fn function1() {}\n",
2358 "file2.rs": "fn function2() {}\n",
2359 "file3.rs": "fn function3() {}\n",
2360 }),
2361 )
2362 .await;
2363
2364 let (_, _thread_store, thread, context_store) =
2365 setup_test_environment(cx, project.clone()).await;
2366
2367 // Open files individually
2368 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2369 .await
2370 .unwrap();
2371 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2372 .await
2373 .unwrap();
2374 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2375 .await
2376 .unwrap();
2377
2378 // Get the context objects
2379 let contexts = context_store.update(cx, |store, _| store.context().clone());
2380 assert_eq!(contexts.len(), 3);
2381
2382 // First message with context 1
2383 let message1_id = thread.update(cx, |thread, cx| {
2384 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2385 });
2386
2387 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2388 let message2_id = thread.update(cx, |thread, cx| {
2389 thread.insert_user_message(
2390 "Message 2",
2391 vec![contexts[0].clone(), contexts[1].clone()],
2392 None,
2393 cx,
2394 )
2395 });
2396
2397 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2398 let message3_id = thread.update(cx, |thread, cx| {
2399 thread.insert_user_message(
2400 "Message 3",
2401 vec![
2402 contexts[0].clone(),
2403 contexts[1].clone(),
2404 contexts[2].clone(),
2405 ],
2406 None,
2407 cx,
2408 )
2409 });
2410
2411 // Check what contexts are included in each message
2412 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2413 (
2414 thread.message(message1_id).unwrap().clone(),
2415 thread.message(message2_id).unwrap().clone(),
2416 thread.message(message3_id).unwrap().clone(),
2417 )
2418 });
2419
2420 // First message should include context 1
2421 assert!(message1.context.contains("file1.rs"));
2422
2423 // Second message should include only context 2 (not 1)
2424 assert!(!message2.context.contains("file1.rs"));
2425 assert!(message2.context.contains("file2.rs"));
2426
2427 // Third message should include only context 3 (not 1 or 2)
2428 assert!(!message3.context.contains("file1.rs"));
2429 assert!(!message3.context.contains("file2.rs"));
2430 assert!(message3.context.contains("file3.rs"));
2431
2432 // Check entire request to make sure all contexts are properly included
2433 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2434
2435 // The request should contain all 3 messages
2436 assert_eq!(request.messages.len(), 4);
2437
2438 // Check that the contexts are properly formatted in each message
2439 assert!(request.messages[1].string_contents().contains("file1.rs"));
2440 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2441 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2442
2443 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2444 assert!(request.messages[2].string_contents().contains("file2.rs"));
2445 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2446
2447 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2448 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2449 assert!(request.messages[3].string_contents().contains("file3.rs"));
2450 }
2451
2452 #[gpui::test]
2453 async fn test_message_without_files(cx: &mut TestAppContext) {
2454 init_test_settings(cx);
2455
2456 let project = create_test_project(
2457 cx,
2458 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2459 )
2460 .await;
2461
2462 let (_, _thread_store, thread, _context_store) =
2463 setup_test_environment(cx, project.clone()).await;
2464
2465 // Insert user message without any context (empty context vector)
2466 let message_id = thread.update(cx, |thread, cx| {
2467 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2468 });
2469
2470 // Check content and context in message object
2471 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2472
2473 // Context should be empty when no files are included
2474 assert_eq!(message.role, Role::User);
2475 assert_eq!(message.segments.len(), 1);
2476 assert_eq!(
2477 message.segments[0],
2478 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2479 );
2480 assert_eq!(message.context, "");
2481
2482 // Check message in request
2483 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2484
2485 assert_eq!(request.messages.len(), 2);
2486 assert_eq!(
2487 request.messages[1].string_contents(),
2488 "What is the best way to learn Rust?"
2489 );
2490
2491 // Add second message, also without context
2492 let message2_id = thread.update(cx, |thread, cx| {
2493 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2494 });
2495
2496 let message2 =
2497 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2498 assert_eq!(message2.context, "");
2499
2500 // Check that both messages appear in the request
2501 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2502
2503 assert_eq!(request.messages.len(), 3);
2504 assert_eq!(
2505 request.messages[1].string_contents(),
2506 "What is the best way to learn Rust?"
2507 );
2508 assert_eq!(
2509 request.messages[2].string_contents(),
2510 "Are there any good books?"
2511 );
2512 }
2513
2514 #[gpui::test]
2515 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2516 init_test_settings(cx);
2517
2518 let project = create_test_project(
2519 cx,
2520 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2521 )
2522 .await;
2523
2524 let (_workspace, _thread_store, thread, context_store) =
2525 setup_test_environment(cx, project.clone()).await;
2526
2527 // Open buffer and add it to context
2528 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2529 .await
2530 .unwrap();
2531
2532 let context =
2533 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2534
2535 // Insert user message with the buffer as context
2536 thread.update(cx, |thread, cx| {
2537 thread.insert_user_message("Explain this code", vec![context], None, cx)
2538 });
2539
2540 // Create a request and check that it doesn't have a stale buffer warning yet
2541 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2542
2543 // Make sure we don't have a stale file warning yet
2544 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2545 msg.string_contents()
2546 .contains("These files changed since last read:")
2547 });
2548 assert!(
2549 !has_stale_warning,
2550 "Should not have stale buffer warning before buffer is modified"
2551 );
2552
2553 // Modify the buffer
2554 buffer.update(cx, |buffer, cx| {
2555 // Find a position at the end of line 1
2556 buffer.edit(
2557 [(1..1, "\n println!(\"Added a new line\");\n")],
2558 None,
2559 cx,
2560 );
2561 });
2562
2563 // Insert another user message without context
2564 thread.update(cx, |thread, cx| {
2565 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2566 });
2567
2568 // Create a new request and check for the stale buffer warning
2569 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2570
2571 // We should have a stale file warning as the last message
2572 let last_message = new_request
2573 .messages
2574 .last()
2575 .expect("Request should have messages");
2576
2577 // The last message should be the stale buffer notification
2578 assert_eq!(last_message.role, Role::User);
2579
2580 // Check the exact content of the message
2581 let expected_content = "These files changed since last read:\n- code.rs\n";
2582 assert_eq!(
2583 last_message.string_contents(),
2584 expected_content,
2585 "Last message should be exactly the stale buffer notification"
2586 );
2587 }
2588
2589 fn init_test_settings(cx: &mut TestAppContext) {
2590 cx.update(|cx| {
2591 let settings_store = SettingsStore::test(cx);
2592 cx.set_global(settings_store);
2593 language::init(cx);
2594 Project::init_settings(cx);
2595 AssistantSettings::register(cx);
2596 prompt_store::init(cx);
2597 thread_store::init(cx);
2598 workspace::init_settings(cx);
2599 ThemeSettings::register(cx);
2600 ContextServerSettings::register(cx);
2601 EditorSettings::register(cx);
2602 });
2603 }
2604
2605 // Helper to create a test project with test files
2606 async fn create_test_project(
2607 cx: &mut TestAppContext,
2608 files: serde_json::Value,
2609 ) -> Entity<Project> {
2610 let fs = FakeFs::new(cx.executor());
2611 fs.insert_tree(path!("/test"), files).await;
2612 Project::test(fs, [path!("/test").as_ref()], cx).await
2613 }
2614
2615 async fn setup_test_environment(
2616 cx: &mut TestAppContext,
2617 project: Entity<Project>,
2618 ) -> (
2619 Entity<Workspace>,
2620 Entity<ThreadStore>,
2621 Entity<Thread>,
2622 Entity<ContextStore>,
2623 ) {
2624 let (workspace, cx) =
2625 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2626
2627 let thread_store = cx
2628 .update(|_, cx| {
2629 ThreadStore::load(
2630 project.clone(),
2631 cx.new(|_| ToolWorkingSet::default()),
2632 Arc::new(PromptBuilder::new(None).unwrap()),
2633 cx,
2634 )
2635 })
2636 .await
2637 .unwrap();
2638
2639 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2640 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2641
2642 (workspace, thread_store, thread, context_store)
2643 }
2644
2645 async fn add_file_to_context(
2646 project: &Entity<Project>,
2647 context_store: &Entity<ContextStore>,
2648 path: &str,
2649 cx: &mut TestAppContext,
2650 ) -> Result<Entity<language::Buffer>> {
2651 let buffer_path = project
2652 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2653 .unwrap();
2654
2655 let buffer = project
2656 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2657 .await
2658 .unwrap();
2659
2660 context_store
2661 .update(cx, |store, cx| {
2662 store.add_file_from_buffer(buffer.clone(), cx)
2663 })
2664 .await?;
2665
2666 Ok(buffer)
2667 }
2668}