1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100 pub images: Vec<LanguageModelImage>,
101}
102
103impl Message {
104 /// Returns whether the message contains any meaningful text that should be displayed
105 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
106 pub fn should_display_content(&self) -> bool {
107 self.segments.iter().all(|segment| segment.should_display())
108 }
109
110 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
111 if let Some(MessageSegment::Thinking {
112 text: segment,
113 signature: current_signature,
114 }) = self.segments.last_mut()
115 {
116 if let Some(signature) = signature {
117 *current_signature = Some(signature);
118 }
119 segment.push_str(text);
120 } else {
121 self.segments.push(MessageSegment::Thinking {
122 text: text.to_string(),
123 signature,
124 });
125 }
126 }
127
128 pub fn push_text(&mut self, text: &str) {
129 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
130 segment.push_str(text);
131 } else {
132 self.segments.push(MessageSegment::Text(text.to_string()));
133 }
134 }
135
136 pub fn to_string(&self) -> String {
137 let mut result = String::new();
138
139 if !self.context.is_empty() {
140 result.push_str(&self.context);
141 }
142
143 for segment in &self.segments {
144 match segment {
145 MessageSegment::Text(text) => result.push_str(text),
146 MessageSegment::Thinking { text, .. } => {
147 result.push_str("<think>\n");
148 result.push_str(text);
149 result.push_str("\n</think>");
150 }
151 MessageSegment::RedactedThinking(_) => {}
152 }
153 }
154
155 result
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub enum MessageSegment {
161 Text(String),
162 Thinking {
163 text: String,
164 signature: Option<String>,
165 },
166 RedactedThinking(Vec<u8>),
167}
168
169impl MessageSegment {
170 pub fn should_display(&self) -> bool {
171 match self {
172 Self::Text(text) => text.is_empty(),
173 Self::Thinking { text, .. } => text.is_empty(),
174 Self::RedactedThinking(_) => false,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ProjectSnapshot {
181 pub worktree_snapshots: Vec<WorktreeSnapshot>,
182 pub unsaved_buffer_paths: Vec<String>,
183 pub timestamp: DateTime<Utc>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct WorktreeSnapshot {
188 pub worktree_path: String,
189 pub git_state: Option<GitState>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GitState {
194 pub remote_url: Option<String>,
195 pub head_sha: Option<String>,
196 pub current_branch: Option<String>,
197 pub diff: Option<String>,
198}
199
200#[derive(Clone)]
201pub struct ThreadCheckpoint {
202 message_id: MessageId,
203 git_checkpoint: GitStoreCheckpoint,
204}
205
206#[derive(Copy, Clone, Debug, PartialEq, Eq)]
207pub enum ThreadFeedback {
208 Positive,
209 Negative,
210}
211
212pub enum LastRestoreCheckpoint {
213 Pending {
214 message_id: MessageId,
215 },
216 Error {
217 message_id: MessageId,
218 error: String,
219 },
220}
221
222impl LastRestoreCheckpoint {
223 pub fn message_id(&self) -> MessageId {
224 match self {
225 LastRestoreCheckpoint::Pending { message_id } => *message_id,
226 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
227 }
228 }
229}
230
231#[derive(Clone, Debug, Default, Serialize, Deserialize)]
232pub enum DetailedSummaryState {
233 #[default]
234 NotGenerated,
235 Generating {
236 message_id: MessageId,
237 },
238 Generated {
239 text: SharedString,
240 message_id: MessageId,
241 },
242}
243
244#[derive(Default)]
245pub struct TotalTokenUsage {
246 pub total: usize,
247 pub max: usize,
248}
249
250impl TotalTokenUsage {
251 pub fn ratio(&self) -> TokenUsageRatio {
252 #[cfg(debug_assertions)]
253 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
254 .unwrap_or("0.8".to_string())
255 .parse()
256 .unwrap();
257 #[cfg(not(debug_assertions))]
258 let warning_threshold: f32 = 0.8;
259
260 if self.total >= self.max {
261 TokenUsageRatio::Exceeded
262 } else if self.total as f32 / self.max as f32 >= warning_threshold {
263 TokenUsageRatio::Warning
264 } else {
265 TokenUsageRatio::Normal
266 }
267 }
268
269 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
270 TotalTokenUsage {
271 total: self.total + tokens,
272 max: self.max,
273 }
274 }
275}
276
277#[derive(Debug, Default, PartialEq, Eq)]
278pub enum TokenUsageRatio {
279 #[default]
280 Normal,
281 Warning,
282 Exceeded,
283}
284
285/// A thread of conversation with the LLM.
286pub struct Thread {
287 id: ThreadId,
288 updated_at: DateTime<Utc>,
289 summary: Option<SharedString>,
290 pending_summary: Task<Option<()>>,
291 detailed_summary_state: DetailedSummaryState,
292 messages: Vec<Message>,
293 next_message_id: MessageId,
294 last_prompt_id: PromptId,
295 context: BTreeMap<ContextId, AssistantContext>,
296 context_by_message: HashMap<MessageId, Vec<ContextId>>,
297 project_context: SharedProjectContext,
298 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
299 completion_count: usize,
300 pending_completions: Vec<PendingCompletion>,
301 project: Entity<Project>,
302 prompt_builder: Arc<PromptBuilder>,
303 tools: Entity<ToolWorkingSet>,
304 tool_use: ToolUseState,
305 action_log: Entity<ActionLog>,
306 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
307 pending_checkpoint: Option<ThreadCheckpoint>,
308 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
309 request_token_usage: Vec<TokenUsage>,
310 cumulative_token_usage: TokenUsage,
311 exceeded_window_error: Option<ExceededWindowError>,
312 feedback: Option<ThreadFeedback>,
313 message_feedback: HashMap<MessageId, ThreadFeedback>,
314 last_auto_capture_at: Option<Instant>,
315 request_callback: Option<
316 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
317 >,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct ExceededWindowError {
322 /// Model used when last message exceeded context window
323 model_id: LanguageModelId,
324 /// Token count including last message
325 token_count: usize,
326}
327
328impl Thread {
329 pub fn new(
330 project: Entity<Project>,
331 tools: Entity<ToolWorkingSet>,
332 prompt_builder: Arc<PromptBuilder>,
333 system_prompt: SharedProjectContext,
334 cx: &mut Context<Self>,
335 ) -> Self {
336 Self {
337 id: ThreadId::new(),
338 updated_at: Utc::now(),
339 summary: None,
340 pending_summary: Task::ready(None),
341 detailed_summary_state: DetailedSummaryState::NotGenerated,
342 messages: Vec::new(),
343 next_message_id: MessageId(0),
344 last_prompt_id: PromptId::new(),
345 context: BTreeMap::default(),
346 context_by_message: HashMap::default(),
347 project_context: system_prompt,
348 checkpoints_by_message: HashMap::default(),
349 completion_count: 0,
350 pending_completions: Vec::new(),
351 project: project.clone(),
352 prompt_builder,
353 tools: tools.clone(),
354 last_restore_checkpoint: None,
355 pending_checkpoint: None,
356 tool_use: ToolUseState::new(tools.clone()),
357 action_log: cx.new(|_| ActionLog::new(project.clone())),
358 initial_project_snapshot: {
359 let project_snapshot = Self::project_snapshot(project, cx);
360 cx.foreground_executor()
361 .spawn(async move { Some(project_snapshot.await) })
362 .shared()
363 },
364 request_token_usage: Vec::new(),
365 cumulative_token_usage: TokenUsage::default(),
366 exceeded_window_error: None,
367 feedback: None,
368 message_feedback: HashMap::default(),
369 last_auto_capture_at: None,
370 request_callback: None,
371 }
372 }
373
374 pub fn deserialize(
375 id: ThreadId,
376 serialized: SerializedThread,
377 project: Entity<Project>,
378 tools: Entity<ToolWorkingSet>,
379 prompt_builder: Arc<PromptBuilder>,
380 project_context: SharedProjectContext,
381 cx: &mut Context<Self>,
382 ) -> Self {
383 let next_message_id = MessageId(
384 serialized
385 .messages
386 .last()
387 .map(|message| message.id.0 + 1)
388 .unwrap_or(0),
389 );
390 let tool_use =
391 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
392
393 Self {
394 id,
395 updated_at: serialized.updated_at,
396 summary: Some(serialized.summary),
397 pending_summary: Task::ready(None),
398 detailed_summary_state: serialized.detailed_summary_state,
399 messages: serialized
400 .messages
401 .into_iter()
402 .map(|message| Message {
403 id: message.id,
404 role: message.role,
405 segments: message
406 .segments
407 .into_iter()
408 .map(|segment| match segment {
409 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
410 SerializedMessageSegment::Thinking { text, signature } => {
411 MessageSegment::Thinking { text, signature }
412 }
413 SerializedMessageSegment::RedactedThinking { data } => {
414 MessageSegment::RedactedThinking(data)
415 }
416 })
417 .collect(),
418 context: message.context,
419 images: Vec::new(),
420 })
421 .collect(),
422 next_message_id,
423 last_prompt_id: PromptId::new(),
424 context: BTreeMap::default(),
425 context_by_message: HashMap::default(),
426 project_context,
427 checkpoints_by_message: HashMap::default(),
428 completion_count: 0,
429 pending_completions: Vec::new(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 project: project.clone(),
433 prompt_builder,
434 tools,
435 tool_use,
436 action_log: cx.new(|_| ActionLog::new(project)),
437 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
438 request_token_usage: serialized.request_token_usage,
439 cumulative_token_usage: serialized.cumulative_token_usage,
440 exceeded_window_error: None,
441 feedback: None,
442 message_feedback: HashMap::default(),
443 last_auto_capture_at: None,
444 request_callback: None,
445 }
446 }
447
448 pub fn set_request_callback(
449 &mut self,
450 callback: impl 'static
451 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
452 ) {
453 self.request_callback = Some(Box::new(callback));
454 }
455
456 pub fn id(&self) -> &ThreadId {
457 &self.id
458 }
459
460 pub fn is_empty(&self) -> bool {
461 self.messages.is_empty()
462 }
463
464 pub fn updated_at(&self) -> DateTime<Utc> {
465 self.updated_at
466 }
467
468 pub fn touch_updated_at(&mut self) {
469 self.updated_at = Utc::now();
470 }
471
472 pub fn advance_prompt_id(&mut self) {
473 self.last_prompt_id = PromptId::new();
474 }
475
476 pub fn summary(&self) -> Option<SharedString> {
477 self.summary.clone()
478 }
479
480 pub fn project_context(&self) -> SharedProjectContext {
481 self.project_context.clone()
482 }
483
484 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
485
486 pub fn summary_or_default(&self) -> SharedString {
487 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
488 }
489
490 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
491 let Some(current_summary) = &self.summary else {
492 // Don't allow setting summary until generated
493 return;
494 };
495
496 let mut new_summary = new_summary.into();
497
498 if new_summary.is_empty() {
499 new_summary = Self::DEFAULT_SUMMARY;
500 }
501
502 if current_summary != &new_summary {
503 self.summary = Some(new_summary);
504 cx.emit(ThreadEvent::SummaryChanged);
505 }
506 }
507
508 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
509 self.latest_detailed_summary()
510 .unwrap_or_else(|| self.text().into())
511 }
512
513 fn latest_detailed_summary(&self) -> Option<SharedString> {
514 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
515 Some(text.clone())
516 } else {
517 None
518 }
519 }
520
521 pub fn message(&self, id: MessageId) -> Option<&Message> {
522 self.messages.iter().find(|message| message.id == id)
523 }
524
525 pub fn messages(&self) -> impl Iterator<Item = &Message> {
526 self.messages.iter()
527 }
528
529 pub fn is_generating(&self) -> bool {
530 !self.pending_completions.is_empty() || !self.all_tools_finished()
531 }
532
533 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
534 &self.tools
535 }
536
537 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
538 self.tool_use
539 .pending_tool_uses()
540 .into_iter()
541 .find(|tool_use| &tool_use.id == id)
542 }
543
544 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
545 self.tool_use
546 .pending_tool_uses()
547 .into_iter()
548 .filter(|tool_use| tool_use.status.needs_confirmation())
549 }
550
551 pub fn has_pending_tool_uses(&self) -> bool {
552 !self.tool_use.pending_tool_uses().is_empty()
553 }
554
555 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
556 self.checkpoints_by_message.get(&id).cloned()
557 }
558
559 pub fn restore_checkpoint(
560 &mut self,
561 checkpoint: ThreadCheckpoint,
562 cx: &mut Context<Self>,
563 ) -> Task<Result<()>> {
564 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
565 message_id: checkpoint.message_id,
566 });
567 cx.emit(ThreadEvent::CheckpointChanged);
568 cx.notify();
569
570 let git_store = self.project().read(cx).git_store().clone();
571 let restore = git_store.update(cx, |git_store, cx| {
572 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
573 });
574
575 cx.spawn(async move |this, cx| {
576 let result = restore.await;
577 this.update(cx, |this, cx| {
578 if let Err(err) = result.as_ref() {
579 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
580 message_id: checkpoint.message_id,
581 error: err.to_string(),
582 });
583 } else {
584 this.truncate(checkpoint.message_id, cx);
585 this.last_restore_checkpoint = None;
586 }
587 this.pending_checkpoint = None;
588 cx.emit(ThreadEvent::CheckpointChanged);
589 cx.notify();
590 })?;
591 result
592 })
593 }
594
595 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
596 let pending_checkpoint = if self.is_generating() {
597 return;
598 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
599 checkpoint
600 } else {
601 return;
602 };
603
604 let git_store = self.project.read(cx).git_store().clone();
605 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
606 cx.spawn(async move |this, cx| match final_checkpoint.await {
607 Ok(final_checkpoint) => {
608 let equal = git_store
609 .update(cx, |store, cx| {
610 store.compare_checkpoints(
611 pending_checkpoint.git_checkpoint.clone(),
612 final_checkpoint.clone(),
613 cx,
614 )
615 })?
616 .await
617 .unwrap_or(false);
618
619 if equal {
620 git_store
621 .update(cx, |store, cx| {
622 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
623 })?
624 .detach();
625 } else {
626 this.update(cx, |this, cx| {
627 this.insert_checkpoint(pending_checkpoint, cx)
628 })?;
629 }
630
631 git_store
632 .update(cx, |store, cx| {
633 store.delete_checkpoint(final_checkpoint, cx)
634 })?
635 .detach();
636
637 Ok(())
638 }
639 Err(_) => this.update(cx, |this, cx| {
640 this.insert_checkpoint(pending_checkpoint, cx)
641 }),
642 })
643 .detach();
644 }
645
646 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
647 self.checkpoints_by_message
648 .insert(checkpoint.message_id, checkpoint);
649 cx.emit(ThreadEvent::CheckpointChanged);
650 cx.notify();
651 }
652
653 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
654 self.last_restore_checkpoint.as_ref()
655 }
656
657 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
658 let Some(message_ix) = self
659 .messages
660 .iter()
661 .rposition(|message| message.id == message_id)
662 else {
663 return;
664 };
665 for deleted_message in self.messages.drain(message_ix..) {
666 self.context_by_message.remove(&deleted_message.id);
667 self.checkpoints_by_message.remove(&deleted_message.id);
668 }
669 cx.notify();
670 }
671
672 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
673 self.context_by_message
674 .get(&id)
675 .into_iter()
676 .flat_map(|context| {
677 context
678 .iter()
679 .filter_map(|context_id| self.context.get(&context_id))
680 })
681 }
682
683 /// Returns whether all of the tool uses have finished running.
684 pub fn all_tools_finished(&self) -> bool {
685 // If the only pending tool uses left are the ones with errors, then
686 // that means that we've finished running all of the pending tools.
687 self.tool_use
688 .pending_tool_uses()
689 .iter()
690 .all(|tool_use| tool_use.status.is_error())
691 }
692
693 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
694 self.tool_use.tool_uses_for_message(id, cx)
695 }
696
697 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
698 self.tool_use.tool_results_for_message(id)
699 }
700
701 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
702 self.tool_use.tool_result(id)
703 }
704
705 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
706 Some(&self.tool_use.tool_result(id)?.content)
707 }
708
709 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
710 self.tool_use.tool_result_card(id).cloned()
711 }
712
713 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
714 self.tool_use.message_has_tool_results(message_id)
715 }
716
717 /// Filter out contexts that have already been included in previous messages
718 pub fn filter_new_context<'a>(
719 &self,
720 context: impl Iterator<Item = &'a AssistantContext>,
721 ) -> impl Iterator<Item = &'a AssistantContext> {
722 context.filter(|ctx| self.is_context_new(ctx))
723 }
724
725 fn is_context_new(&self, context: &AssistantContext) -> bool {
726 !self.context.contains_key(&context.id())
727 }
728
729 pub fn insert_user_message(
730 &mut self,
731 text: impl Into<String>,
732 context: Vec<AssistantContext>,
733 git_checkpoint: Option<GitStoreCheckpoint>,
734 cx: &mut Context<Self>,
735 ) -> MessageId {
736 let text = text.into();
737
738 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
739
740 let new_context: Vec<_> = context
741 .into_iter()
742 .filter(|ctx| self.is_context_new(ctx))
743 .collect();
744
745 if !new_context.is_empty() {
746 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
747 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
748 message.context = context_string;
749 }
750 }
751
752 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
753 message.images = new_context
754 .iter()
755 .filter_map(|context| {
756 if let AssistantContext::Image(image_context) = context {
757 image_context.image_task.clone().now_or_never().flatten()
758 } else {
759 None
760 }
761 })
762 .collect::<Vec<_>>();
763 }
764
765 self.action_log.update(cx, |log, cx| {
766 // Track all buffers added as context
767 for ctx in &new_context {
768 match ctx {
769 AssistantContext::File(file_ctx) => {
770 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
771 }
772 AssistantContext::Directory(dir_ctx) => {
773 for context_buffer in &dir_ctx.context_buffers {
774 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
775 }
776 }
777 AssistantContext::Symbol(symbol_ctx) => {
778 log.buffer_added_as_context(
779 symbol_ctx.context_symbol.buffer.clone(),
780 cx,
781 );
782 }
783 AssistantContext::Selection(selection_context) => {
784 log.buffer_added_as_context(
785 selection_context.context_buffer.buffer.clone(),
786 cx,
787 );
788 }
789 AssistantContext::FetchedUrl(_)
790 | AssistantContext::Thread(_)
791 | AssistantContext::Rules(_)
792 | AssistantContext::Image(_) => {}
793 }
794 }
795 });
796 }
797
798 let context_ids = new_context
799 .iter()
800 .map(|context| context.id())
801 .collect::<Vec<_>>();
802 self.context.extend(
803 new_context
804 .into_iter()
805 .map(|context| (context.id(), context)),
806 );
807 self.context_by_message.insert(message_id, context_ids);
808
809 if let Some(git_checkpoint) = git_checkpoint {
810 self.pending_checkpoint = Some(ThreadCheckpoint {
811 message_id,
812 git_checkpoint,
813 });
814 }
815
816 self.auto_capture_telemetry(cx);
817
818 message_id
819 }
820
821 pub fn insert_message(
822 &mut self,
823 role: Role,
824 segments: Vec<MessageSegment>,
825 cx: &mut Context<Self>,
826 ) -> MessageId {
827 let id = self.next_message_id.post_inc();
828 self.messages.push(Message {
829 id,
830 role,
831 segments,
832 context: String::new(),
833 images: Vec::new(),
834 });
835 self.touch_updated_at();
836 cx.emit(ThreadEvent::MessageAdded(id));
837 id
838 }
839
840 pub fn edit_message(
841 &mut self,
842 id: MessageId,
843 new_role: Role,
844 new_segments: Vec<MessageSegment>,
845 cx: &mut Context<Self>,
846 ) -> bool {
847 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
848 return false;
849 };
850 message.role = new_role;
851 message.segments = new_segments;
852 self.touch_updated_at();
853 cx.emit(ThreadEvent::MessageEdited(id));
854 true
855 }
856
857 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
858 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
859 return false;
860 };
861 self.messages.remove(index);
862 self.context_by_message.remove(&id);
863 self.touch_updated_at();
864 cx.emit(ThreadEvent::MessageDeleted(id));
865 true
866 }
867
868 /// Returns the representation of this [`Thread`] in a textual form.
869 ///
870 /// This is the representation we use when attaching a thread as context to another thread.
871 pub fn text(&self) -> String {
872 let mut text = String::new();
873
874 for message in &self.messages {
875 text.push_str(match message.role {
876 language_model::Role::User => "User:",
877 language_model::Role::Assistant => "Assistant:",
878 language_model::Role::System => "System:",
879 });
880 text.push('\n');
881
882 for segment in &message.segments {
883 match segment {
884 MessageSegment::Text(content) => text.push_str(content),
885 MessageSegment::Thinking { text: content, .. } => {
886 text.push_str(&format!("<think>{}</think>", content))
887 }
888 MessageSegment::RedactedThinking(_) => {}
889 }
890 }
891 text.push('\n');
892 }
893
894 text
895 }
896
897 /// Serializes this thread into a format for storage or telemetry.
898 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
899 let initial_project_snapshot = self.initial_project_snapshot.clone();
900 cx.spawn(async move |this, cx| {
901 let initial_project_snapshot = initial_project_snapshot.await;
902 this.read_with(cx, |this, cx| SerializedThread {
903 version: SerializedThread::VERSION.to_string(),
904 summary: this.summary_or_default(),
905 updated_at: this.updated_at(),
906 messages: this
907 .messages()
908 .map(|message| SerializedMessage {
909 id: message.id,
910 role: message.role,
911 segments: message
912 .segments
913 .iter()
914 .map(|segment| match segment {
915 MessageSegment::Text(text) => {
916 SerializedMessageSegment::Text { text: text.clone() }
917 }
918 MessageSegment::Thinking { text, signature } => {
919 SerializedMessageSegment::Thinking {
920 text: text.clone(),
921 signature: signature.clone(),
922 }
923 }
924 MessageSegment::RedactedThinking(data) => {
925 SerializedMessageSegment::RedactedThinking {
926 data: data.clone(),
927 }
928 }
929 })
930 .collect(),
931 tool_uses: this
932 .tool_uses_for_message(message.id, cx)
933 .into_iter()
934 .map(|tool_use| SerializedToolUse {
935 id: tool_use.id,
936 name: tool_use.name,
937 input: tool_use.input,
938 })
939 .collect(),
940 tool_results: this
941 .tool_results_for_message(message.id)
942 .into_iter()
943 .map(|tool_result| SerializedToolResult {
944 tool_use_id: tool_result.tool_use_id.clone(),
945 is_error: tool_result.is_error,
946 content: tool_result.content.clone(),
947 })
948 .collect(),
949 context: message.context.clone(),
950 })
951 .collect(),
952 initial_project_snapshot,
953 cumulative_token_usage: this.cumulative_token_usage,
954 request_token_usage: this.request_token_usage.clone(),
955 detailed_summary_state: this.detailed_summary_state.clone(),
956 exceeded_window_error: this.exceeded_window_error.clone(),
957 })
958 })
959 }
960
961 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
962 let mut request = self.to_completion_request(cx);
963 if model.supports_tools() {
964 request.tools = {
965 let mut tools = Vec::new();
966 tools.extend(
967 self.tools()
968 .read(cx)
969 .enabled_tools(cx)
970 .into_iter()
971 .filter_map(|tool| {
972 // Skip tools that cannot be supported
973 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
974 Some(LanguageModelRequestTool {
975 name: tool.name(),
976 description: tool.description(),
977 input_schema,
978 })
979 }),
980 );
981
982 tools
983 };
984 }
985
986 self.stream_completion(request, model, cx);
987 }
988
989 pub fn used_tools_since_last_user_message(&self) -> bool {
990 for message in self.messages.iter().rev() {
991 if self.tool_use.message_has_tool_results(message.id) {
992 return true;
993 } else if message.role == Role::User {
994 return false;
995 }
996 }
997
998 false
999 }
1000
1001 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1002 let mut request = LanguageModelRequest {
1003 thread_id: Some(self.id.to_string()),
1004 prompt_id: Some(self.last_prompt_id.to_string()),
1005 messages: vec![],
1006 tools: Vec::new(),
1007 stop: Vec::new(),
1008 temperature: None,
1009 };
1010
1011 if let Some(project_context) = self.project_context.borrow().as_ref() {
1012 match self
1013 .prompt_builder
1014 .generate_assistant_system_prompt(project_context)
1015 {
1016 Err(err) => {
1017 let message = format!("{err:?}").into();
1018 log::error!("{message}");
1019 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020 header: "Error generating system prompt".into(),
1021 message,
1022 }));
1023 }
1024 Ok(system_prompt) => {
1025 request.messages.push(LanguageModelRequestMessage {
1026 role: Role::System,
1027 content: vec![MessageContent::Text(system_prompt)],
1028 cache: true,
1029 });
1030 }
1031 }
1032 } else {
1033 let message = "Context for system prompt unexpectedly not ready.".into();
1034 log::error!("{message}");
1035 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1036 header: "Error generating system prompt".into(),
1037 message,
1038 }));
1039 }
1040
1041 for message in &self.messages {
1042 let mut request_message = LanguageModelRequestMessage {
1043 role: message.role,
1044 content: Vec::new(),
1045 cache: false,
1046 };
1047
1048 self.tool_use
1049 .attach_tool_results(message.id, &mut request_message);
1050
1051 if !message.context.is_empty() {
1052 request_message
1053 .content
1054 .push(MessageContent::Text(message.context.to_string()));
1055 }
1056
1057 if !message.images.is_empty() {
1058 // Some providers only support image parts after an initial text part
1059 if request_message.content.is_empty() {
1060 request_message
1061 .content
1062 .push(MessageContent::Text("Images attached by user:".to_string()));
1063 }
1064
1065 for image in &message.images {
1066 request_message
1067 .content
1068 .push(MessageContent::Image(image.clone()))
1069 }
1070 }
1071
1072 for segment in &message.segments {
1073 match segment {
1074 MessageSegment::Text(text) => {
1075 if !text.is_empty() {
1076 request_message
1077 .content
1078 .push(MessageContent::Text(text.into()));
1079 }
1080 }
1081 MessageSegment::Thinking { text, signature } => {
1082 if !text.is_empty() {
1083 request_message.content.push(MessageContent::Thinking {
1084 text: text.into(),
1085 signature: signature.clone(),
1086 });
1087 }
1088 }
1089 MessageSegment::RedactedThinking(data) => {
1090 request_message
1091 .content
1092 .push(MessageContent::RedactedThinking(data.clone()));
1093 }
1094 };
1095 }
1096
1097 self.tool_use
1098 .attach_tool_uses(message.id, &mut request_message);
1099
1100 request.messages.push(request_message);
1101 }
1102
1103 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1104 if let Some(last) = request.messages.last_mut() {
1105 last.cache = true;
1106 }
1107
1108 self.attached_tracked_files_state(&mut request.messages, cx);
1109
1110 request
1111 }
1112
1113 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1114 let mut request = LanguageModelRequest {
1115 thread_id: None,
1116 prompt_id: None,
1117 messages: vec![],
1118 tools: Vec::new(),
1119 stop: Vec::new(),
1120 temperature: None,
1121 };
1122
1123 for message in &self.messages {
1124 let mut request_message = LanguageModelRequestMessage {
1125 role: message.role,
1126 content: Vec::new(),
1127 cache: false,
1128 };
1129
1130 // Skip tool results during summarization.
1131 if self.tool_use.message_has_tool_results(message.id) {
1132 continue;
1133 }
1134
1135 for segment in &message.segments {
1136 match segment {
1137 MessageSegment::Text(text) => request_message
1138 .content
1139 .push(MessageContent::Text(text.clone())),
1140 MessageSegment::Thinking { .. } => {}
1141 MessageSegment::RedactedThinking(_) => {}
1142 }
1143 }
1144
1145 if request_message.content.is_empty() {
1146 continue;
1147 }
1148
1149 request.messages.push(request_message);
1150 }
1151
1152 request.messages.push(LanguageModelRequestMessage {
1153 role: Role::User,
1154 content: vec![MessageContent::Text(added_user_message)],
1155 cache: false,
1156 });
1157
1158 request
1159 }
1160
1161 fn attached_tracked_files_state(
1162 &self,
1163 messages: &mut Vec<LanguageModelRequestMessage>,
1164 cx: &App,
1165 ) {
1166 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1167
1168 let mut stale_message = String::new();
1169
1170 let action_log = self.action_log.read(cx);
1171
1172 for stale_file in action_log.stale_buffers(cx) {
1173 let Some(file) = stale_file.read(cx).file() else {
1174 continue;
1175 };
1176
1177 if stale_message.is_empty() {
1178 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1179 }
1180
1181 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1182 }
1183
1184 let mut content = Vec::with_capacity(2);
1185
1186 if !stale_message.is_empty() {
1187 content.push(stale_message.into());
1188 }
1189
1190 if !content.is_empty() {
1191 let context_message = LanguageModelRequestMessage {
1192 role: Role::User,
1193 content,
1194 cache: false,
1195 };
1196
1197 messages.push(context_message);
1198 }
1199 }
1200
1201 pub fn stream_completion(
1202 &mut self,
1203 request: LanguageModelRequest,
1204 model: Arc<dyn LanguageModel>,
1205 cx: &mut Context<Self>,
1206 ) {
1207 let pending_completion_id = post_inc(&mut self.completion_count);
1208 let mut request_callback_parameters = if self.request_callback.is_some() {
1209 Some((request.clone(), Vec::new()))
1210 } else {
1211 None
1212 };
1213 let prompt_id = self.last_prompt_id.clone();
1214 let tool_use_metadata = ToolUseMetadata {
1215 model: model.clone(),
1216 thread_id: self.id.clone(),
1217 prompt_id: prompt_id.clone(),
1218 };
1219
1220 let task = cx.spawn(async move |thread, cx| {
1221 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1222 let initial_token_usage =
1223 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1224 let stream_completion = async {
1225 let (mut events, usage) = stream_completion_future.await?;
1226
1227 let mut stop_reason = StopReason::EndTurn;
1228 let mut current_token_usage = TokenUsage::default();
1229
1230 if let Some(usage) = usage {
1231 thread
1232 .update(cx, |_thread, cx| {
1233 cx.emit(ThreadEvent::UsageUpdated(usage));
1234 })
1235 .ok();
1236 }
1237
1238 while let Some(event) = events.next().await {
1239 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1240 response_events
1241 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1242 }
1243
1244 let event = event?;
1245
1246 thread.update(cx, |thread, cx| {
1247 match event {
1248 LanguageModelCompletionEvent::StartMessage { .. } => {
1249 thread.insert_message(
1250 Role::Assistant,
1251 vec![MessageSegment::Text(String::new())],
1252 cx,
1253 );
1254 }
1255 LanguageModelCompletionEvent::Stop(reason) => {
1256 stop_reason = reason;
1257 }
1258 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1259 thread.update_token_usage_at_last_message(token_usage);
1260 thread.cumulative_token_usage = thread.cumulative_token_usage
1261 + token_usage
1262 - current_token_usage;
1263 current_token_usage = token_usage;
1264 }
1265 LanguageModelCompletionEvent::Text(chunk) => {
1266 cx.emit(ThreadEvent::ReceivedTextChunk);
1267 if let Some(last_message) = thread.messages.last_mut() {
1268 if last_message.role == Role::Assistant {
1269 last_message.push_text(&chunk);
1270 cx.emit(ThreadEvent::StreamedAssistantText(
1271 last_message.id,
1272 chunk,
1273 ));
1274 } else {
1275 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1276 // of a new Assistant response.
1277 //
1278 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1279 // will result in duplicating the text of the chunk in the rendered Markdown.
1280 thread.insert_message(
1281 Role::Assistant,
1282 vec![MessageSegment::Text(chunk.to_string())],
1283 cx,
1284 );
1285 };
1286 }
1287 }
1288 LanguageModelCompletionEvent::Thinking {
1289 text: chunk,
1290 signature,
1291 } => {
1292 if let Some(last_message) = thread.messages.last_mut() {
1293 if last_message.role == Role::Assistant {
1294 last_message.push_thinking(&chunk, signature);
1295 cx.emit(ThreadEvent::StreamedAssistantThinking(
1296 last_message.id,
1297 chunk,
1298 ));
1299 } else {
1300 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1301 // of a new Assistant response.
1302 //
1303 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1304 // will result in duplicating the text of the chunk in the rendered Markdown.
1305 thread.insert_message(
1306 Role::Assistant,
1307 vec![MessageSegment::Thinking {
1308 text: chunk.to_string(),
1309 signature,
1310 }],
1311 cx,
1312 );
1313 };
1314 }
1315 }
1316 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1317 let last_assistant_message_id = thread
1318 .messages
1319 .iter_mut()
1320 .rfind(|message| message.role == Role::Assistant)
1321 .map(|message| message.id)
1322 .unwrap_or_else(|| {
1323 thread.insert_message(Role::Assistant, vec![], cx)
1324 });
1325
1326 let tool_use_id = tool_use.id.clone();
1327 let streamed_input = if tool_use.is_input_complete {
1328 None
1329 } else {
1330 Some((&tool_use.input).clone())
1331 };
1332
1333 let ui_text = thread.tool_use.request_tool_use(
1334 last_assistant_message_id,
1335 tool_use,
1336 tool_use_metadata.clone(),
1337 cx,
1338 );
1339
1340 if let Some(input) = streamed_input {
1341 cx.emit(ThreadEvent::StreamedToolUse {
1342 tool_use_id,
1343 ui_text,
1344 input,
1345 });
1346 }
1347 }
1348 }
1349
1350 thread.touch_updated_at();
1351 cx.emit(ThreadEvent::StreamedCompletion);
1352 cx.notify();
1353
1354 thread.auto_capture_telemetry(cx);
1355 })?;
1356
1357 smol::future::yield_now().await;
1358 }
1359
1360 thread.update(cx, |thread, cx| {
1361 thread
1362 .pending_completions
1363 .retain(|completion| completion.id != pending_completion_id);
1364
1365 // If there is a response without tool use, summarize the message. Otherwise,
1366 // allow two tool uses before summarizing.
1367 if thread.summary.is_none()
1368 && thread.messages.len() >= 2
1369 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1370 {
1371 thread.summarize(cx);
1372 }
1373 })?;
1374
1375 anyhow::Ok(stop_reason)
1376 };
1377
1378 let result = stream_completion.await;
1379
1380 thread
1381 .update(cx, |thread, cx| {
1382 thread.finalize_pending_checkpoint(cx);
1383 match result.as_ref() {
1384 Ok(stop_reason) => match stop_reason {
1385 StopReason::ToolUse => {
1386 let tool_uses = thread.use_pending_tools(cx);
1387 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1388 }
1389 StopReason::EndTurn => {}
1390 StopReason::MaxTokens => {}
1391 },
1392 Err(error) => {
1393 if error.is::<PaymentRequiredError>() {
1394 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1395 } else if error.is::<MaxMonthlySpendReachedError>() {
1396 cx.emit(ThreadEvent::ShowError(
1397 ThreadError::MaxMonthlySpendReached,
1398 ));
1399 } else if let Some(error) =
1400 error.downcast_ref::<ModelRequestLimitReachedError>()
1401 {
1402 cx.emit(ThreadEvent::ShowError(
1403 ThreadError::ModelRequestLimitReached { plan: error.plan },
1404 ));
1405 } else if let Some(known_error) =
1406 error.downcast_ref::<LanguageModelKnownError>()
1407 {
1408 match known_error {
1409 LanguageModelKnownError::ContextWindowLimitExceeded {
1410 tokens,
1411 } => {
1412 thread.exceeded_window_error = Some(ExceededWindowError {
1413 model_id: model.id(),
1414 token_count: *tokens,
1415 });
1416 cx.notify();
1417 }
1418 }
1419 } else {
1420 let error_message = error
1421 .chain()
1422 .map(|err| err.to_string())
1423 .collect::<Vec<_>>()
1424 .join("\n");
1425 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1426 header: "Error interacting with language model".into(),
1427 message: SharedString::from(error_message.clone()),
1428 }));
1429 }
1430
1431 thread.cancel_last_completion(cx);
1432 }
1433 }
1434 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1435
1436 if let Some((request_callback, (request, response_events))) = thread
1437 .request_callback
1438 .as_mut()
1439 .zip(request_callback_parameters.as_ref())
1440 {
1441 request_callback(request, response_events);
1442 }
1443
1444 thread.auto_capture_telemetry(cx);
1445
1446 if let Ok(initial_usage) = initial_token_usage {
1447 let usage = thread.cumulative_token_usage - initial_usage;
1448
1449 telemetry::event!(
1450 "Assistant Thread Completion",
1451 thread_id = thread.id().to_string(),
1452 prompt_id = prompt_id,
1453 model = model.telemetry_id(),
1454 model_provider = model.provider_id().to_string(),
1455 input_tokens = usage.input_tokens,
1456 output_tokens = usage.output_tokens,
1457 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1458 cache_read_input_tokens = usage.cache_read_input_tokens,
1459 );
1460 }
1461 })
1462 .ok();
1463 });
1464
1465 self.pending_completions.push(PendingCompletion {
1466 id: pending_completion_id,
1467 _task: task,
1468 });
1469 }
1470
1471 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1472 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1473 return;
1474 };
1475
1476 if !model.provider.is_authenticated(cx) {
1477 return;
1478 }
1479
1480 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1481 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1482 If the conversation is about a specific subject, include it in the title. \
1483 Be descriptive. DO NOT speak in the first person.";
1484
1485 let request = self.to_summarize_request(added_user_message.into());
1486
1487 self.pending_summary = cx.spawn(async move |this, cx| {
1488 async move {
1489 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1490 let (mut messages, usage) = stream.await?;
1491
1492 if let Some(usage) = usage {
1493 this.update(cx, |_thread, cx| {
1494 cx.emit(ThreadEvent::UsageUpdated(usage));
1495 })
1496 .ok();
1497 }
1498
1499 let mut new_summary = String::new();
1500 while let Some(message) = messages.stream.next().await {
1501 let text = message?;
1502 let mut lines = text.lines();
1503 new_summary.extend(lines.next());
1504
1505 // Stop if the LLM generated multiple lines.
1506 if lines.next().is_some() {
1507 break;
1508 }
1509 }
1510
1511 this.update(cx, |this, cx| {
1512 if !new_summary.is_empty() {
1513 this.summary = Some(new_summary.into());
1514 }
1515
1516 cx.emit(ThreadEvent::SummaryGenerated);
1517 })?;
1518
1519 anyhow::Ok(())
1520 }
1521 .log_err()
1522 .await
1523 });
1524 }
1525
1526 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1527 let last_message_id = self.messages.last().map(|message| message.id)?;
1528
1529 match &self.detailed_summary_state {
1530 DetailedSummaryState::Generating { message_id, .. }
1531 | DetailedSummaryState::Generated { message_id, .. }
1532 if *message_id == last_message_id =>
1533 {
1534 // Already up-to-date
1535 return None;
1536 }
1537 _ => {}
1538 }
1539
1540 let ConfiguredModel { model, provider } =
1541 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1542
1543 if !provider.is_authenticated(cx) {
1544 return None;
1545 }
1546
1547 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1548 1. A brief overview of what was discussed\n\
1549 2. Key facts or information discovered\n\
1550 3. Outcomes or conclusions reached\n\
1551 4. Any action items or next steps if any\n\
1552 Format it in Markdown with headings and bullet points.";
1553
1554 let request = self.to_summarize_request(added_user_message.into());
1555
1556 let task = cx.spawn(async move |thread, cx| {
1557 let stream = model.stream_completion_text(request, &cx);
1558 let Some(mut messages) = stream.await.log_err() else {
1559 thread
1560 .update(cx, |this, _cx| {
1561 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1562 })
1563 .log_err();
1564
1565 return;
1566 };
1567
1568 let mut new_detailed_summary = String::new();
1569
1570 while let Some(chunk) = messages.stream.next().await {
1571 if let Some(chunk) = chunk.log_err() {
1572 new_detailed_summary.push_str(&chunk);
1573 }
1574 }
1575
1576 thread
1577 .update(cx, |this, _cx| {
1578 this.detailed_summary_state = DetailedSummaryState::Generated {
1579 text: new_detailed_summary.into(),
1580 message_id: last_message_id,
1581 };
1582 })
1583 .log_err();
1584 });
1585
1586 self.detailed_summary_state = DetailedSummaryState::Generating {
1587 message_id: last_message_id,
1588 };
1589
1590 Some(task)
1591 }
1592
1593 pub fn is_generating_detailed_summary(&self) -> bool {
1594 matches!(
1595 self.detailed_summary_state,
1596 DetailedSummaryState::Generating { .. }
1597 )
1598 }
1599
1600 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1601 self.auto_capture_telemetry(cx);
1602 let request = self.to_completion_request(cx);
1603 let messages = Arc::new(request.messages);
1604 let pending_tool_uses = self
1605 .tool_use
1606 .pending_tool_uses()
1607 .into_iter()
1608 .filter(|tool_use| tool_use.status.is_idle())
1609 .cloned()
1610 .collect::<Vec<_>>();
1611
1612 for tool_use in pending_tool_uses.iter() {
1613 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1614 if tool.needs_confirmation(&tool_use.input, cx)
1615 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1616 {
1617 self.tool_use.confirm_tool_use(
1618 tool_use.id.clone(),
1619 tool_use.ui_text.clone(),
1620 tool_use.input.clone(),
1621 messages.clone(),
1622 tool,
1623 );
1624 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1625 } else {
1626 self.run_tool(
1627 tool_use.id.clone(),
1628 tool_use.ui_text.clone(),
1629 tool_use.input.clone(),
1630 &messages,
1631 tool,
1632 cx,
1633 );
1634 }
1635 }
1636 }
1637
1638 pending_tool_uses
1639 }
1640
1641 pub fn run_tool(
1642 &mut self,
1643 tool_use_id: LanguageModelToolUseId,
1644 ui_text: impl Into<SharedString>,
1645 input: serde_json::Value,
1646 messages: &[LanguageModelRequestMessage],
1647 tool: Arc<dyn Tool>,
1648 cx: &mut Context<Thread>,
1649 ) {
1650 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1651 self.tool_use
1652 .run_pending_tool(tool_use_id, ui_text.into(), task);
1653 }
1654
1655 fn spawn_tool_use(
1656 &mut self,
1657 tool_use_id: LanguageModelToolUseId,
1658 messages: &[LanguageModelRequestMessage],
1659 input: serde_json::Value,
1660 tool: Arc<dyn Tool>,
1661 cx: &mut Context<Thread>,
1662 ) -> Task<()> {
1663 let tool_name: Arc<str> = tool.name().into();
1664
1665 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1666 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1667 } else {
1668 tool.run(
1669 input,
1670 messages,
1671 self.project.clone(),
1672 self.action_log.clone(),
1673 cx,
1674 )
1675 };
1676
1677 // Store the card separately if it exists
1678 if let Some(card) = tool_result.card.clone() {
1679 self.tool_use
1680 .insert_tool_result_card(tool_use_id.clone(), card);
1681 }
1682
1683 cx.spawn({
1684 async move |thread: WeakEntity<Thread>, cx| {
1685 let output = tool_result.output.await;
1686
1687 thread
1688 .update(cx, |thread, cx| {
1689 let pending_tool_use = thread.tool_use.insert_tool_output(
1690 tool_use_id.clone(),
1691 tool_name,
1692 output,
1693 cx,
1694 );
1695 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1696 })
1697 .ok();
1698 }
1699 })
1700 }
1701
1702 fn tool_finished(
1703 &mut self,
1704 tool_use_id: LanguageModelToolUseId,
1705 pending_tool_use: Option<PendingToolUse>,
1706 canceled: bool,
1707 cx: &mut Context<Self>,
1708 ) {
1709 if self.all_tools_finished() {
1710 let model_registry = LanguageModelRegistry::read_global(cx);
1711 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1712 self.attach_tool_results(cx);
1713 if !canceled {
1714 self.send_to_model(model, cx);
1715 }
1716 }
1717 }
1718
1719 cx.emit(ThreadEvent::ToolFinished {
1720 tool_use_id,
1721 pending_tool_use,
1722 });
1723 }
1724
1725 /// Insert an empty message to be populated with tool results upon send.
1726 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1727 // Tool results are assumed to be waiting on the next message id, so they will populate
1728 // this empty message before sending to model. Would prefer this to be more straightforward.
1729 self.insert_message(Role::User, vec![], cx);
1730 self.auto_capture_telemetry(cx);
1731 }
1732
1733 /// Cancels the last pending completion, if there are any pending.
1734 ///
1735 /// Returns whether a completion was canceled.
1736 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1737 let canceled = if self.pending_completions.pop().is_some() {
1738 true
1739 } else {
1740 let mut canceled = false;
1741 for pending_tool_use in self.tool_use.cancel_pending() {
1742 canceled = true;
1743 self.tool_finished(
1744 pending_tool_use.id.clone(),
1745 Some(pending_tool_use),
1746 true,
1747 cx,
1748 );
1749 }
1750 canceled
1751 };
1752 self.finalize_pending_checkpoint(cx);
1753 canceled
1754 }
1755
1756 pub fn feedback(&self) -> Option<ThreadFeedback> {
1757 self.feedback
1758 }
1759
1760 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1761 self.message_feedback.get(&message_id).copied()
1762 }
1763
1764 pub fn report_message_feedback(
1765 &mut self,
1766 message_id: MessageId,
1767 feedback: ThreadFeedback,
1768 cx: &mut Context<Self>,
1769 ) -> Task<Result<()>> {
1770 if self.message_feedback.get(&message_id) == Some(&feedback) {
1771 return Task::ready(Ok(()));
1772 }
1773
1774 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1775 let serialized_thread = self.serialize(cx);
1776 let thread_id = self.id().clone();
1777 let client = self.project.read(cx).client();
1778
1779 let enabled_tool_names: Vec<String> = self
1780 .tools()
1781 .read(cx)
1782 .enabled_tools(cx)
1783 .iter()
1784 .map(|tool| tool.name().to_string())
1785 .collect();
1786
1787 self.message_feedback.insert(message_id, feedback);
1788
1789 cx.notify();
1790
1791 let message_content = self
1792 .message(message_id)
1793 .map(|msg| msg.to_string())
1794 .unwrap_or_default();
1795
1796 cx.background_spawn(async move {
1797 let final_project_snapshot = final_project_snapshot.await;
1798 let serialized_thread = serialized_thread.await?;
1799 let thread_data =
1800 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1801
1802 let rating = match feedback {
1803 ThreadFeedback::Positive => "positive",
1804 ThreadFeedback::Negative => "negative",
1805 };
1806 telemetry::event!(
1807 "Assistant Thread Rated",
1808 rating,
1809 thread_id,
1810 enabled_tool_names,
1811 message_id = message_id.0,
1812 message_content,
1813 thread_data,
1814 final_project_snapshot
1815 );
1816 client.telemetry().flush_events().await;
1817
1818 Ok(())
1819 })
1820 }
1821
1822 pub fn report_feedback(
1823 &mut self,
1824 feedback: ThreadFeedback,
1825 cx: &mut Context<Self>,
1826 ) -> Task<Result<()>> {
1827 let last_assistant_message_id = self
1828 .messages
1829 .iter()
1830 .rev()
1831 .find(|msg| msg.role == Role::Assistant)
1832 .map(|msg| msg.id);
1833
1834 if let Some(message_id) = last_assistant_message_id {
1835 self.report_message_feedback(message_id, feedback, cx)
1836 } else {
1837 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1838 let serialized_thread = self.serialize(cx);
1839 let thread_id = self.id().clone();
1840 let client = self.project.read(cx).client();
1841 self.feedback = Some(feedback);
1842 cx.notify();
1843
1844 cx.background_spawn(async move {
1845 let final_project_snapshot = final_project_snapshot.await;
1846 let serialized_thread = serialized_thread.await?;
1847 let thread_data = serde_json::to_value(serialized_thread)
1848 .unwrap_or_else(|_| serde_json::Value::Null);
1849
1850 let rating = match feedback {
1851 ThreadFeedback::Positive => "positive",
1852 ThreadFeedback::Negative => "negative",
1853 };
1854 telemetry::event!(
1855 "Assistant Thread Rated",
1856 rating,
1857 thread_id,
1858 thread_data,
1859 final_project_snapshot
1860 );
1861 client.telemetry().flush_events().await;
1862
1863 Ok(())
1864 })
1865 }
1866 }
1867
1868 /// Create a snapshot of the current project state including git information and unsaved buffers.
1869 fn project_snapshot(
1870 project: Entity<Project>,
1871 cx: &mut Context<Self>,
1872 ) -> Task<Arc<ProjectSnapshot>> {
1873 let git_store = project.read(cx).git_store().clone();
1874 let worktree_snapshots: Vec<_> = project
1875 .read(cx)
1876 .visible_worktrees(cx)
1877 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1878 .collect();
1879
1880 cx.spawn(async move |_, cx| {
1881 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1882
1883 let mut unsaved_buffers = Vec::new();
1884 cx.update(|app_cx| {
1885 let buffer_store = project.read(app_cx).buffer_store();
1886 for buffer_handle in buffer_store.read(app_cx).buffers() {
1887 let buffer = buffer_handle.read(app_cx);
1888 if buffer.is_dirty() {
1889 if let Some(file) = buffer.file() {
1890 let path = file.path().to_string_lossy().to_string();
1891 unsaved_buffers.push(path);
1892 }
1893 }
1894 }
1895 })
1896 .ok();
1897
1898 Arc::new(ProjectSnapshot {
1899 worktree_snapshots,
1900 unsaved_buffer_paths: unsaved_buffers,
1901 timestamp: Utc::now(),
1902 })
1903 })
1904 }
1905
1906 fn worktree_snapshot(
1907 worktree: Entity<project::Worktree>,
1908 git_store: Entity<GitStore>,
1909 cx: &App,
1910 ) -> Task<WorktreeSnapshot> {
1911 cx.spawn(async move |cx| {
1912 // Get worktree path and snapshot
1913 let worktree_info = cx.update(|app_cx| {
1914 let worktree = worktree.read(app_cx);
1915 let path = worktree.abs_path().to_string_lossy().to_string();
1916 let snapshot = worktree.snapshot();
1917 (path, snapshot)
1918 });
1919
1920 let Ok((worktree_path, _snapshot)) = worktree_info else {
1921 return WorktreeSnapshot {
1922 worktree_path: String::new(),
1923 git_state: None,
1924 };
1925 };
1926
1927 let git_state = git_store
1928 .update(cx, |git_store, cx| {
1929 git_store
1930 .repositories()
1931 .values()
1932 .find(|repo| {
1933 repo.read(cx)
1934 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1935 .is_some()
1936 })
1937 .cloned()
1938 })
1939 .ok()
1940 .flatten()
1941 .map(|repo| {
1942 repo.update(cx, |repo, _| {
1943 let current_branch =
1944 repo.branch.as_ref().map(|branch| branch.name.to_string());
1945 repo.send_job(None, |state, _| async move {
1946 let RepositoryState::Local { backend, .. } = state else {
1947 return GitState {
1948 remote_url: None,
1949 head_sha: None,
1950 current_branch,
1951 diff: None,
1952 };
1953 };
1954
1955 let remote_url = backend.remote_url("origin");
1956 let head_sha = backend.head_sha();
1957 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1958
1959 GitState {
1960 remote_url,
1961 head_sha,
1962 current_branch,
1963 diff,
1964 }
1965 })
1966 })
1967 });
1968
1969 let git_state = match git_state {
1970 Some(git_state) => match git_state.ok() {
1971 Some(git_state) => git_state.await.ok(),
1972 None => None,
1973 },
1974 None => None,
1975 };
1976
1977 WorktreeSnapshot {
1978 worktree_path,
1979 git_state,
1980 }
1981 })
1982 }
1983
1984 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1985 let mut markdown = Vec::new();
1986
1987 if let Some(summary) = self.summary() {
1988 writeln!(markdown, "# {summary}\n")?;
1989 };
1990
1991 for message in self.messages() {
1992 writeln!(
1993 markdown,
1994 "## {role}\n",
1995 role = match message.role {
1996 Role::User => "User",
1997 Role::Assistant => "Assistant",
1998 Role::System => "System",
1999 }
2000 )?;
2001
2002 if !message.context.is_empty() {
2003 writeln!(markdown, "{}", message.context)?;
2004 }
2005
2006 for segment in &message.segments {
2007 match segment {
2008 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2009 MessageSegment::Thinking { text, .. } => {
2010 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2011 }
2012 MessageSegment::RedactedThinking(_) => {}
2013 }
2014 }
2015
2016 for tool_use in self.tool_uses_for_message(message.id, cx) {
2017 writeln!(
2018 markdown,
2019 "**Use Tool: {} ({})**",
2020 tool_use.name, tool_use.id
2021 )?;
2022 writeln!(markdown, "```json")?;
2023 writeln!(
2024 markdown,
2025 "{}",
2026 serde_json::to_string_pretty(&tool_use.input)?
2027 )?;
2028 writeln!(markdown, "```")?;
2029 }
2030
2031 for tool_result in self.tool_results_for_message(message.id) {
2032 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2033 if tool_result.is_error {
2034 write!(markdown, " (Error)")?;
2035 }
2036
2037 writeln!(markdown, "**\n")?;
2038 writeln!(markdown, "{}", tool_result.content)?;
2039 }
2040 }
2041
2042 Ok(String::from_utf8_lossy(&markdown).to_string())
2043 }
2044
2045 pub fn keep_edits_in_range(
2046 &mut self,
2047 buffer: Entity<language::Buffer>,
2048 buffer_range: Range<language::Anchor>,
2049 cx: &mut Context<Self>,
2050 ) {
2051 self.action_log.update(cx, |action_log, cx| {
2052 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2053 });
2054 }
2055
2056 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2057 self.action_log
2058 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2059 }
2060
2061 pub fn reject_edits_in_ranges(
2062 &mut self,
2063 buffer: Entity<language::Buffer>,
2064 buffer_ranges: Vec<Range<language::Anchor>>,
2065 cx: &mut Context<Self>,
2066 ) -> Task<Result<()>> {
2067 self.action_log.update(cx, |action_log, cx| {
2068 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2069 })
2070 }
2071
2072 pub fn action_log(&self) -> &Entity<ActionLog> {
2073 &self.action_log
2074 }
2075
2076 pub fn project(&self) -> &Entity<Project> {
2077 &self.project
2078 }
2079
2080 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2081 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2082 return;
2083 }
2084
2085 let now = Instant::now();
2086 if let Some(last) = self.last_auto_capture_at {
2087 if now.duration_since(last).as_secs() < 10 {
2088 return;
2089 }
2090 }
2091
2092 self.last_auto_capture_at = Some(now);
2093
2094 let thread_id = self.id().clone();
2095 let github_login = self
2096 .project
2097 .read(cx)
2098 .user_store()
2099 .read(cx)
2100 .current_user()
2101 .map(|user| user.github_login.clone());
2102 let client = self.project.read(cx).client().clone();
2103 let serialize_task = self.serialize(cx);
2104
2105 cx.background_executor()
2106 .spawn(async move {
2107 if let Ok(serialized_thread) = serialize_task.await {
2108 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2109 telemetry::event!(
2110 "Agent Thread Auto-Captured",
2111 thread_id = thread_id.to_string(),
2112 thread_data = thread_data,
2113 auto_capture_reason = "tracked_user",
2114 github_login = github_login
2115 );
2116
2117 client.telemetry().flush_events().await;
2118 }
2119 }
2120 })
2121 .detach();
2122 }
2123
2124 pub fn cumulative_token_usage(&self) -> TokenUsage {
2125 self.cumulative_token_usage
2126 }
2127
2128 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2129 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2130 return TotalTokenUsage::default();
2131 };
2132
2133 let max = model.model.max_token_count();
2134
2135 let index = self
2136 .messages
2137 .iter()
2138 .position(|msg| msg.id == message_id)
2139 .unwrap_or(0);
2140
2141 if index == 0 {
2142 return TotalTokenUsage { total: 0, max };
2143 }
2144
2145 let token_usage = &self
2146 .request_token_usage
2147 .get(index - 1)
2148 .cloned()
2149 .unwrap_or_default();
2150
2151 TotalTokenUsage {
2152 total: token_usage.total_tokens() as usize,
2153 max,
2154 }
2155 }
2156
2157 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2158 let model_registry = LanguageModelRegistry::read_global(cx);
2159 let Some(model) = model_registry.default_model() else {
2160 return TotalTokenUsage::default();
2161 };
2162
2163 let max = model.model.max_token_count();
2164
2165 if let Some(exceeded_error) = &self.exceeded_window_error {
2166 if model.model.id() == exceeded_error.model_id {
2167 return TotalTokenUsage {
2168 total: exceeded_error.token_count,
2169 max,
2170 };
2171 }
2172 }
2173
2174 let total = self
2175 .token_usage_at_last_message()
2176 .unwrap_or_default()
2177 .total_tokens() as usize;
2178
2179 TotalTokenUsage { total, max }
2180 }
2181
2182 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2183 self.request_token_usage
2184 .get(self.messages.len().saturating_sub(1))
2185 .or_else(|| self.request_token_usage.last())
2186 .cloned()
2187 }
2188
2189 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2190 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2191 self.request_token_usage
2192 .resize(self.messages.len(), placeholder);
2193
2194 if let Some(last) = self.request_token_usage.last_mut() {
2195 *last = token_usage;
2196 }
2197 }
2198
2199 pub fn deny_tool_use(
2200 &mut self,
2201 tool_use_id: LanguageModelToolUseId,
2202 tool_name: Arc<str>,
2203 cx: &mut Context<Self>,
2204 ) {
2205 let err = Err(anyhow::anyhow!(
2206 "Permission to run tool action denied by user"
2207 ));
2208
2209 self.tool_use
2210 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2211 self.tool_finished(tool_use_id.clone(), None, true, cx);
2212 }
2213}
2214
2215#[derive(Debug, Clone, Error)]
2216pub enum ThreadError {
2217 #[error("Payment required")]
2218 PaymentRequired,
2219 #[error("Max monthly spend reached")]
2220 MaxMonthlySpendReached,
2221 #[error("Model request limit reached")]
2222 ModelRequestLimitReached { plan: Plan },
2223 #[error("Message {header}: {message}")]
2224 Message {
2225 header: SharedString,
2226 message: SharedString,
2227 },
2228}
2229
2230#[derive(Debug, Clone)]
2231pub enum ThreadEvent {
2232 ShowError(ThreadError),
2233 UsageUpdated(RequestUsage),
2234 StreamedCompletion,
2235 ReceivedTextChunk,
2236 StreamedAssistantText(MessageId, String),
2237 StreamedAssistantThinking(MessageId, String),
2238 StreamedToolUse {
2239 tool_use_id: LanguageModelToolUseId,
2240 ui_text: Arc<str>,
2241 input: serde_json::Value,
2242 },
2243 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2244 MessageAdded(MessageId),
2245 MessageEdited(MessageId),
2246 MessageDeleted(MessageId),
2247 SummaryGenerated,
2248 SummaryChanged,
2249 UsePendingTools {
2250 tool_uses: Vec<PendingToolUse>,
2251 },
2252 ToolFinished {
2253 #[allow(unused)]
2254 tool_use_id: LanguageModelToolUseId,
2255 /// The pending tool use that corresponds to this tool.
2256 pending_tool_use: Option<PendingToolUse>,
2257 },
2258 CheckpointChanged,
2259 ToolConfirmationNeeded,
2260}
2261
2262impl EventEmitter<ThreadEvent> for Thread {}
2263
2264struct PendingCompletion {
2265 id: usize,
2266 _task: Task<()>,
2267}
2268
2269#[cfg(test)]
2270mod tests {
2271 use super::*;
2272 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2273 use assistant_settings::AssistantSettings;
2274 use context_server::ContextServerSettings;
2275 use editor::EditorSettings;
2276 use gpui::TestAppContext;
2277 use project::{FakeFs, Project};
2278 use prompt_store::PromptBuilder;
2279 use serde_json::json;
2280 use settings::{Settings, SettingsStore};
2281 use std::sync::Arc;
2282 use theme::ThemeSettings;
2283 use util::path;
2284 use workspace::Workspace;
2285
2286 #[gpui::test]
2287 async fn test_message_with_context(cx: &mut TestAppContext) {
2288 init_test_settings(cx);
2289
2290 let project = create_test_project(
2291 cx,
2292 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2293 )
2294 .await;
2295
2296 let (_workspace, _thread_store, thread, context_store) =
2297 setup_test_environment(cx, project.clone()).await;
2298
2299 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2300 .await
2301 .unwrap();
2302
2303 let context =
2304 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2305
2306 // Insert user message with context
2307 let message_id = thread.update(cx, |thread, cx| {
2308 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2309 });
2310
2311 // Check content and context in message object
2312 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2313
2314 // Use different path format strings based on platform for the test
2315 #[cfg(windows)]
2316 let path_part = r"test\code.rs";
2317 #[cfg(not(windows))]
2318 let path_part = "test/code.rs";
2319
2320 let expected_context = format!(
2321 r#"
2322<context>
2323The following items were attached by the user. You don't need to use other tools to read them.
2324
2325<files>
2326```rs {path_part}
2327fn main() {{
2328 println!("Hello, world!");
2329}}
2330```
2331</files>
2332</context>
2333"#
2334 );
2335
2336 assert_eq!(message.role, Role::User);
2337 assert_eq!(message.segments.len(), 1);
2338 assert_eq!(
2339 message.segments[0],
2340 MessageSegment::Text("Please explain this code".to_string())
2341 );
2342 assert_eq!(message.context, expected_context);
2343
2344 // Check message in request
2345 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2346
2347 assert_eq!(request.messages.len(), 2);
2348 let expected_full_message = format!("{}Please explain this code", expected_context);
2349 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2350 }
2351
2352 #[gpui::test]
2353 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2354 init_test_settings(cx);
2355
2356 let project = create_test_project(
2357 cx,
2358 json!({
2359 "file1.rs": "fn function1() {}\n",
2360 "file2.rs": "fn function2() {}\n",
2361 "file3.rs": "fn function3() {}\n",
2362 }),
2363 )
2364 .await;
2365
2366 let (_, _thread_store, thread, context_store) =
2367 setup_test_environment(cx, project.clone()).await;
2368
2369 // Open files individually
2370 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2371 .await
2372 .unwrap();
2373 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2374 .await
2375 .unwrap();
2376 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2377 .await
2378 .unwrap();
2379
2380 // Get the context objects
2381 let contexts = context_store.update(cx, |store, _| store.context().clone());
2382 assert_eq!(contexts.len(), 3);
2383
2384 // First message with context 1
2385 let message1_id = thread.update(cx, |thread, cx| {
2386 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2387 });
2388
2389 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2390 let message2_id = thread.update(cx, |thread, cx| {
2391 thread.insert_user_message(
2392 "Message 2",
2393 vec![contexts[0].clone(), contexts[1].clone()],
2394 None,
2395 cx,
2396 )
2397 });
2398
2399 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2400 let message3_id = thread.update(cx, |thread, cx| {
2401 thread.insert_user_message(
2402 "Message 3",
2403 vec![
2404 contexts[0].clone(),
2405 contexts[1].clone(),
2406 contexts[2].clone(),
2407 ],
2408 None,
2409 cx,
2410 )
2411 });
2412
2413 // Check what contexts are included in each message
2414 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2415 (
2416 thread.message(message1_id).unwrap().clone(),
2417 thread.message(message2_id).unwrap().clone(),
2418 thread.message(message3_id).unwrap().clone(),
2419 )
2420 });
2421
2422 // First message should include context 1
2423 assert!(message1.context.contains("file1.rs"));
2424
2425 // Second message should include only context 2 (not 1)
2426 assert!(!message2.context.contains("file1.rs"));
2427 assert!(message2.context.contains("file2.rs"));
2428
2429 // Third message should include only context 3 (not 1 or 2)
2430 assert!(!message3.context.contains("file1.rs"));
2431 assert!(!message3.context.contains("file2.rs"));
2432 assert!(message3.context.contains("file3.rs"));
2433
2434 // Check entire request to make sure all contexts are properly included
2435 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2436
2437 // The request should contain all 3 messages
2438 assert_eq!(request.messages.len(), 4);
2439
2440 // Check that the contexts are properly formatted in each message
2441 assert!(request.messages[1].string_contents().contains("file1.rs"));
2442 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2443 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2444
2445 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2446 assert!(request.messages[2].string_contents().contains("file2.rs"));
2447 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2448
2449 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2450 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2451 assert!(request.messages[3].string_contents().contains("file3.rs"));
2452 }
2453
2454 #[gpui::test]
2455 async fn test_message_without_files(cx: &mut TestAppContext) {
2456 init_test_settings(cx);
2457
2458 let project = create_test_project(
2459 cx,
2460 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2461 )
2462 .await;
2463
2464 let (_, _thread_store, thread, _context_store) =
2465 setup_test_environment(cx, project.clone()).await;
2466
2467 // Insert user message without any context (empty context vector)
2468 let message_id = thread.update(cx, |thread, cx| {
2469 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2470 });
2471
2472 // Check content and context in message object
2473 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2474
2475 // Context should be empty when no files are included
2476 assert_eq!(message.role, Role::User);
2477 assert_eq!(message.segments.len(), 1);
2478 assert_eq!(
2479 message.segments[0],
2480 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2481 );
2482 assert_eq!(message.context, "");
2483
2484 // Check message in request
2485 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2486
2487 assert_eq!(request.messages.len(), 2);
2488 assert_eq!(
2489 request.messages[1].string_contents(),
2490 "What is the best way to learn Rust?"
2491 );
2492
2493 // Add second message, also without context
2494 let message2_id = thread.update(cx, |thread, cx| {
2495 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2496 });
2497
2498 let message2 =
2499 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2500 assert_eq!(message2.context, "");
2501
2502 // Check that both messages appear in the request
2503 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2504
2505 assert_eq!(request.messages.len(), 3);
2506 assert_eq!(
2507 request.messages[1].string_contents(),
2508 "What is the best way to learn Rust?"
2509 );
2510 assert_eq!(
2511 request.messages[2].string_contents(),
2512 "Are there any good books?"
2513 );
2514 }
2515
2516 #[gpui::test]
2517 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2518 init_test_settings(cx);
2519
2520 let project = create_test_project(
2521 cx,
2522 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2523 )
2524 .await;
2525
2526 let (_workspace, _thread_store, thread, context_store) =
2527 setup_test_environment(cx, project.clone()).await;
2528
2529 // Open buffer and add it to context
2530 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2531 .await
2532 .unwrap();
2533
2534 let context =
2535 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2536
2537 // Insert user message with the buffer as context
2538 thread.update(cx, |thread, cx| {
2539 thread.insert_user_message("Explain this code", vec![context], None, cx)
2540 });
2541
2542 // Create a request and check that it doesn't have a stale buffer warning yet
2543 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2544
2545 // Make sure we don't have a stale file warning yet
2546 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2547 msg.string_contents()
2548 .contains("These files changed since last read:")
2549 });
2550 assert!(
2551 !has_stale_warning,
2552 "Should not have stale buffer warning before buffer is modified"
2553 );
2554
2555 // Modify the buffer
2556 buffer.update(cx, |buffer, cx| {
2557 // Find a position at the end of line 1
2558 buffer.edit(
2559 [(1..1, "\n println!(\"Added a new line\");\n")],
2560 None,
2561 cx,
2562 );
2563 });
2564
2565 // Insert another user message without context
2566 thread.update(cx, |thread, cx| {
2567 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2568 });
2569
2570 // Create a new request and check for the stale buffer warning
2571 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2572
2573 // We should have a stale file warning as the last message
2574 let last_message = new_request
2575 .messages
2576 .last()
2577 .expect("Request should have messages");
2578
2579 // The last message should be the stale buffer notification
2580 assert_eq!(last_message.role, Role::User);
2581
2582 // Check the exact content of the message
2583 let expected_content = "These files changed since last read:\n- code.rs\n";
2584 assert_eq!(
2585 last_message.string_contents(),
2586 expected_content,
2587 "Last message should be exactly the stale buffer notification"
2588 );
2589 }
2590
2591 fn init_test_settings(cx: &mut TestAppContext) {
2592 cx.update(|cx| {
2593 let settings_store = SettingsStore::test(cx);
2594 cx.set_global(settings_store);
2595 language::init(cx);
2596 Project::init_settings(cx);
2597 AssistantSettings::register(cx);
2598 prompt_store::init(cx);
2599 thread_store::init(cx);
2600 workspace::init_settings(cx);
2601 ThemeSettings::register(cx);
2602 ContextServerSettings::register(cx);
2603 EditorSettings::register(cx);
2604 });
2605 }
2606
2607 // Helper to create a test project with test files
2608 async fn create_test_project(
2609 cx: &mut TestAppContext,
2610 files: serde_json::Value,
2611 ) -> Entity<Project> {
2612 let fs = FakeFs::new(cx.executor());
2613 fs.insert_tree(path!("/test"), files).await;
2614 Project::test(fs, [path!("/test").as_ref()], cx).await
2615 }
2616
2617 async fn setup_test_environment(
2618 cx: &mut TestAppContext,
2619 project: Entity<Project>,
2620 ) -> (
2621 Entity<Workspace>,
2622 Entity<ThreadStore>,
2623 Entity<Thread>,
2624 Entity<ContextStore>,
2625 ) {
2626 let (workspace, cx) =
2627 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2628
2629 let thread_store = cx
2630 .update(|_, cx| {
2631 ThreadStore::load(
2632 project.clone(),
2633 cx.new(|_| ToolWorkingSet::default()),
2634 Arc::new(PromptBuilder::new(None).unwrap()),
2635 cx,
2636 )
2637 })
2638 .await
2639 .unwrap();
2640
2641 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2642 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2643
2644 (workspace, thread_store, thread, context_store)
2645 }
2646
2647 async fn add_file_to_context(
2648 project: &Entity<Project>,
2649 context_store: &Entity<ContextStore>,
2650 path: &str,
2651 cx: &mut TestAppContext,
2652 ) -> Result<Entity<language::Buffer>> {
2653 let buffer_path = project
2654 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2655 .unwrap();
2656
2657 let buffer = project
2658 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2659 .await
2660 .unwrap();
2661
2662 context_store
2663 .update(cx, |store, cx| {
2664 store.add_file_from_buffer(buffer.clone(), cx)
2665 })
2666 .await?;
2667
2668 Ok(buffer)
2669 }
2670}