1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100 pub images: Vec<LanguageModelImage>,
101}
102
103impl Message {
104 /// Returns whether the message contains any meaningful text that should be displayed
105 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
106 pub fn should_display_content(&self) -> bool {
107 self.segments.iter().all(|segment| segment.should_display())
108 }
109
110 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
111 if let Some(MessageSegment::Thinking {
112 text: segment,
113 signature: current_signature,
114 }) = self.segments.last_mut()
115 {
116 if let Some(signature) = signature {
117 *current_signature = Some(signature);
118 }
119 segment.push_str(text);
120 } else {
121 self.segments.push(MessageSegment::Thinking {
122 text: text.to_string(),
123 signature,
124 });
125 }
126 }
127
128 pub fn push_text(&mut self, text: &str) {
129 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
130 segment.push_str(text);
131 } else {
132 self.segments.push(MessageSegment::Text(text.to_string()));
133 }
134 }
135
136 pub fn to_string(&self) -> String {
137 let mut result = String::new();
138
139 if !self.context.is_empty() {
140 result.push_str(&self.context);
141 }
142
143 for segment in &self.segments {
144 match segment {
145 MessageSegment::Text(text) => result.push_str(text),
146 MessageSegment::Thinking { text, .. } => {
147 result.push_str("<think>\n");
148 result.push_str(text);
149 result.push_str("\n</think>");
150 }
151 MessageSegment::RedactedThinking(_) => {}
152 }
153 }
154
155 result
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub enum MessageSegment {
161 Text(String),
162 Thinking {
163 text: String,
164 signature: Option<String>,
165 },
166 RedactedThinking(Vec<u8>),
167}
168
169impl MessageSegment {
170 pub fn should_display(&self) -> bool {
171 match self {
172 Self::Text(text) => text.is_empty(),
173 Self::Thinking { text, .. } => text.is_empty(),
174 Self::RedactedThinking(_) => false,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct ProjectSnapshot {
181 pub worktree_snapshots: Vec<WorktreeSnapshot>,
182 pub unsaved_buffer_paths: Vec<String>,
183 pub timestamp: DateTime<Utc>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct WorktreeSnapshot {
188 pub worktree_path: String,
189 pub git_state: Option<GitState>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct GitState {
194 pub remote_url: Option<String>,
195 pub head_sha: Option<String>,
196 pub current_branch: Option<String>,
197 pub diff: Option<String>,
198}
199
200#[derive(Clone)]
201pub struct ThreadCheckpoint {
202 message_id: MessageId,
203 git_checkpoint: GitStoreCheckpoint,
204}
205
206#[derive(Copy, Clone, Debug, PartialEq, Eq)]
207pub enum ThreadFeedback {
208 Positive,
209 Negative,
210}
211
212pub enum LastRestoreCheckpoint {
213 Pending {
214 message_id: MessageId,
215 },
216 Error {
217 message_id: MessageId,
218 error: String,
219 },
220}
221
222impl LastRestoreCheckpoint {
223 pub fn message_id(&self) -> MessageId {
224 match self {
225 LastRestoreCheckpoint::Pending { message_id } => *message_id,
226 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
227 }
228 }
229}
230
231#[derive(Clone, Debug, Default, Serialize, Deserialize)]
232pub enum DetailedSummaryState {
233 #[default]
234 NotGenerated,
235 Generating {
236 message_id: MessageId,
237 },
238 Generated {
239 text: SharedString,
240 message_id: MessageId,
241 },
242}
243
244#[derive(Default)]
245pub struct TotalTokenUsage {
246 pub total: usize,
247 pub max: usize,
248}
249
250impl TotalTokenUsage {
251 pub fn ratio(&self) -> TokenUsageRatio {
252 #[cfg(debug_assertions)]
253 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
254 .unwrap_or("0.8".to_string())
255 .parse()
256 .unwrap();
257 #[cfg(not(debug_assertions))]
258 let warning_threshold: f32 = 0.8;
259
260 if self.total >= self.max {
261 TokenUsageRatio::Exceeded
262 } else if self.total as f32 / self.max as f32 >= warning_threshold {
263 TokenUsageRatio::Warning
264 } else {
265 TokenUsageRatio::Normal
266 }
267 }
268
269 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
270 TotalTokenUsage {
271 total: self.total + tokens,
272 max: self.max,
273 }
274 }
275}
276
277#[derive(Debug, Default, PartialEq, Eq)]
278pub enum TokenUsageRatio {
279 #[default]
280 Normal,
281 Warning,
282 Exceeded,
283}
284
285/// A thread of conversation with the LLM.
286pub struct Thread {
287 id: ThreadId,
288 updated_at: DateTime<Utc>,
289 summary: Option<SharedString>,
290 pending_summary: Task<Option<()>>,
291 detailed_summary_state: DetailedSummaryState,
292 messages: Vec<Message>,
293 next_message_id: MessageId,
294 last_prompt_id: PromptId,
295 context: BTreeMap<ContextId, AssistantContext>,
296 context_by_message: HashMap<MessageId, Vec<ContextId>>,
297 project_context: SharedProjectContext,
298 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
299 completion_count: usize,
300 pending_completions: Vec<PendingCompletion>,
301 project: Entity<Project>,
302 prompt_builder: Arc<PromptBuilder>,
303 tools: Entity<ToolWorkingSet>,
304 tool_use: ToolUseState,
305 action_log: Entity<ActionLog>,
306 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
307 pending_checkpoint: Option<ThreadCheckpoint>,
308 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
309 request_token_usage: Vec<TokenUsage>,
310 cumulative_token_usage: TokenUsage,
311 exceeded_window_error: Option<ExceededWindowError>,
312 feedback: Option<ThreadFeedback>,
313 message_feedback: HashMap<MessageId, ThreadFeedback>,
314 last_auto_capture_at: Option<Instant>,
315 request_callback: Option<
316 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
317 >,
318 remaining_turns: u32,
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct ExceededWindowError {
323 /// Model used when last message exceeded context window
324 model_id: LanguageModelId,
325 /// Token count including last message
326 token_count: usize,
327}
328
329impl Thread {
330 pub fn new(
331 project: Entity<Project>,
332 tools: Entity<ToolWorkingSet>,
333 prompt_builder: Arc<PromptBuilder>,
334 system_prompt: SharedProjectContext,
335 cx: &mut Context<Self>,
336 ) -> Self {
337 Self {
338 id: ThreadId::new(),
339 updated_at: Utc::now(),
340 summary: None,
341 pending_summary: Task::ready(None),
342 detailed_summary_state: DetailedSummaryState::NotGenerated,
343 messages: Vec::new(),
344 next_message_id: MessageId(0),
345 last_prompt_id: PromptId::new(),
346 context: BTreeMap::default(),
347 context_by_message: HashMap::default(),
348 project_context: system_prompt,
349 checkpoints_by_message: HashMap::default(),
350 completion_count: 0,
351 pending_completions: Vec::new(),
352 project: project.clone(),
353 prompt_builder,
354 tools: tools.clone(),
355 last_restore_checkpoint: None,
356 pending_checkpoint: None,
357 tool_use: ToolUseState::new(tools.clone()),
358 action_log: cx.new(|_| ActionLog::new(project.clone())),
359 initial_project_snapshot: {
360 let project_snapshot = Self::project_snapshot(project, cx);
361 cx.foreground_executor()
362 .spawn(async move { Some(project_snapshot.await) })
363 .shared()
364 },
365 request_token_usage: Vec::new(),
366 cumulative_token_usage: TokenUsage::default(),
367 exceeded_window_error: None,
368 feedback: None,
369 message_feedback: HashMap::default(),
370 last_auto_capture_at: None,
371 request_callback: None,
372 remaining_turns: u32::MAX,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use =
393 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
394
395 Self {
396 id,
397 updated_at: serialized.updated_at,
398 summary: Some(serialized.summary),
399 pending_summary: Task::ready(None),
400 detailed_summary_state: serialized.detailed_summary_state,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 context: message.context,
421 images: Vec::new(),
422 })
423 .collect(),
424 next_message_id,
425 last_prompt_id: PromptId::new(),
426 context: BTreeMap::default(),
427 context_by_message: HashMap::default(),
428 project_context,
429 checkpoints_by_message: HashMap::default(),
430 completion_count: 0,
431 pending_completions: Vec::new(),
432 last_restore_checkpoint: None,
433 pending_checkpoint: None,
434 project: project.clone(),
435 prompt_builder,
436 tools,
437 tool_use,
438 action_log: cx.new(|_| ActionLog::new(project)),
439 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
440 request_token_usage: serialized.request_token_usage,
441 cumulative_token_usage: serialized.cumulative_token_usage,
442 exceeded_window_error: None,
443 feedback: None,
444 message_feedback: HashMap::default(),
445 last_auto_capture_at: None,
446 request_callback: None,
447 remaining_turns: u32::MAX,
448 }
449 }
450
451 pub fn set_request_callback(
452 &mut self,
453 callback: impl 'static
454 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
455 ) {
456 self.request_callback = Some(Box::new(callback));
457 }
458
459 pub fn id(&self) -> &ThreadId {
460 &self.id
461 }
462
463 pub fn is_empty(&self) -> bool {
464 self.messages.is_empty()
465 }
466
467 pub fn updated_at(&self) -> DateTime<Utc> {
468 self.updated_at
469 }
470
471 pub fn touch_updated_at(&mut self) {
472 self.updated_at = Utc::now();
473 }
474
475 pub fn advance_prompt_id(&mut self) {
476 self.last_prompt_id = PromptId::new();
477 }
478
479 pub fn summary(&self) -> Option<SharedString> {
480 self.summary.clone()
481 }
482
483 pub fn project_context(&self) -> SharedProjectContext {
484 self.project_context.clone()
485 }
486
487 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
488
489 pub fn summary_or_default(&self) -> SharedString {
490 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
491 }
492
493 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
494 let Some(current_summary) = &self.summary else {
495 // Don't allow setting summary until generated
496 return;
497 };
498
499 let mut new_summary = new_summary.into();
500
501 if new_summary.is_empty() {
502 new_summary = Self::DEFAULT_SUMMARY;
503 }
504
505 if current_summary != &new_summary {
506 self.summary = Some(new_summary);
507 cx.emit(ThreadEvent::SummaryChanged);
508 }
509 }
510
511 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
512 self.latest_detailed_summary()
513 .unwrap_or_else(|| self.text().into())
514 }
515
516 fn latest_detailed_summary(&self) -> Option<SharedString> {
517 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
518 Some(text.clone())
519 } else {
520 None
521 }
522 }
523
524 pub fn message(&self, id: MessageId) -> Option<&Message> {
525 self.messages.iter().find(|message| message.id == id)
526 }
527
528 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
529 self.messages.iter()
530 }
531
532 pub fn is_generating(&self) -> bool {
533 !self.pending_completions.is_empty() || !self.all_tools_finished()
534 }
535
536 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
537 &self.tools
538 }
539
540 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
541 self.tool_use
542 .pending_tool_uses()
543 .into_iter()
544 .find(|tool_use| &tool_use.id == id)
545 }
546
547 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
548 self.tool_use
549 .pending_tool_uses()
550 .into_iter()
551 .filter(|tool_use| tool_use.status.needs_confirmation())
552 }
553
554 pub fn has_pending_tool_uses(&self) -> bool {
555 !self.tool_use.pending_tool_uses().is_empty()
556 }
557
558 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
559 self.checkpoints_by_message.get(&id).cloned()
560 }
561
562 pub fn restore_checkpoint(
563 &mut self,
564 checkpoint: ThreadCheckpoint,
565 cx: &mut Context<Self>,
566 ) -> Task<Result<()>> {
567 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
568 message_id: checkpoint.message_id,
569 });
570 cx.emit(ThreadEvent::CheckpointChanged);
571 cx.notify();
572
573 let git_store = self.project().read(cx).git_store().clone();
574 let restore = git_store.update(cx, |git_store, cx| {
575 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
576 });
577
578 cx.spawn(async move |this, cx| {
579 let result = restore.await;
580 this.update(cx, |this, cx| {
581 if let Err(err) = result.as_ref() {
582 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
583 message_id: checkpoint.message_id,
584 error: err.to_string(),
585 });
586 } else {
587 this.truncate(checkpoint.message_id, cx);
588 this.last_restore_checkpoint = None;
589 }
590 this.pending_checkpoint = None;
591 cx.emit(ThreadEvent::CheckpointChanged);
592 cx.notify();
593 })?;
594 result
595 })
596 }
597
598 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
599 let pending_checkpoint = if self.is_generating() {
600 return;
601 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
602 checkpoint
603 } else {
604 return;
605 };
606
607 let git_store = self.project.read(cx).git_store().clone();
608 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
609 cx.spawn(async move |this, cx| match final_checkpoint.await {
610 Ok(final_checkpoint) => {
611 let equal = git_store
612 .update(cx, |store, cx| {
613 store.compare_checkpoints(
614 pending_checkpoint.git_checkpoint.clone(),
615 final_checkpoint.clone(),
616 cx,
617 )
618 })?
619 .await
620 .unwrap_or(false);
621
622 if !equal {
623 this.update(cx, |this, cx| {
624 this.insert_checkpoint(pending_checkpoint, cx)
625 })?;
626 }
627
628 Ok(())
629 }
630 Err(_) => this.update(cx, |this, cx| {
631 this.insert_checkpoint(pending_checkpoint, cx)
632 }),
633 })
634 .detach();
635 }
636
637 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
638 self.checkpoints_by_message
639 .insert(checkpoint.message_id, checkpoint);
640 cx.emit(ThreadEvent::CheckpointChanged);
641 cx.notify();
642 }
643
644 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
645 self.last_restore_checkpoint.as_ref()
646 }
647
648 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
649 let Some(message_ix) = self
650 .messages
651 .iter()
652 .rposition(|message| message.id == message_id)
653 else {
654 return;
655 };
656 for deleted_message in self.messages.drain(message_ix..) {
657 self.context_by_message.remove(&deleted_message.id);
658 self.checkpoints_by_message.remove(&deleted_message.id);
659 }
660 cx.notify();
661 }
662
663 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
664 self.context_by_message
665 .get(&id)
666 .into_iter()
667 .flat_map(|context| {
668 context
669 .iter()
670 .filter_map(|context_id| self.context.get(&context_id))
671 })
672 }
673
674 /// Returns whether all of the tool uses have finished running.
675 pub fn all_tools_finished(&self) -> bool {
676 // If the only pending tool uses left are the ones with errors, then
677 // that means that we've finished running all of the pending tools.
678 self.tool_use
679 .pending_tool_uses()
680 .iter()
681 .all(|tool_use| tool_use.status.is_error())
682 }
683
684 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
685 self.tool_use.tool_uses_for_message(id, cx)
686 }
687
688 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
689 self.tool_use.tool_results_for_message(id)
690 }
691
692 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
693 self.tool_use.tool_result(id)
694 }
695
696 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
697 Some(&self.tool_use.tool_result(id)?.content)
698 }
699
700 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
701 self.tool_use.tool_result_card(id).cloned()
702 }
703
704 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
705 self.tool_use.message_has_tool_results(message_id)
706 }
707
708 /// Filter out contexts that have already been included in previous messages
709 pub fn filter_new_context<'a>(
710 &self,
711 context: impl Iterator<Item = &'a AssistantContext>,
712 ) -> impl Iterator<Item = &'a AssistantContext> {
713 context.filter(|ctx| self.is_context_new(ctx))
714 }
715
716 fn is_context_new(&self, context: &AssistantContext) -> bool {
717 !self.context.contains_key(&context.id())
718 }
719
720 pub fn insert_user_message(
721 &mut self,
722 text: impl Into<String>,
723 context: Vec<AssistantContext>,
724 git_checkpoint: Option<GitStoreCheckpoint>,
725 cx: &mut Context<Self>,
726 ) -> MessageId {
727 let text = text.into();
728
729 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
730
731 let new_context: Vec<_> = context
732 .into_iter()
733 .filter(|ctx| self.is_context_new(ctx))
734 .collect();
735
736 if !new_context.is_empty() {
737 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
738 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
739 message.context = context_string;
740 }
741 }
742
743 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
744 message.images = new_context
745 .iter()
746 .filter_map(|context| {
747 if let AssistantContext::Image(image_context) = context {
748 image_context.image_task.clone().now_or_never().flatten()
749 } else {
750 None
751 }
752 })
753 .collect::<Vec<_>>();
754 }
755
756 self.action_log.update(cx, |log, cx| {
757 // Track all buffers added as context
758 for ctx in &new_context {
759 match ctx {
760 AssistantContext::File(file_ctx) => {
761 log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
762 }
763 AssistantContext::Directory(dir_ctx) => {
764 for context_buffer in &dir_ctx.context_buffers {
765 log.track_buffer(context_buffer.buffer.clone(), cx);
766 }
767 }
768 AssistantContext::Symbol(symbol_ctx) => {
769 log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
770 }
771 AssistantContext::Selection(selection_context) => {
772 log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
773 }
774 AssistantContext::FetchedUrl(_)
775 | AssistantContext::Thread(_)
776 | AssistantContext::Rules(_)
777 | AssistantContext::Image(_) => {}
778 }
779 }
780 });
781 }
782
783 let context_ids = new_context
784 .iter()
785 .map(|context| context.id())
786 .collect::<Vec<_>>();
787 self.context.extend(
788 new_context
789 .into_iter()
790 .map(|context| (context.id(), context)),
791 );
792 self.context_by_message.insert(message_id, context_ids);
793
794 if let Some(git_checkpoint) = git_checkpoint {
795 self.pending_checkpoint = Some(ThreadCheckpoint {
796 message_id,
797 git_checkpoint,
798 });
799 }
800
801 self.auto_capture_telemetry(cx);
802
803 message_id
804 }
805
806 pub fn insert_message(
807 &mut self,
808 role: Role,
809 segments: Vec<MessageSegment>,
810 cx: &mut Context<Self>,
811 ) -> MessageId {
812 let id = self.next_message_id.post_inc();
813 self.messages.push(Message {
814 id,
815 role,
816 segments,
817 context: String::new(),
818 images: Vec::new(),
819 });
820 self.touch_updated_at();
821 cx.emit(ThreadEvent::MessageAdded(id));
822 id
823 }
824
825 pub fn edit_message(
826 &mut self,
827 id: MessageId,
828 new_role: Role,
829 new_segments: Vec<MessageSegment>,
830 cx: &mut Context<Self>,
831 ) -> bool {
832 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
833 return false;
834 };
835 message.role = new_role;
836 message.segments = new_segments;
837 self.touch_updated_at();
838 cx.emit(ThreadEvent::MessageEdited(id));
839 true
840 }
841
842 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
843 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
844 return false;
845 };
846 self.messages.remove(index);
847 self.context_by_message.remove(&id);
848 self.touch_updated_at();
849 cx.emit(ThreadEvent::MessageDeleted(id));
850 true
851 }
852
853 /// Returns the representation of this [`Thread`] in a textual form.
854 ///
855 /// This is the representation we use when attaching a thread as context to another thread.
856 pub fn text(&self) -> String {
857 let mut text = String::new();
858
859 for message in &self.messages {
860 text.push_str(match message.role {
861 language_model::Role::User => "User:",
862 language_model::Role::Assistant => "Assistant:",
863 language_model::Role::System => "System:",
864 });
865 text.push('\n');
866
867 for segment in &message.segments {
868 match segment {
869 MessageSegment::Text(content) => text.push_str(content),
870 MessageSegment::Thinking { text: content, .. } => {
871 text.push_str(&format!("<think>{}</think>", content))
872 }
873 MessageSegment::RedactedThinking(_) => {}
874 }
875 }
876 text.push('\n');
877 }
878
879 text
880 }
881
882 /// Serializes this thread into a format for storage or telemetry.
883 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
884 let initial_project_snapshot = self.initial_project_snapshot.clone();
885 cx.spawn(async move |this, cx| {
886 let initial_project_snapshot = initial_project_snapshot.await;
887 this.read_with(cx, |this, cx| SerializedThread {
888 version: SerializedThread::VERSION.to_string(),
889 summary: this.summary_or_default(),
890 updated_at: this.updated_at(),
891 messages: this
892 .messages()
893 .map(|message| SerializedMessage {
894 id: message.id,
895 role: message.role,
896 segments: message
897 .segments
898 .iter()
899 .map(|segment| match segment {
900 MessageSegment::Text(text) => {
901 SerializedMessageSegment::Text { text: text.clone() }
902 }
903 MessageSegment::Thinking { text, signature } => {
904 SerializedMessageSegment::Thinking {
905 text: text.clone(),
906 signature: signature.clone(),
907 }
908 }
909 MessageSegment::RedactedThinking(data) => {
910 SerializedMessageSegment::RedactedThinking {
911 data: data.clone(),
912 }
913 }
914 })
915 .collect(),
916 tool_uses: this
917 .tool_uses_for_message(message.id, cx)
918 .into_iter()
919 .map(|tool_use| SerializedToolUse {
920 id: tool_use.id,
921 name: tool_use.name,
922 input: tool_use.input,
923 })
924 .collect(),
925 tool_results: this
926 .tool_results_for_message(message.id)
927 .into_iter()
928 .map(|tool_result| SerializedToolResult {
929 tool_use_id: tool_result.tool_use_id.clone(),
930 is_error: tool_result.is_error,
931 content: tool_result.content.clone(),
932 })
933 .collect(),
934 context: message.context.clone(),
935 })
936 .collect(),
937 initial_project_snapshot,
938 cumulative_token_usage: this.cumulative_token_usage,
939 request_token_usage: this.request_token_usage.clone(),
940 detailed_summary_state: this.detailed_summary_state.clone(),
941 exceeded_window_error: this.exceeded_window_error.clone(),
942 })
943 })
944 }
945
946 pub fn remaining_turns(&self) -> u32 {
947 self.remaining_turns
948 }
949
950 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
951 self.remaining_turns = remaining_turns;
952 }
953
954 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
955 if self.remaining_turns == 0 {
956 return;
957 }
958
959 self.remaining_turns -= 1;
960
961 let mut request = self.to_completion_request(cx);
962 if model.supports_tools() {
963 request.tools = {
964 let mut tools = Vec::new();
965 tools.extend(
966 self.tools()
967 .read(cx)
968 .enabled_tools(cx)
969 .into_iter()
970 .filter_map(|tool| {
971 // Skip tools that cannot be supported
972 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
973 Some(LanguageModelRequestTool {
974 name: tool.name(),
975 description: tool.description(),
976 input_schema,
977 })
978 }),
979 );
980
981 tools
982 };
983 }
984
985 self.stream_completion(request, model, cx);
986 }
987
988 pub fn used_tools_since_last_user_message(&self) -> bool {
989 for message in self.messages.iter().rev() {
990 if self.tool_use.message_has_tool_results(message.id) {
991 return true;
992 } else if message.role == Role::User {
993 return false;
994 }
995 }
996
997 false
998 }
999
1000 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1001 let mut request = LanguageModelRequest {
1002 thread_id: Some(self.id.to_string()),
1003 prompt_id: Some(self.last_prompt_id.to_string()),
1004 messages: vec![],
1005 tools: Vec::new(),
1006 stop: Vec::new(),
1007 temperature: None,
1008 };
1009
1010 if let Some(project_context) = self.project_context.borrow().as_ref() {
1011 match self
1012 .prompt_builder
1013 .generate_assistant_system_prompt(project_context)
1014 {
1015 Err(err) => {
1016 let message = format!("{err:?}").into();
1017 log::error!("{message}");
1018 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1019 header: "Error generating system prompt".into(),
1020 message,
1021 }));
1022 }
1023 Ok(system_prompt) => {
1024 request.messages.push(LanguageModelRequestMessage {
1025 role: Role::System,
1026 content: vec![MessageContent::Text(system_prompt)],
1027 cache: true,
1028 });
1029 }
1030 }
1031 } else {
1032 let message = "Context for system prompt unexpectedly not ready.".into();
1033 log::error!("{message}");
1034 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1035 header: "Error generating system prompt".into(),
1036 message,
1037 }));
1038 }
1039
1040 for message in &self.messages {
1041 let mut request_message = LanguageModelRequestMessage {
1042 role: message.role,
1043 content: Vec::new(),
1044 cache: false,
1045 };
1046
1047 self.tool_use
1048 .attach_tool_results(message.id, &mut request_message);
1049
1050 if !message.context.is_empty() {
1051 request_message
1052 .content
1053 .push(MessageContent::Text(message.context.to_string()));
1054 }
1055
1056 if !message.images.is_empty() {
1057 // Some providers only support image parts after an initial text part
1058 if request_message.content.is_empty() {
1059 request_message
1060 .content
1061 .push(MessageContent::Text("Images attached by user:".to_string()));
1062 }
1063
1064 for image in &message.images {
1065 request_message
1066 .content
1067 .push(MessageContent::Image(image.clone()))
1068 }
1069 }
1070
1071 for segment in &message.segments {
1072 match segment {
1073 MessageSegment::Text(text) => {
1074 if !text.is_empty() {
1075 request_message
1076 .content
1077 .push(MessageContent::Text(text.into()));
1078 }
1079 }
1080 MessageSegment::Thinking { text, signature } => {
1081 if !text.is_empty() {
1082 request_message.content.push(MessageContent::Thinking {
1083 text: text.into(),
1084 signature: signature.clone(),
1085 });
1086 }
1087 }
1088 MessageSegment::RedactedThinking(data) => {
1089 request_message
1090 .content
1091 .push(MessageContent::RedactedThinking(data.clone()));
1092 }
1093 };
1094 }
1095
1096 self.tool_use
1097 .attach_tool_uses(message.id, &mut request_message);
1098
1099 request.messages.push(request_message);
1100 }
1101
1102 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1103 if let Some(last) = request.messages.last_mut() {
1104 last.cache = true;
1105 }
1106
1107 self.attached_tracked_files_state(&mut request.messages, cx);
1108
1109 request
1110 }
1111
1112 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1113 let mut request = LanguageModelRequest {
1114 thread_id: None,
1115 prompt_id: None,
1116 messages: vec![],
1117 tools: Vec::new(),
1118 stop: Vec::new(),
1119 temperature: None,
1120 };
1121
1122 for message in &self.messages {
1123 let mut request_message = LanguageModelRequestMessage {
1124 role: message.role,
1125 content: Vec::new(),
1126 cache: false,
1127 };
1128
1129 // Skip tool results during summarization.
1130 if self.tool_use.message_has_tool_results(message.id) {
1131 continue;
1132 }
1133
1134 for segment in &message.segments {
1135 match segment {
1136 MessageSegment::Text(text) => request_message
1137 .content
1138 .push(MessageContent::Text(text.clone())),
1139 MessageSegment::Thinking { .. } => {}
1140 MessageSegment::RedactedThinking(_) => {}
1141 }
1142 }
1143
1144 if request_message.content.is_empty() {
1145 continue;
1146 }
1147
1148 request.messages.push(request_message);
1149 }
1150
1151 request.messages.push(LanguageModelRequestMessage {
1152 role: Role::User,
1153 content: vec![MessageContent::Text(added_user_message)],
1154 cache: false,
1155 });
1156
1157 request
1158 }
1159
1160 fn attached_tracked_files_state(
1161 &self,
1162 messages: &mut Vec<LanguageModelRequestMessage>,
1163 cx: &App,
1164 ) {
1165 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1166
1167 let mut stale_message = String::new();
1168
1169 let action_log = self.action_log.read(cx);
1170
1171 for stale_file in action_log.stale_buffers(cx) {
1172 let Some(file) = stale_file.read(cx).file() else {
1173 continue;
1174 };
1175
1176 if stale_message.is_empty() {
1177 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1178 }
1179
1180 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1181 }
1182
1183 let mut content = Vec::with_capacity(2);
1184
1185 if !stale_message.is_empty() {
1186 content.push(stale_message.into());
1187 }
1188
1189 if !content.is_empty() {
1190 let context_message = LanguageModelRequestMessage {
1191 role: Role::User,
1192 content,
1193 cache: false,
1194 };
1195
1196 messages.push(context_message);
1197 }
1198 }
1199
1200 pub fn stream_completion(
1201 &mut self,
1202 request: LanguageModelRequest,
1203 model: Arc<dyn LanguageModel>,
1204 cx: &mut Context<Self>,
1205 ) {
1206 let pending_completion_id = post_inc(&mut self.completion_count);
1207 let mut request_callback_parameters = if self.request_callback.is_some() {
1208 Some((request.clone(), Vec::new()))
1209 } else {
1210 None
1211 };
1212 let prompt_id = self.last_prompt_id.clone();
1213 let tool_use_metadata = ToolUseMetadata {
1214 model: model.clone(),
1215 thread_id: self.id.clone(),
1216 prompt_id: prompt_id.clone(),
1217 };
1218
1219 let task = cx.spawn(async move |thread, cx| {
1220 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1221 let initial_token_usage =
1222 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1223 let stream_completion = async {
1224 let (mut events, usage) = stream_completion_future.await?;
1225
1226 let mut stop_reason = StopReason::EndTurn;
1227 let mut current_token_usage = TokenUsage::default();
1228
1229 if let Some(usage) = usage {
1230 thread
1231 .update(cx, |_thread, cx| {
1232 cx.emit(ThreadEvent::UsageUpdated(usage));
1233 })
1234 .ok();
1235 }
1236
1237 while let Some(event) = events.next().await {
1238 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1239 response_events
1240 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1241 }
1242
1243 let event = event?;
1244
1245 thread.update(cx, |thread, cx| {
1246 match event {
1247 LanguageModelCompletionEvent::StartMessage { .. } => {
1248 thread.insert_message(
1249 Role::Assistant,
1250 vec![MessageSegment::Text(String::new())],
1251 cx,
1252 );
1253 }
1254 LanguageModelCompletionEvent::Stop(reason) => {
1255 stop_reason = reason;
1256 }
1257 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1258 thread.update_token_usage_at_last_message(token_usage);
1259 thread.cumulative_token_usage = thread.cumulative_token_usage
1260 + token_usage
1261 - current_token_usage;
1262 current_token_usage = token_usage;
1263 }
1264 LanguageModelCompletionEvent::Text(chunk) => {
1265 cx.emit(ThreadEvent::ReceivedTextChunk);
1266 if let Some(last_message) = thread.messages.last_mut() {
1267 if last_message.role == Role::Assistant {
1268 last_message.push_text(&chunk);
1269 cx.emit(ThreadEvent::StreamedAssistantText(
1270 last_message.id,
1271 chunk,
1272 ));
1273 } else {
1274 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1275 // of a new Assistant response.
1276 //
1277 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1278 // will result in duplicating the text of the chunk in the rendered Markdown.
1279 thread.insert_message(
1280 Role::Assistant,
1281 vec![MessageSegment::Text(chunk.to_string())],
1282 cx,
1283 );
1284 };
1285 }
1286 }
1287 LanguageModelCompletionEvent::Thinking {
1288 text: chunk,
1289 signature,
1290 } => {
1291 if let Some(last_message) = thread.messages.last_mut() {
1292 if last_message.role == Role::Assistant {
1293 last_message.push_thinking(&chunk, signature);
1294 cx.emit(ThreadEvent::StreamedAssistantThinking(
1295 last_message.id,
1296 chunk,
1297 ));
1298 } else {
1299 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1300 // of a new Assistant response.
1301 //
1302 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1303 // will result in duplicating the text of the chunk in the rendered Markdown.
1304 thread.insert_message(
1305 Role::Assistant,
1306 vec![MessageSegment::Thinking {
1307 text: chunk.to_string(),
1308 signature,
1309 }],
1310 cx,
1311 );
1312 };
1313 }
1314 }
1315 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1316 let last_assistant_message_id = thread
1317 .messages
1318 .iter_mut()
1319 .rfind(|message| message.role == Role::Assistant)
1320 .map(|message| message.id)
1321 .unwrap_or_else(|| {
1322 thread.insert_message(Role::Assistant, vec![], cx)
1323 });
1324
1325 let tool_use_id = tool_use.id.clone();
1326 let streamed_input = if tool_use.is_input_complete {
1327 None
1328 } else {
1329 Some((&tool_use.input).clone())
1330 };
1331
1332 let ui_text = thread.tool_use.request_tool_use(
1333 last_assistant_message_id,
1334 tool_use,
1335 tool_use_metadata.clone(),
1336 cx,
1337 );
1338
1339 if let Some(input) = streamed_input {
1340 cx.emit(ThreadEvent::StreamedToolUse {
1341 tool_use_id,
1342 ui_text,
1343 input,
1344 });
1345 }
1346 }
1347 }
1348
1349 thread.touch_updated_at();
1350 cx.emit(ThreadEvent::StreamedCompletion);
1351 cx.notify();
1352
1353 thread.auto_capture_telemetry(cx);
1354 })?;
1355
1356 smol::future::yield_now().await;
1357 }
1358
1359 thread.update(cx, |thread, cx| {
1360 thread
1361 .pending_completions
1362 .retain(|completion| completion.id != pending_completion_id);
1363
1364 // If there is a response without tool use, summarize the message. Otherwise,
1365 // allow two tool uses before summarizing.
1366 if thread.summary.is_none()
1367 && thread.messages.len() >= 2
1368 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1369 {
1370 thread.summarize(cx);
1371 }
1372 })?;
1373
1374 anyhow::Ok(stop_reason)
1375 };
1376
1377 let result = stream_completion.await;
1378
1379 thread
1380 .update(cx, |thread, cx| {
1381 thread.finalize_pending_checkpoint(cx);
1382 match result.as_ref() {
1383 Ok(stop_reason) => match stop_reason {
1384 StopReason::ToolUse => {
1385 let tool_uses = thread.use_pending_tools(cx);
1386 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1387 }
1388 StopReason::EndTurn => {}
1389 StopReason::MaxTokens => {}
1390 },
1391 Err(error) => {
1392 if error.is::<PaymentRequiredError>() {
1393 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1394 } else if error.is::<MaxMonthlySpendReachedError>() {
1395 cx.emit(ThreadEvent::ShowError(
1396 ThreadError::MaxMonthlySpendReached,
1397 ));
1398 } else if let Some(error) =
1399 error.downcast_ref::<ModelRequestLimitReachedError>()
1400 {
1401 cx.emit(ThreadEvent::ShowError(
1402 ThreadError::ModelRequestLimitReached { plan: error.plan },
1403 ));
1404 } else if let Some(known_error) =
1405 error.downcast_ref::<LanguageModelKnownError>()
1406 {
1407 match known_error {
1408 LanguageModelKnownError::ContextWindowLimitExceeded {
1409 tokens,
1410 } => {
1411 thread.exceeded_window_error = Some(ExceededWindowError {
1412 model_id: model.id(),
1413 token_count: *tokens,
1414 });
1415 cx.notify();
1416 }
1417 }
1418 } else {
1419 let error_message = error
1420 .chain()
1421 .map(|err| err.to_string())
1422 .collect::<Vec<_>>()
1423 .join("\n");
1424 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1425 header: "Error interacting with language model".into(),
1426 message: SharedString::from(error_message.clone()),
1427 }));
1428 }
1429
1430 thread.cancel_last_completion(cx);
1431 }
1432 }
1433 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1434
1435 if let Some((request_callback, (request, response_events))) = thread
1436 .request_callback
1437 .as_mut()
1438 .zip(request_callback_parameters.as_ref())
1439 {
1440 request_callback(request, response_events);
1441 }
1442
1443 thread.auto_capture_telemetry(cx);
1444
1445 if let Ok(initial_usage) = initial_token_usage {
1446 let usage = thread.cumulative_token_usage - initial_usage;
1447
1448 telemetry::event!(
1449 "Assistant Thread Completion",
1450 thread_id = thread.id().to_string(),
1451 prompt_id = prompt_id,
1452 model = model.telemetry_id(),
1453 model_provider = model.provider_id().to_string(),
1454 input_tokens = usage.input_tokens,
1455 output_tokens = usage.output_tokens,
1456 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1457 cache_read_input_tokens = usage.cache_read_input_tokens,
1458 );
1459 }
1460 })
1461 .ok();
1462 });
1463
1464 self.pending_completions.push(PendingCompletion {
1465 id: pending_completion_id,
1466 _task: task,
1467 });
1468 }
1469
1470 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1471 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1472 return;
1473 };
1474
1475 if !model.provider.is_authenticated(cx) {
1476 return;
1477 }
1478
1479 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1480 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1481 If the conversation is about a specific subject, include it in the title. \
1482 Be descriptive. DO NOT speak in the first person.";
1483
1484 let request = self.to_summarize_request(added_user_message.into());
1485
1486 self.pending_summary = cx.spawn(async move |this, cx| {
1487 async move {
1488 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1489 let (mut messages, usage) = stream.await?;
1490
1491 if let Some(usage) = usage {
1492 this.update(cx, |_thread, cx| {
1493 cx.emit(ThreadEvent::UsageUpdated(usage));
1494 })
1495 .ok();
1496 }
1497
1498 let mut new_summary = String::new();
1499 while let Some(message) = messages.stream.next().await {
1500 let text = message?;
1501 let mut lines = text.lines();
1502 new_summary.extend(lines.next());
1503
1504 // Stop if the LLM generated multiple lines.
1505 if lines.next().is_some() {
1506 break;
1507 }
1508 }
1509
1510 this.update(cx, |this, cx| {
1511 if !new_summary.is_empty() {
1512 this.summary = Some(new_summary.into());
1513 }
1514
1515 cx.emit(ThreadEvent::SummaryGenerated);
1516 })?;
1517
1518 anyhow::Ok(())
1519 }
1520 .log_err()
1521 .await
1522 });
1523 }
1524
1525 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1526 let last_message_id = self.messages.last().map(|message| message.id)?;
1527
1528 match &self.detailed_summary_state {
1529 DetailedSummaryState::Generating { message_id, .. }
1530 | DetailedSummaryState::Generated { message_id, .. }
1531 if *message_id == last_message_id =>
1532 {
1533 // Already up-to-date
1534 return None;
1535 }
1536 _ => {}
1537 }
1538
1539 let ConfiguredModel { model, provider } =
1540 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1541
1542 if !provider.is_authenticated(cx) {
1543 return None;
1544 }
1545
1546 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1547 1. A brief overview of what was discussed\n\
1548 2. Key facts or information discovered\n\
1549 3. Outcomes or conclusions reached\n\
1550 4. Any action items or next steps if any\n\
1551 Format it in Markdown with headings and bullet points.";
1552
1553 let request = self.to_summarize_request(added_user_message.into());
1554
1555 let task = cx.spawn(async move |thread, cx| {
1556 let stream = model.stream_completion_text(request, &cx);
1557 let Some(mut messages) = stream.await.log_err() else {
1558 thread
1559 .update(cx, |this, _cx| {
1560 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1561 })
1562 .log_err();
1563
1564 return;
1565 };
1566
1567 let mut new_detailed_summary = String::new();
1568
1569 while let Some(chunk) = messages.stream.next().await {
1570 if let Some(chunk) = chunk.log_err() {
1571 new_detailed_summary.push_str(&chunk);
1572 }
1573 }
1574
1575 thread
1576 .update(cx, |this, _cx| {
1577 this.detailed_summary_state = DetailedSummaryState::Generated {
1578 text: new_detailed_summary.into(),
1579 message_id: last_message_id,
1580 };
1581 })
1582 .log_err();
1583 });
1584
1585 self.detailed_summary_state = DetailedSummaryState::Generating {
1586 message_id: last_message_id,
1587 };
1588
1589 Some(task)
1590 }
1591
1592 pub fn is_generating_detailed_summary(&self) -> bool {
1593 matches!(
1594 self.detailed_summary_state,
1595 DetailedSummaryState::Generating { .. }
1596 )
1597 }
1598
1599 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1600 self.auto_capture_telemetry(cx);
1601 let request = self.to_completion_request(cx);
1602 let messages = Arc::new(request.messages);
1603 let pending_tool_uses = self
1604 .tool_use
1605 .pending_tool_uses()
1606 .into_iter()
1607 .filter(|tool_use| tool_use.status.is_idle())
1608 .cloned()
1609 .collect::<Vec<_>>();
1610
1611 for tool_use in pending_tool_uses.iter() {
1612 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1613 if tool.needs_confirmation(&tool_use.input, cx)
1614 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1615 {
1616 self.tool_use.confirm_tool_use(
1617 tool_use.id.clone(),
1618 tool_use.ui_text.clone(),
1619 tool_use.input.clone(),
1620 messages.clone(),
1621 tool,
1622 );
1623 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1624 } else {
1625 self.run_tool(
1626 tool_use.id.clone(),
1627 tool_use.ui_text.clone(),
1628 tool_use.input.clone(),
1629 &messages,
1630 tool,
1631 cx,
1632 );
1633 }
1634 }
1635 }
1636
1637 pending_tool_uses
1638 }
1639
1640 pub fn run_tool(
1641 &mut self,
1642 tool_use_id: LanguageModelToolUseId,
1643 ui_text: impl Into<SharedString>,
1644 input: serde_json::Value,
1645 messages: &[LanguageModelRequestMessage],
1646 tool: Arc<dyn Tool>,
1647 cx: &mut Context<Thread>,
1648 ) {
1649 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1650 self.tool_use
1651 .run_pending_tool(tool_use_id, ui_text.into(), task);
1652 }
1653
1654 fn spawn_tool_use(
1655 &mut self,
1656 tool_use_id: LanguageModelToolUseId,
1657 messages: &[LanguageModelRequestMessage],
1658 input: serde_json::Value,
1659 tool: Arc<dyn Tool>,
1660 cx: &mut Context<Thread>,
1661 ) -> Task<()> {
1662 let tool_name: Arc<str> = tool.name().into();
1663
1664 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1665 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1666 } else {
1667 tool.run(
1668 input,
1669 messages,
1670 self.project.clone(),
1671 self.action_log.clone(),
1672 cx,
1673 )
1674 };
1675
1676 // Store the card separately if it exists
1677 if let Some(card) = tool_result.card.clone() {
1678 self.tool_use
1679 .insert_tool_result_card(tool_use_id.clone(), card);
1680 }
1681
1682 cx.spawn({
1683 async move |thread: WeakEntity<Thread>, cx| {
1684 let output = tool_result.output.await;
1685
1686 thread
1687 .update(cx, |thread, cx| {
1688 let pending_tool_use = thread.tool_use.insert_tool_output(
1689 tool_use_id.clone(),
1690 tool_name,
1691 output,
1692 cx,
1693 );
1694 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1695 })
1696 .ok();
1697 }
1698 })
1699 }
1700
1701 fn tool_finished(
1702 &mut self,
1703 tool_use_id: LanguageModelToolUseId,
1704 pending_tool_use: Option<PendingToolUse>,
1705 canceled: bool,
1706 cx: &mut Context<Self>,
1707 ) {
1708 if self.all_tools_finished() {
1709 let model_registry = LanguageModelRegistry::read_global(cx);
1710 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1711 self.attach_tool_results(cx);
1712 if !canceled {
1713 self.send_to_model(model, cx);
1714 }
1715 }
1716 }
1717
1718 cx.emit(ThreadEvent::ToolFinished {
1719 tool_use_id,
1720 pending_tool_use,
1721 });
1722 }
1723
1724 /// Insert an empty message to be populated with tool results upon send.
1725 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1726 // Tool results are assumed to be waiting on the next message id, so they will populate
1727 // this empty message before sending to model. Would prefer this to be more straightforward.
1728 self.insert_message(Role::User, vec![], cx);
1729 self.auto_capture_telemetry(cx);
1730 }
1731
1732 /// Cancels the last pending completion, if there are any pending.
1733 ///
1734 /// Returns whether a completion was canceled.
1735 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1736 let canceled = if self.pending_completions.pop().is_some() {
1737 true
1738 } else {
1739 let mut canceled = false;
1740 for pending_tool_use in self.tool_use.cancel_pending() {
1741 canceled = true;
1742 self.tool_finished(
1743 pending_tool_use.id.clone(),
1744 Some(pending_tool_use),
1745 true,
1746 cx,
1747 );
1748 }
1749 canceled
1750 };
1751 self.finalize_pending_checkpoint(cx);
1752 canceled
1753 }
1754
1755 pub fn feedback(&self) -> Option<ThreadFeedback> {
1756 self.feedback
1757 }
1758
1759 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1760 self.message_feedback.get(&message_id).copied()
1761 }
1762
1763 pub fn report_message_feedback(
1764 &mut self,
1765 message_id: MessageId,
1766 feedback: ThreadFeedback,
1767 cx: &mut Context<Self>,
1768 ) -> Task<Result<()>> {
1769 if self.message_feedback.get(&message_id) == Some(&feedback) {
1770 return Task::ready(Ok(()));
1771 }
1772
1773 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1774 let serialized_thread = self.serialize(cx);
1775 let thread_id = self.id().clone();
1776 let client = self.project.read(cx).client();
1777
1778 let enabled_tool_names: Vec<String> = self
1779 .tools()
1780 .read(cx)
1781 .enabled_tools(cx)
1782 .iter()
1783 .map(|tool| tool.name().to_string())
1784 .collect();
1785
1786 self.message_feedback.insert(message_id, feedback);
1787
1788 cx.notify();
1789
1790 let message_content = self
1791 .message(message_id)
1792 .map(|msg| msg.to_string())
1793 .unwrap_or_default();
1794
1795 cx.background_spawn(async move {
1796 let final_project_snapshot = final_project_snapshot.await;
1797 let serialized_thread = serialized_thread.await?;
1798 let thread_data =
1799 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1800
1801 let rating = match feedback {
1802 ThreadFeedback::Positive => "positive",
1803 ThreadFeedback::Negative => "negative",
1804 };
1805 telemetry::event!(
1806 "Assistant Thread Rated",
1807 rating,
1808 thread_id,
1809 enabled_tool_names,
1810 message_id = message_id.0,
1811 message_content,
1812 thread_data,
1813 final_project_snapshot
1814 );
1815 client.telemetry().flush_events().await;
1816
1817 Ok(())
1818 })
1819 }
1820
1821 pub fn report_feedback(
1822 &mut self,
1823 feedback: ThreadFeedback,
1824 cx: &mut Context<Self>,
1825 ) -> Task<Result<()>> {
1826 let last_assistant_message_id = self
1827 .messages
1828 .iter()
1829 .rev()
1830 .find(|msg| msg.role == Role::Assistant)
1831 .map(|msg| msg.id);
1832
1833 if let Some(message_id) = last_assistant_message_id {
1834 self.report_message_feedback(message_id, feedback, cx)
1835 } else {
1836 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1837 let serialized_thread = self.serialize(cx);
1838 let thread_id = self.id().clone();
1839 let client = self.project.read(cx).client();
1840 self.feedback = Some(feedback);
1841 cx.notify();
1842
1843 cx.background_spawn(async move {
1844 let final_project_snapshot = final_project_snapshot.await;
1845 let serialized_thread = serialized_thread.await?;
1846 let thread_data = serde_json::to_value(serialized_thread)
1847 .unwrap_or_else(|_| serde_json::Value::Null);
1848
1849 let rating = match feedback {
1850 ThreadFeedback::Positive => "positive",
1851 ThreadFeedback::Negative => "negative",
1852 };
1853 telemetry::event!(
1854 "Assistant Thread Rated",
1855 rating,
1856 thread_id,
1857 thread_data,
1858 final_project_snapshot
1859 );
1860 client.telemetry().flush_events().await;
1861
1862 Ok(())
1863 })
1864 }
1865 }
1866
1867 /// Create a snapshot of the current project state including git information and unsaved buffers.
1868 fn project_snapshot(
1869 project: Entity<Project>,
1870 cx: &mut Context<Self>,
1871 ) -> Task<Arc<ProjectSnapshot>> {
1872 let git_store = project.read(cx).git_store().clone();
1873 let worktree_snapshots: Vec<_> = project
1874 .read(cx)
1875 .visible_worktrees(cx)
1876 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1877 .collect();
1878
1879 cx.spawn(async move |_, cx| {
1880 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1881
1882 let mut unsaved_buffers = Vec::new();
1883 cx.update(|app_cx| {
1884 let buffer_store = project.read(app_cx).buffer_store();
1885 for buffer_handle in buffer_store.read(app_cx).buffers() {
1886 let buffer = buffer_handle.read(app_cx);
1887 if buffer.is_dirty() {
1888 if let Some(file) = buffer.file() {
1889 let path = file.path().to_string_lossy().to_string();
1890 unsaved_buffers.push(path);
1891 }
1892 }
1893 }
1894 })
1895 .ok();
1896
1897 Arc::new(ProjectSnapshot {
1898 worktree_snapshots,
1899 unsaved_buffer_paths: unsaved_buffers,
1900 timestamp: Utc::now(),
1901 })
1902 })
1903 }
1904
1905 fn worktree_snapshot(
1906 worktree: Entity<project::Worktree>,
1907 git_store: Entity<GitStore>,
1908 cx: &App,
1909 ) -> Task<WorktreeSnapshot> {
1910 cx.spawn(async move |cx| {
1911 // Get worktree path and snapshot
1912 let worktree_info = cx.update(|app_cx| {
1913 let worktree = worktree.read(app_cx);
1914 let path = worktree.abs_path().to_string_lossy().to_string();
1915 let snapshot = worktree.snapshot();
1916 (path, snapshot)
1917 });
1918
1919 let Ok((worktree_path, _snapshot)) = worktree_info else {
1920 return WorktreeSnapshot {
1921 worktree_path: String::new(),
1922 git_state: None,
1923 };
1924 };
1925
1926 let git_state = git_store
1927 .update(cx, |git_store, cx| {
1928 git_store
1929 .repositories()
1930 .values()
1931 .find(|repo| {
1932 repo.read(cx)
1933 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1934 .is_some()
1935 })
1936 .cloned()
1937 })
1938 .ok()
1939 .flatten()
1940 .map(|repo| {
1941 repo.update(cx, |repo, _| {
1942 let current_branch =
1943 repo.branch.as_ref().map(|branch| branch.name.to_string());
1944 repo.send_job(None, |state, _| async move {
1945 let RepositoryState::Local { backend, .. } = state else {
1946 return GitState {
1947 remote_url: None,
1948 head_sha: None,
1949 current_branch,
1950 diff: None,
1951 };
1952 };
1953
1954 let remote_url = backend.remote_url("origin");
1955 let head_sha = backend.head_sha();
1956 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1957
1958 GitState {
1959 remote_url,
1960 head_sha,
1961 current_branch,
1962 diff,
1963 }
1964 })
1965 })
1966 });
1967
1968 let git_state = match git_state {
1969 Some(git_state) => match git_state.ok() {
1970 Some(git_state) => git_state.await.ok(),
1971 None => None,
1972 },
1973 None => None,
1974 };
1975
1976 WorktreeSnapshot {
1977 worktree_path,
1978 git_state,
1979 }
1980 })
1981 }
1982
1983 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1984 let mut markdown = Vec::new();
1985
1986 if let Some(summary) = self.summary() {
1987 writeln!(markdown, "# {summary}\n")?;
1988 };
1989
1990 for message in self.messages() {
1991 writeln!(
1992 markdown,
1993 "## {role}\n",
1994 role = match message.role {
1995 Role::User => "User",
1996 Role::Assistant => "Assistant",
1997 Role::System => "System",
1998 }
1999 )?;
2000
2001 if !message.context.is_empty() {
2002 writeln!(markdown, "{}", message.context)?;
2003 }
2004
2005 for segment in &message.segments {
2006 match segment {
2007 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2008 MessageSegment::Thinking { text, .. } => {
2009 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2010 }
2011 MessageSegment::RedactedThinking(_) => {}
2012 }
2013 }
2014
2015 for tool_use in self.tool_uses_for_message(message.id, cx) {
2016 writeln!(
2017 markdown,
2018 "**Use Tool: {} ({})**",
2019 tool_use.name, tool_use.id
2020 )?;
2021 writeln!(markdown, "```json")?;
2022 writeln!(
2023 markdown,
2024 "{}",
2025 serde_json::to_string_pretty(&tool_use.input)?
2026 )?;
2027 writeln!(markdown, "```")?;
2028 }
2029
2030 for tool_result in self.tool_results_for_message(message.id) {
2031 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2032 if tool_result.is_error {
2033 write!(markdown, " (Error)")?;
2034 }
2035
2036 writeln!(markdown, "**\n")?;
2037 writeln!(markdown, "{}", tool_result.content)?;
2038 }
2039 }
2040
2041 Ok(String::from_utf8_lossy(&markdown).to_string())
2042 }
2043
2044 pub fn keep_edits_in_range(
2045 &mut self,
2046 buffer: Entity<language::Buffer>,
2047 buffer_range: Range<language::Anchor>,
2048 cx: &mut Context<Self>,
2049 ) {
2050 self.action_log.update(cx, |action_log, cx| {
2051 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2052 });
2053 }
2054
2055 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2056 self.action_log
2057 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2058 }
2059
2060 pub fn reject_edits_in_ranges(
2061 &mut self,
2062 buffer: Entity<language::Buffer>,
2063 buffer_ranges: Vec<Range<language::Anchor>>,
2064 cx: &mut Context<Self>,
2065 ) -> Task<Result<()>> {
2066 self.action_log.update(cx, |action_log, cx| {
2067 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2068 })
2069 }
2070
2071 pub fn action_log(&self) -> &Entity<ActionLog> {
2072 &self.action_log
2073 }
2074
2075 pub fn project(&self) -> &Entity<Project> {
2076 &self.project
2077 }
2078
2079 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2080 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2081 return;
2082 }
2083
2084 let now = Instant::now();
2085 if let Some(last) = self.last_auto_capture_at {
2086 if now.duration_since(last).as_secs() < 10 {
2087 return;
2088 }
2089 }
2090
2091 self.last_auto_capture_at = Some(now);
2092
2093 let thread_id = self.id().clone();
2094 let github_login = self
2095 .project
2096 .read(cx)
2097 .user_store()
2098 .read(cx)
2099 .current_user()
2100 .map(|user| user.github_login.clone());
2101 let client = self.project.read(cx).client().clone();
2102 let serialize_task = self.serialize(cx);
2103
2104 cx.background_executor()
2105 .spawn(async move {
2106 if let Ok(serialized_thread) = serialize_task.await {
2107 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2108 telemetry::event!(
2109 "Agent Thread Auto-Captured",
2110 thread_id = thread_id.to_string(),
2111 thread_data = thread_data,
2112 auto_capture_reason = "tracked_user",
2113 github_login = github_login
2114 );
2115
2116 client.telemetry().flush_events().await;
2117 }
2118 }
2119 })
2120 .detach();
2121 }
2122
2123 pub fn cumulative_token_usage(&self) -> TokenUsage {
2124 self.cumulative_token_usage
2125 }
2126
2127 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2128 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2129 return TotalTokenUsage::default();
2130 };
2131
2132 let max = model.model.max_token_count();
2133
2134 let index = self
2135 .messages
2136 .iter()
2137 .position(|msg| msg.id == message_id)
2138 .unwrap_or(0);
2139
2140 if index == 0 {
2141 return TotalTokenUsage { total: 0, max };
2142 }
2143
2144 let token_usage = &self
2145 .request_token_usage
2146 .get(index - 1)
2147 .cloned()
2148 .unwrap_or_default();
2149
2150 TotalTokenUsage {
2151 total: token_usage.total_tokens() as usize,
2152 max,
2153 }
2154 }
2155
2156 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2157 let model_registry = LanguageModelRegistry::read_global(cx);
2158 let Some(model) = model_registry.default_model() else {
2159 return TotalTokenUsage::default();
2160 };
2161
2162 let max = model.model.max_token_count();
2163
2164 if let Some(exceeded_error) = &self.exceeded_window_error {
2165 if model.model.id() == exceeded_error.model_id {
2166 return TotalTokenUsage {
2167 total: exceeded_error.token_count,
2168 max,
2169 };
2170 }
2171 }
2172
2173 let total = self
2174 .token_usage_at_last_message()
2175 .unwrap_or_default()
2176 .total_tokens() as usize;
2177
2178 TotalTokenUsage { total, max }
2179 }
2180
2181 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2182 self.request_token_usage
2183 .get(self.messages.len().saturating_sub(1))
2184 .or_else(|| self.request_token_usage.last())
2185 .cloned()
2186 }
2187
2188 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2189 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2190 self.request_token_usage
2191 .resize(self.messages.len(), placeholder);
2192
2193 if let Some(last) = self.request_token_usage.last_mut() {
2194 *last = token_usage;
2195 }
2196 }
2197
2198 pub fn deny_tool_use(
2199 &mut self,
2200 tool_use_id: LanguageModelToolUseId,
2201 tool_name: Arc<str>,
2202 cx: &mut Context<Self>,
2203 ) {
2204 let err = Err(anyhow::anyhow!(
2205 "Permission to run tool action denied by user"
2206 ));
2207
2208 self.tool_use
2209 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2210 self.tool_finished(tool_use_id.clone(), None, true, cx);
2211 }
2212}
2213
2214#[derive(Debug, Clone, Error)]
2215pub enum ThreadError {
2216 #[error("Payment required")]
2217 PaymentRequired,
2218 #[error("Max monthly spend reached")]
2219 MaxMonthlySpendReached,
2220 #[error("Model request limit reached")]
2221 ModelRequestLimitReached { plan: Plan },
2222 #[error("Message {header}: {message}")]
2223 Message {
2224 header: SharedString,
2225 message: SharedString,
2226 },
2227}
2228
2229#[derive(Debug, Clone)]
2230pub enum ThreadEvent {
2231 ShowError(ThreadError),
2232 UsageUpdated(RequestUsage),
2233 StreamedCompletion,
2234 ReceivedTextChunk,
2235 StreamedAssistantText(MessageId, String),
2236 StreamedAssistantThinking(MessageId, String),
2237 StreamedToolUse {
2238 tool_use_id: LanguageModelToolUseId,
2239 ui_text: Arc<str>,
2240 input: serde_json::Value,
2241 },
2242 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2243 MessageAdded(MessageId),
2244 MessageEdited(MessageId),
2245 MessageDeleted(MessageId),
2246 SummaryGenerated,
2247 SummaryChanged,
2248 UsePendingTools {
2249 tool_uses: Vec<PendingToolUse>,
2250 },
2251 ToolFinished {
2252 #[allow(unused)]
2253 tool_use_id: LanguageModelToolUseId,
2254 /// The pending tool use that corresponds to this tool.
2255 pending_tool_use: Option<PendingToolUse>,
2256 },
2257 CheckpointChanged,
2258 ToolConfirmationNeeded,
2259}
2260
2261impl EventEmitter<ThreadEvent> for Thread {}
2262
2263struct PendingCompletion {
2264 id: usize,
2265 _task: Task<()>,
2266}
2267
2268#[cfg(test)]
2269mod tests {
2270 use super::*;
2271 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2272 use assistant_settings::AssistantSettings;
2273 use context_server::ContextServerSettings;
2274 use editor::EditorSettings;
2275 use gpui::TestAppContext;
2276 use project::{FakeFs, Project};
2277 use prompt_store::PromptBuilder;
2278 use serde_json::json;
2279 use settings::{Settings, SettingsStore};
2280 use std::sync::Arc;
2281 use theme::ThemeSettings;
2282 use util::path;
2283 use workspace::Workspace;
2284
2285 #[gpui::test]
2286 async fn test_message_with_context(cx: &mut TestAppContext) {
2287 init_test_settings(cx);
2288
2289 let project = create_test_project(
2290 cx,
2291 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2292 )
2293 .await;
2294
2295 let (_workspace, _thread_store, thread, context_store) =
2296 setup_test_environment(cx, project.clone()).await;
2297
2298 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2299 .await
2300 .unwrap();
2301
2302 let context =
2303 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2304
2305 // Insert user message with context
2306 let message_id = thread.update(cx, |thread, cx| {
2307 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2308 });
2309
2310 // Check content and context in message object
2311 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2312
2313 // Use different path format strings based on platform for the test
2314 #[cfg(windows)]
2315 let path_part = r"test\code.rs";
2316 #[cfg(not(windows))]
2317 let path_part = "test/code.rs";
2318
2319 let expected_context = format!(
2320 r#"
2321<context>
2322The following items were attached by the user. You don't need to use other tools to read them.
2323
2324<files>
2325```rs {path_part}
2326fn main() {{
2327 println!("Hello, world!");
2328}}
2329```
2330</files>
2331</context>
2332"#
2333 );
2334
2335 assert_eq!(message.role, Role::User);
2336 assert_eq!(message.segments.len(), 1);
2337 assert_eq!(
2338 message.segments[0],
2339 MessageSegment::Text("Please explain this code".to_string())
2340 );
2341 assert_eq!(message.context, expected_context);
2342
2343 // Check message in request
2344 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2345
2346 assert_eq!(request.messages.len(), 2);
2347 let expected_full_message = format!("{}Please explain this code", expected_context);
2348 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2349 }
2350
2351 #[gpui::test]
2352 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2353 init_test_settings(cx);
2354
2355 let project = create_test_project(
2356 cx,
2357 json!({
2358 "file1.rs": "fn function1() {}\n",
2359 "file2.rs": "fn function2() {}\n",
2360 "file3.rs": "fn function3() {}\n",
2361 }),
2362 )
2363 .await;
2364
2365 let (_, _thread_store, thread, context_store) =
2366 setup_test_environment(cx, project.clone()).await;
2367
2368 // Open files individually
2369 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2370 .await
2371 .unwrap();
2372 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2373 .await
2374 .unwrap();
2375 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2376 .await
2377 .unwrap();
2378
2379 // Get the context objects
2380 let contexts = context_store.update(cx, |store, _| store.context().clone());
2381 assert_eq!(contexts.len(), 3);
2382
2383 // First message with context 1
2384 let message1_id = thread.update(cx, |thread, cx| {
2385 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2386 });
2387
2388 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2389 let message2_id = thread.update(cx, |thread, cx| {
2390 thread.insert_user_message(
2391 "Message 2",
2392 vec![contexts[0].clone(), contexts[1].clone()],
2393 None,
2394 cx,
2395 )
2396 });
2397
2398 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2399 let message3_id = thread.update(cx, |thread, cx| {
2400 thread.insert_user_message(
2401 "Message 3",
2402 vec![
2403 contexts[0].clone(),
2404 contexts[1].clone(),
2405 contexts[2].clone(),
2406 ],
2407 None,
2408 cx,
2409 )
2410 });
2411
2412 // Check what contexts are included in each message
2413 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2414 (
2415 thread.message(message1_id).unwrap().clone(),
2416 thread.message(message2_id).unwrap().clone(),
2417 thread.message(message3_id).unwrap().clone(),
2418 )
2419 });
2420
2421 // First message should include context 1
2422 assert!(message1.context.contains("file1.rs"));
2423
2424 // Second message should include only context 2 (not 1)
2425 assert!(!message2.context.contains("file1.rs"));
2426 assert!(message2.context.contains("file2.rs"));
2427
2428 // Third message should include only context 3 (not 1 or 2)
2429 assert!(!message3.context.contains("file1.rs"));
2430 assert!(!message3.context.contains("file2.rs"));
2431 assert!(message3.context.contains("file3.rs"));
2432
2433 // Check entire request to make sure all contexts are properly included
2434 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2435
2436 // The request should contain all 3 messages
2437 assert_eq!(request.messages.len(), 4);
2438
2439 // Check that the contexts are properly formatted in each message
2440 assert!(request.messages[1].string_contents().contains("file1.rs"));
2441 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2442 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2443
2444 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2445 assert!(request.messages[2].string_contents().contains("file2.rs"));
2446 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2447
2448 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2449 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2450 assert!(request.messages[3].string_contents().contains("file3.rs"));
2451 }
2452
2453 #[gpui::test]
2454 async fn test_message_without_files(cx: &mut TestAppContext) {
2455 init_test_settings(cx);
2456
2457 let project = create_test_project(
2458 cx,
2459 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2460 )
2461 .await;
2462
2463 let (_, _thread_store, thread, _context_store) =
2464 setup_test_environment(cx, project.clone()).await;
2465
2466 // Insert user message without any context (empty context vector)
2467 let message_id = thread.update(cx, |thread, cx| {
2468 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2469 });
2470
2471 // Check content and context in message object
2472 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2473
2474 // Context should be empty when no files are included
2475 assert_eq!(message.role, Role::User);
2476 assert_eq!(message.segments.len(), 1);
2477 assert_eq!(
2478 message.segments[0],
2479 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2480 );
2481 assert_eq!(message.context, "");
2482
2483 // Check message in request
2484 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2485
2486 assert_eq!(request.messages.len(), 2);
2487 assert_eq!(
2488 request.messages[1].string_contents(),
2489 "What is the best way to learn Rust?"
2490 );
2491
2492 // Add second message, also without context
2493 let message2_id = thread.update(cx, |thread, cx| {
2494 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2495 });
2496
2497 let message2 =
2498 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2499 assert_eq!(message2.context, "");
2500
2501 // Check that both messages appear in the request
2502 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2503
2504 assert_eq!(request.messages.len(), 3);
2505 assert_eq!(
2506 request.messages[1].string_contents(),
2507 "What is the best way to learn Rust?"
2508 );
2509 assert_eq!(
2510 request.messages[2].string_contents(),
2511 "Are there any good books?"
2512 );
2513 }
2514
2515 #[gpui::test]
2516 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2517 init_test_settings(cx);
2518
2519 let project = create_test_project(
2520 cx,
2521 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2522 )
2523 .await;
2524
2525 let (_workspace, _thread_store, thread, context_store) =
2526 setup_test_environment(cx, project.clone()).await;
2527
2528 // Open buffer and add it to context
2529 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2530 .await
2531 .unwrap();
2532
2533 let context =
2534 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2535
2536 // Insert user message with the buffer as context
2537 thread.update(cx, |thread, cx| {
2538 thread.insert_user_message("Explain this code", vec![context], None, cx)
2539 });
2540
2541 // Create a request and check that it doesn't have a stale buffer warning yet
2542 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2543
2544 // Make sure we don't have a stale file warning yet
2545 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2546 msg.string_contents()
2547 .contains("These files changed since last read:")
2548 });
2549 assert!(
2550 !has_stale_warning,
2551 "Should not have stale buffer warning before buffer is modified"
2552 );
2553
2554 // Modify the buffer
2555 buffer.update(cx, |buffer, cx| {
2556 // Find a position at the end of line 1
2557 buffer.edit(
2558 [(1..1, "\n println!(\"Added a new line\");\n")],
2559 None,
2560 cx,
2561 );
2562 });
2563
2564 // Insert another user message without context
2565 thread.update(cx, |thread, cx| {
2566 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2567 });
2568
2569 // Create a new request and check for the stale buffer warning
2570 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2571
2572 // We should have a stale file warning as the last message
2573 let last_message = new_request
2574 .messages
2575 .last()
2576 .expect("Request should have messages");
2577
2578 // The last message should be the stale buffer notification
2579 assert_eq!(last_message.role, Role::User);
2580
2581 // Check the exact content of the message
2582 let expected_content = "These files changed since last read:\n- code.rs\n";
2583 assert_eq!(
2584 last_message.string_contents(),
2585 expected_content,
2586 "Last message should be exactly the stale buffer notification"
2587 );
2588 }
2589
2590 fn init_test_settings(cx: &mut TestAppContext) {
2591 cx.update(|cx| {
2592 let settings_store = SettingsStore::test(cx);
2593 cx.set_global(settings_store);
2594 language::init(cx);
2595 Project::init_settings(cx);
2596 AssistantSettings::register(cx);
2597 prompt_store::init(cx);
2598 thread_store::init(cx);
2599 workspace::init_settings(cx);
2600 ThemeSettings::register(cx);
2601 ContextServerSettings::register(cx);
2602 EditorSettings::register(cx);
2603 });
2604 }
2605
2606 // Helper to create a test project with test files
2607 async fn create_test_project(
2608 cx: &mut TestAppContext,
2609 files: serde_json::Value,
2610 ) -> Entity<Project> {
2611 let fs = FakeFs::new(cx.executor());
2612 fs.insert_tree(path!("/test"), files).await;
2613 Project::test(fs, [path!("/test").as_ref()], cx).await
2614 }
2615
2616 async fn setup_test_environment(
2617 cx: &mut TestAppContext,
2618 project: Entity<Project>,
2619 ) -> (
2620 Entity<Workspace>,
2621 Entity<ThreadStore>,
2622 Entity<Thread>,
2623 Entity<ContextStore>,
2624 ) {
2625 let (workspace, cx) =
2626 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2627
2628 let thread_store = cx
2629 .update(|_, cx| {
2630 ThreadStore::load(
2631 project.clone(),
2632 cx.new(|_| ToolWorkingSet::default()),
2633 Arc::new(PromptBuilder::new(None).unwrap()),
2634 cx,
2635 )
2636 })
2637 .await
2638 .unwrap();
2639
2640 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2641 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2642
2643 (workspace, thread_store, thread, context_store)
2644 }
2645
2646 async fn add_file_to_context(
2647 project: &Entity<Project>,
2648 context_store: &Entity<ContextStore>,
2649 path: &str,
2650 cx: &mut TestAppContext,
2651 ) -> Result<Entity<language::Buffer>> {
2652 let buffer_path = project
2653 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2654 .unwrap();
2655
2656 let buffer = project
2657 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2658 .await
2659 .unwrap();
2660
2661 context_store
2662 .update(cx, |store, cx| {
2663 store.add_file_from_buffer(buffer.clone(), cx)
2664 })
2665 .await?;
2666
2667 Ok(buffer)
2668 }
2669}