1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100}
101
102impl Message {
103 /// Returns whether the message contains any meaningful text that should be displayed
104 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
105 pub fn should_display_content(&self) -> bool {
106 self.segments.iter().all(|segment| segment.should_display())
107 }
108
109 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
110 if let Some(MessageSegment::Thinking {
111 text: segment,
112 signature: current_signature,
113 }) = self.segments.last_mut()
114 {
115 if let Some(signature) = signature {
116 *current_signature = Some(signature);
117 }
118 segment.push_str(text);
119 } else {
120 self.segments.push(MessageSegment::Thinking {
121 text: text.to_string(),
122 signature,
123 });
124 }
125 }
126
127 pub fn push_text(&mut self, text: &str) {
128 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
129 segment.push_str(text);
130 } else {
131 self.segments.push(MessageSegment::Text(text.to_string()));
132 }
133 }
134
135 pub fn to_string(&self) -> String {
136 let mut result = String::new();
137
138 if !self.context.is_empty() {
139 result.push_str(&self.context);
140 }
141
142 for segment in &self.segments {
143 match segment {
144 MessageSegment::Text(text) => result.push_str(text),
145 MessageSegment::Thinking { text, .. } => {
146 result.push_str("<think>\n");
147 result.push_str(text);
148 result.push_str("\n</think>");
149 }
150 MessageSegment::RedactedThinking(_) => {}
151 }
152 }
153
154 result
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum MessageSegment {
160 Text(String),
161 Thinking {
162 text: String,
163 signature: Option<String>,
164 },
165 RedactedThinking(Vec<u8>),
166}
167
168impl MessageSegment {
169 pub fn should_display(&self) -> bool {
170 match self {
171 Self::Text(text) => text.is_empty(),
172 Self::Thinking { text, .. } => text.is_empty(),
173 Self::RedactedThinking(_) => false,
174 }
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ProjectSnapshot {
180 pub worktree_snapshots: Vec<WorktreeSnapshot>,
181 pub unsaved_buffer_paths: Vec<String>,
182 pub timestamp: DateTime<Utc>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct WorktreeSnapshot {
187 pub worktree_path: String,
188 pub git_state: Option<GitState>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct GitState {
193 pub remote_url: Option<String>,
194 pub head_sha: Option<String>,
195 pub current_branch: Option<String>,
196 pub diff: Option<String>,
197}
198
199#[derive(Clone)]
200pub struct ThreadCheckpoint {
201 message_id: MessageId,
202 git_checkpoint: GitStoreCheckpoint,
203}
204
205#[derive(Copy, Clone, Debug, PartialEq, Eq)]
206pub enum ThreadFeedback {
207 Positive,
208 Negative,
209}
210
211pub enum LastRestoreCheckpoint {
212 Pending {
213 message_id: MessageId,
214 },
215 Error {
216 message_id: MessageId,
217 error: String,
218 },
219}
220
221impl LastRestoreCheckpoint {
222 pub fn message_id(&self) -> MessageId {
223 match self {
224 LastRestoreCheckpoint::Pending { message_id } => *message_id,
225 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
226 }
227 }
228}
229
230#[derive(Clone, Debug, Default, Serialize, Deserialize)]
231pub enum DetailedSummaryState {
232 #[default]
233 NotGenerated,
234 Generating {
235 message_id: MessageId,
236 },
237 Generated {
238 text: SharedString,
239 message_id: MessageId,
240 },
241}
242
243#[derive(Default)]
244pub struct TotalTokenUsage {
245 pub total: usize,
246 pub max: usize,
247}
248
249impl TotalTokenUsage {
250 pub fn ratio(&self) -> TokenUsageRatio {
251 #[cfg(debug_assertions)]
252 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
253 .unwrap_or("0.8".to_string())
254 .parse()
255 .unwrap();
256 #[cfg(not(debug_assertions))]
257 let warning_threshold: f32 = 0.8;
258
259 if self.total >= self.max {
260 TokenUsageRatio::Exceeded
261 } else if self.total as f32 / self.max as f32 >= warning_threshold {
262 TokenUsageRatio::Warning
263 } else {
264 TokenUsageRatio::Normal
265 }
266 }
267
268 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
269 TotalTokenUsage {
270 total: self.total + tokens,
271 max: self.max,
272 }
273 }
274}
275
276#[derive(Debug, Default, PartialEq, Eq)]
277pub enum TokenUsageRatio {
278 #[default]
279 Normal,
280 Warning,
281 Exceeded,
282}
283
284/// A thread of conversation with the LLM.
285pub struct Thread {
286 id: ThreadId,
287 updated_at: DateTime<Utc>,
288 summary: Option<SharedString>,
289 pending_summary: Task<Option<()>>,
290 detailed_summary_state: DetailedSummaryState,
291 messages: Vec<Message>,
292 next_message_id: MessageId,
293 last_prompt_id: PromptId,
294 context: BTreeMap<ContextId, AssistantContext>,
295 context_by_message: HashMap<MessageId, Vec<ContextId>>,
296 project_context: SharedProjectContext,
297 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
298 completion_count: usize,
299 pending_completions: Vec<PendingCompletion>,
300 project: Entity<Project>,
301 prompt_builder: Arc<PromptBuilder>,
302 tools: Entity<ToolWorkingSet>,
303 tool_use: ToolUseState,
304 action_log: Entity<ActionLog>,
305 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
306 pending_checkpoint: Option<ThreadCheckpoint>,
307 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
308 request_token_usage: Vec<TokenUsage>,
309 cumulative_token_usage: TokenUsage,
310 exceeded_window_error: Option<ExceededWindowError>,
311 feedback: Option<ThreadFeedback>,
312 message_feedback: HashMap<MessageId, ThreadFeedback>,
313 last_auto_capture_at: Option<Instant>,
314 request_callback: Option<
315 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
316 >,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct ExceededWindowError {
321 /// Model used when last message exceeded context window
322 model_id: LanguageModelId,
323 /// Token count including last message
324 token_count: usize,
325}
326
327impl Thread {
328 pub fn new(
329 project: Entity<Project>,
330 tools: Entity<ToolWorkingSet>,
331 prompt_builder: Arc<PromptBuilder>,
332 system_prompt: SharedProjectContext,
333 cx: &mut Context<Self>,
334 ) -> Self {
335 Self {
336 id: ThreadId::new(),
337 updated_at: Utc::now(),
338 summary: None,
339 pending_summary: Task::ready(None),
340 detailed_summary_state: DetailedSummaryState::NotGenerated,
341 messages: Vec::new(),
342 next_message_id: MessageId(0),
343 last_prompt_id: PromptId::new(),
344 context: BTreeMap::default(),
345 context_by_message: HashMap::default(),
346 project_context: system_prompt,
347 checkpoints_by_message: HashMap::default(),
348 completion_count: 0,
349 pending_completions: Vec::new(),
350 project: project.clone(),
351 prompt_builder,
352 tools: tools.clone(),
353 last_restore_checkpoint: None,
354 pending_checkpoint: None,
355 tool_use: ToolUseState::new(tools.clone()),
356 action_log: cx.new(|_| ActionLog::new(project.clone())),
357 initial_project_snapshot: {
358 let project_snapshot = Self::project_snapshot(project, cx);
359 cx.foreground_executor()
360 .spawn(async move { Some(project_snapshot.await) })
361 .shared()
362 },
363 request_token_usage: Vec::new(),
364 cumulative_token_usage: TokenUsage::default(),
365 exceeded_window_error: None,
366 feedback: None,
367 message_feedback: HashMap::default(),
368 last_auto_capture_at: None,
369 request_callback: None,
370 }
371 }
372
373 pub fn deserialize(
374 id: ThreadId,
375 serialized: SerializedThread,
376 project: Entity<Project>,
377 tools: Entity<ToolWorkingSet>,
378 prompt_builder: Arc<PromptBuilder>,
379 project_context: SharedProjectContext,
380 cx: &mut Context<Self>,
381 ) -> Self {
382 let next_message_id = MessageId(
383 serialized
384 .messages
385 .last()
386 .map(|message| message.id.0 + 1)
387 .unwrap_or(0),
388 );
389 let tool_use =
390 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
391
392 Self {
393 id,
394 updated_at: serialized.updated_at,
395 summary: Some(serialized.summary),
396 pending_summary: Task::ready(None),
397 detailed_summary_state: serialized.detailed_summary_state,
398 messages: serialized
399 .messages
400 .into_iter()
401 .map(|message| Message {
402 id: message.id,
403 role: message.role,
404 segments: message
405 .segments
406 .into_iter()
407 .map(|segment| match segment {
408 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
409 SerializedMessageSegment::Thinking { text, signature } => {
410 MessageSegment::Thinking { text, signature }
411 }
412 SerializedMessageSegment::RedactedThinking { data } => {
413 MessageSegment::RedactedThinking(data)
414 }
415 })
416 .collect(),
417 context: message.context,
418 })
419 .collect(),
420 next_message_id,
421 last_prompt_id: PromptId::new(),
422 context: BTreeMap::default(),
423 context_by_message: HashMap::default(),
424 project_context,
425 checkpoints_by_message: HashMap::default(),
426 completion_count: 0,
427 pending_completions: Vec::new(),
428 last_restore_checkpoint: None,
429 pending_checkpoint: None,
430 project: project.clone(),
431 prompt_builder,
432 tools,
433 tool_use,
434 action_log: cx.new(|_| ActionLog::new(project)),
435 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
436 request_token_usage: serialized.request_token_usage,
437 cumulative_token_usage: serialized.cumulative_token_usage,
438 exceeded_window_error: None,
439 feedback: None,
440 message_feedback: HashMap::default(),
441 last_auto_capture_at: None,
442 request_callback: None,
443 }
444 }
445
446 pub fn set_request_callback(
447 &mut self,
448 callback: impl 'static
449 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
450 ) {
451 self.request_callback = Some(Box::new(callback));
452 }
453
454 pub fn id(&self) -> &ThreadId {
455 &self.id
456 }
457
458 pub fn is_empty(&self) -> bool {
459 self.messages.is_empty()
460 }
461
462 pub fn updated_at(&self) -> DateTime<Utc> {
463 self.updated_at
464 }
465
466 pub fn touch_updated_at(&mut self) {
467 self.updated_at = Utc::now();
468 }
469
470 pub fn advance_prompt_id(&mut self) {
471 self.last_prompt_id = PromptId::new();
472 }
473
474 pub fn summary(&self) -> Option<SharedString> {
475 self.summary.clone()
476 }
477
478 pub fn project_context(&self) -> SharedProjectContext {
479 self.project_context.clone()
480 }
481
482 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
483
484 pub fn summary_or_default(&self) -> SharedString {
485 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
486 }
487
488 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
489 let Some(current_summary) = &self.summary else {
490 // Don't allow setting summary until generated
491 return;
492 };
493
494 let mut new_summary = new_summary.into();
495
496 if new_summary.is_empty() {
497 new_summary = Self::DEFAULT_SUMMARY;
498 }
499
500 if current_summary != &new_summary {
501 self.summary = Some(new_summary);
502 cx.emit(ThreadEvent::SummaryChanged);
503 }
504 }
505
506 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
507 self.latest_detailed_summary()
508 .unwrap_or_else(|| self.text().into())
509 }
510
511 fn latest_detailed_summary(&self) -> Option<SharedString> {
512 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
513 Some(text.clone())
514 } else {
515 None
516 }
517 }
518
519 pub fn message(&self, id: MessageId) -> Option<&Message> {
520 self.messages.iter().find(|message| message.id == id)
521 }
522
523 pub fn messages(&self) -> impl Iterator<Item = &Message> {
524 self.messages.iter()
525 }
526
527 pub fn is_generating(&self) -> bool {
528 !self.pending_completions.is_empty() || !self.all_tools_finished()
529 }
530
531 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
532 &self.tools
533 }
534
535 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
536 self.tool_use
537 .pending_tool_uses()
538 .into_iter()
539 .find(|tool_use| &tool_use.id == id)
540 }
541
542 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
543 self.tool_use
544 .pending_tool_uses()
545 .into_iter()
546 .filter(|tool_use| tool_use.status.needs_confirmation())
547 }
548
549 pub fn has_pending_tool_uses(&self) -> bool {
550 !self.tool_use.pending_tool_uses().is_empty()
551 }
552
553 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
554 self.checkpoints_by_message.get(&id).cloned()
555 }
556
557 pub fn restore_checkpoint(
558 &mut self,
559 checkpoint: ThreadCheckpoint,
560 cx: &mut Context<Self>,
561 ) -> Task<Result<()>> {
562 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
563 message_id: checkpoint.message_id,
564 });
565 cx.emit(ThreadEvent::CheckpointChanged);
566 cx.notify();
567
568 let git_store = self.project().read(cx).git_store().clone();
569 let restore = git_store.update(cx, |git_store, cx| {
570 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
571 });
572
573 cx.spawn(async move |this, cx| {
574 let result = restore.await;
575 this.update(cx, |this, cx| {
576 if let Err(err) = result.as_ref() {
577 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
578 message_id: checkpoint.message_id,
579 error: err.to_string(),
580 });
581 } else {
582 this.truncate(checkpoint.message_id, cx);
583 this.last_restore_checkpoint = None;
584 }
585 this.pending_checkpoint = None;
586 cx.emit(ThreadEvent::CheckpointChanged);
587 cx.notify();
588 })?;
589 result
590 })
591 }
592
593 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
594 let pending_checkpoint = if self.is_generating() {
595 return;
596 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
597 checkpoint
598 } else {
599 return;
600 };
601
602 let git_store = self.project.read(cx).git_store().clone();
603 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
604 cx.spawn(async move |this, cx| match final_checkpoint.await {
605 Ok(final_checkpoint) => {
606 let equal = git_store
607 .update(cx, |store, cx| {
608 store.compare_checkpoints(
609 pending_checkpoint.git_checkpoint.clone(),
610 final_checkpoint.clone(),
611 cx,
612 )
613 })?
614 .await
615 .unwrap_or(false);
616
617 if equal {
618 git_store
619 .update(cx, |store, cx| {
620 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
621 })?
622 .detach();
623 } else {
624 this.update(cx, |this, cx| {
625 this.insert_checkpoint(pending_checkpoint, cx)
626 })?;
627 }
628
629 git_store
630 .update(cx, |store, cx| {
631 store.delete_checkpoint(final_checkpoint, cx)
632 })?
633 .detach();
634
635 Ok(())
636 }
637 Err(_) => this.update(cx, |this, cx| {
638 this.insert_checkpoint(pending_checkpoint, cx)
639 }),
640 })
641 .detach();
642 }
643
644 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
645 self.checkpoints_by_message
646 .insert(checkpoint.message_id, checkpoint);
647 cx.emit(ThreadEvent::CheckpointChanged);
648 cx.notify();
649 }
650
651 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
652 self.last_restore_checkpoint.as_ref()
653 }
654
655 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
656 let Some(message_ix) = self
657 .messages
658 .iter()
659 .rposition(|message| message.id == message_id)
660 else {
661 return;
662 };
663 for deleted_message in self.messages.drain(message_ix..) {
664 self.context_by_message.remove(&deleted_message.id);
665 self.checkpoints_by_message.remove(&deleted_message.id);
666 }
667 cx.notify();
668 }
669
670 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
671 self.context_by_message
672 .get(&id)
673 .into_iter()
674 .flat_map(|context| {
675 context
676 .iter()
677 .filter_map(|context_id| self.context.get(&context_id))
678 })
679 }
680
681 /// Returns whether all of the tool uses have finished running.
682 pub fn all_tools_finished(&self) -> bool {
683 // If the only pending tool uses left are the ones with errors, then
684 // that means that we've finished running all of the pending tools.
685 self.tool_use
686 .pending_tool_uses()
687 .iter()
688 .all(|tool_use| tool_use.status.is_error())
689 }
690
691 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
692 self.tool_use.tool_uses_for_message(id, cx)
693 }
694
695 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
696 self.tool_use.tool_results_for_message(id)
697 }
698
699 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
700 self.tool_use.tool_result(id)
701 }
702
703 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
704 Some(&self.tool_use.tool_result(id)?.content)
705 }
706
707 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
708 self.tool_use.tool_result_card(id).cloned()
709 }
710
711 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
712 self.tool_use.message_has_tool_results(message_id)
713 }
714
715 /// Filter out contexts that have already been included in previous messages
716 pub fn filter_new_context<'a>(
717 &self,
718 context: impl Iterator<Item = &'a AssistantContext>,
719 ) -> impl Iterator<Item = &'a AssistantContext> {
720 context.filter(|ctx| self.is_context_new(ctx))
721 }
722
723 fn is_context_new(&self, context: &AssistantContext) -> bool {
724 !self.context.contains_key(&context.id())
725 }
726
727 pub fn insert_user_message(
728 &mut self,
729 text: impl Into<String>,
730 context: Vec<AssistantContext>,
731 git_checkpoint: Option<GitStoreCheckpoint>,
732 cx: &mut Context<Self>,
733 ) -> MessageId {
734 let text = text.into();
735
736 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
737
738 let new_context: Vec<_> = context
739 .into_iter()
740 .filter(|ctx| self.is_context_new(ctx))
741 .collect();
742
743 if !new_context.is_empty() {
744 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
745 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
746 message.context = context_string;
747 }
748 }
749
750 self.action_log.update(cx, |log, cx| {
751 // Track all buffers added as context
752 for ctx in &new_context {
753 match ctx {
754 AssistantContext::File(file_ctx) => {
755 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
756 }
757 AssistantContext::Directory(dir_ctx) => {
758 for context_buffer in &dir_ctx.context_buffers {
759 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
760 }
761 }
762 AssistantContext::Symbol(symbol_ctx) => {
763 log.buffer_added_as_context(
764 symbol_ctx.context_symbol.buffer.clone(),
765 cx,
766 );
767 }
768 AssistantContext::Excerpt(excerpt_context) => {
769 log.buffer_added_as_context(
770 excerpt_context.context_buffer.buffer.clone(),
771 cx,
772 );
773 }
774 AssistantContext::FetchedUrl(_)
775 | AssistantContext::Thread(_)
776 | AssistantContext::Rules(_) => {}
777 }
778 }
779 });
780 }
781
782 let context_ids = new_context
783 .iter()
784 .map(|context| context.id())
785 .collect::<Vec<_>>();
786 self.context.extend(
787 new_context
788 .into_iter()
789 .map(|context| (context.id(), context)),
790 );
791 self.context_by_message.insert(message_id, context_ids);
792
793 if let Some(git_checkpoint) = git_checkpoint {
794 self.pending_checkpoint = Some(ThreadCheckpoint {
795 message_id,
796 git_checkpoint,
797 });
798 }
799
800 self.auto_capture_telemetry(cx);
801
802 message_id
803 }
804
805 pub fn insert_message(
806 &mut self,
807 role: Role,
808 segments: Vec<MessageSegment>,
809 cx: &mut Context<Self>,
810 ) -> MessageId {
811 let id = self.next_message_id.post_inc();
812 self.messages.push(Message {
813 id,
814 role,
815 segments,
816 context: String::new(),
817 });
818 self.touch_updated_at();
819 cx.emit(ThreadEvent::MessageAdded(id));
820 id
821 }
822
823 pub fn edit_message(
824 &mut self,
825 id: MessageId,
826 new_role: Role,
827 new_segments: Vec<MessageSegment>,
828 cx: &mut Context<Self>,
829 ) -> bool {
830 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
831 return false;
832 };
833 message.role = new_role;
834 message.segments = new_segments;
835 self.touch_updated_at();
836 cx.emit(ThreadEvent::MessageEdited(id));
837 true
838 }
839
840 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
841 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
842 return false;
843 };
844 self.messages.remove(index);
845 self.context_by_message.remove(&id);
846 self.touch_updated_at();
847 cx.emit(ThreadEvent::MessageDeleted(id));
848 true
849 }
850
851 /// Returns the representation of this [`Thread`] in a textual form.
852 ///
853 /// This is the representation we use when attaching a thread as context to another thread.
854 pub fn text(&self) -> String {
855 let mut text = String::new();
856
857 for message in &self.messages {
858 text.push_str(match message.role {
859 language_model::Role::User => "User:",
860 language_model::Role::Assistant => "Assistant:",
861 language_model::Role::System => "System:",
862 });
863 text.push('\n');
864
865 for segment in &message.segments {
866 match segment {
867 MessageSegment::Text(content) => text.push_str(content),
868 MessageSegment::Thinking { text: content, .. } => {
869 text.push_str(&format!("<think>{}</think>", content))
870 }
871 MessageSegment::RedactedThinking(_) => {}
872 }
873 }
874 text.push('\n');
875 }
876
877 text
878 }
879
880 /// Serializes this thread into a format for storage or telemetry.
881 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
882 let initial_project_snapshot = self.initial_project_snapshot.clone();
883 cx.spawn(async move |this, cx| {
884 let initial_project_snapshot = initial_project_snapshot.await;
885 this.read_with(cx, |this, cx| SerializedThread {
886 version: SerializedThread::VERSION.to_string(),
887 summary: this.summary_or_default(),
888 updated_at: this.updated_at(),
889 messages: this
890 .messages()
891 .map(|message| SerializedMessage {
892 id: message.id,
893 role: message.role,
894 segments: message
895 .segments
896 .iter()
897 .map(|segment| match segment {
898 MessageSegment::Text(text) => {
899 SerializedMessageSegment::Text { text: text.clone() }
900 }
901 MessageSegment::Thinking { text, signature } => {
902 SerializedMessageSegment::Thinking {
903 text: text.clone(),
904 signature: signature.clone(),
905 }
906 }
907 MessageSegment::RedactedThinking(data) => {
908 SerializedMessageSegment::RedactedThinking {
909 data: data.clone(),
910 }
911 }
912 })
913 .collect(),
914 tool_uses: this
915 .tool_uses_for_message(message.id, cx)
916 .into_iter()
917 .map(|tool_use| SerializedToolUse {
918 id: tool_use.id,
919 name: tool_use.name,
920 input: tool_use.input,
921 })
922 .collect(),
923 tool_results: this
924 .tool_results_for_message(message.id)
925 .into_iter()
926 .map(|tool_result| SerializedToolResult {
927 tool_use_id: tool_result.tool_use_id.clone(),
928 is_error: tool_result.is_error,
929 content: tool_result.content.clone(),
930 })
931 .collect(),
932 context: message.context.clone(),
933 })
934 .collect(),
935 initial_project_snapshot,
936 cumulative_token_usage: this.cumulative_token_usage,
937 request_token_usage: this.request_token_usage.clone(),
938 detailed_summary_state: this.detailed_summary_state.clone(),
939 exceeded_window_error: this.exceeded_window_error.clone(),
940 })
941 })
942 }
943
944 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
945 let mut request = self.to_completion_request(cx);
946 if model.supports_tools() {
947 request.tools = {
948 let mut tools = Vec::new();
949 tools.extend(
950 self.tools()
951 .read(cx)
952 .enabled_tools(cx)
953 .into_iter()
954 .filter_map(|tool| {
955 // Skip tools that cannot be supported
956 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
957 Some(LanguageModelRequestTool {
958 name: tool.name(),
959 description: tool.description(),
960 input_schema,
961 })
962 }),
963 );
964
965 tools
966 };
967 }
968
969 self.stream_completion(request, model, cx);
970 }
971
972 pub fn used_tools_since_last_user_message(&self) -> bool {
973 for message in self.messages.iter().rev() {
974 if self.tool_use.message_has_tool_results(message.id) {
975 return true;
976 } else if message.role == Role::User {
977 return false;
978 }
979 }
980
981 false
982 }
983
984 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
985 let mut request = LanguageModelRequest {
986 thread_id: Some(self.id.to_string()),
987 prompt_id: Some(self.last_prompt_id.to_string()),
988 messages: vec![],
989 tools: Vec::new(),
990 stop: Vec::new(),
991 temperature: None,
992 };
993
994 if let Some(project_context) = self.project_context.borrow().as_ref() {
995 match self
996 .prompt_builder
997 .generate_assistant_system_prompt(project_context)
998 {
999 Err(err) => {
1000 let message = format!("{err:?}").into();
1001 log::error!("{message}");
1002 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1003 header: "Error generating system prompt".into(),
1004 message,
1005 }));
1006 }
1007 Ok(system_prompt) => {
1008 request.messages.push(LanguageModelRequestMessage {
1009 role: Role::System,
1010 content: vec![MessageContent::Text(system_prompt)],
1011 cache: true,
1012 });
1013 }
1014 }
1015 } else {
1016 let message = "Context for system prompt unexpectedly not ready.".into();
1017 log::error!("{message}");
1018 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1019 header: "Error generating system prompt".into(),
1020 message,
1021 }));
1022 }
1023
1024 for message in &self.messages {
1025 let mut request_message = LanguageModelRequestMessage {
1026 role: message.role,
1027 content: Vec::new(),
1028 cache: false,
1029 };
1030
1031 self.tool_use
1032 .attach_tool_results(message.id, &mut request_message);
1033
1034 if !message.context.is_empty() {
1035 request_message
1036 .content
1037 .push(MessageContent::Text(message.context.to_string()));
1038 }
1039
1040 for segment in &message.segments {
1041 match segment {
1042 MessageSegment::Text(text) => {
1043 if !text.is_empty() {
1044 request_message
1045 .content
1046 .push(MessageContent::Text(text.into()));
1047 }
1048 }
1049 MessageSegment::Thinking { text, signature } => {
1050 if !text.is_empty() {
1051 request_message.content.push(MessageContent::Thinking {
1052 text: text.into(),
1053 signature: signature.clone(),
1054 });
1055 }
1056 }
1057 MessageSegment::RedactedThinking(data) => {
1058 request_message
1059 .content
1060 .push(MessageContent::RedactedThinking(data.clone()));
1061 }
1062 };
1063 }
1064
1065 self.tool_use
1066 .attach_tool_uses(message.id, &mut request_message);
1067
1068 request.messages.push(request_message);
1069 }
1070
1071 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1072 if let Some(last) = request.messages.last_mut() {
1073 last.cache = true;
1074 }
1075
1076 self.attached_tracked_files_state(&mut request.messages, cx);
1077
1078 request
1079 }
1080
1081 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1082 let mut request = LanguageModelRequest {
1083 thread_id: None,
1084 prompt_id: None,
1085 messages: vec![],
1086 tools: Vec::new(),
1087 stop: Vec::new(),
1088 temperature: None,
1089 };
1090
1091 for message in &self.messages {
1092 let mut request_message = LanguageModelRequestMessage {
1093 role: message.role,
1094 content: Vec::new(),
1095 cache: false,
1096 };
1097
1098 // Skip tool results during summarization.
1099 if self.tool_use.message_has_tool_results(message.id) {
1100 continue;
1101 }
1102
1103 for segment in &message.segments {
1104 match segment {
1105 MessageSegment::Text(text) => request_message
1106 .content
1107 .push(MessageContent::Text(text.clone())),
1108 MessageSegment::Thinking { .. } => {}
1109 MessageSegment::RedactedThinking(_) => {}
1110 }
1111 }
1112
1113 if request_message.content.is_empty() {
1114 continue;
1115 }
1116
1117 request.messages.push(request_message);
1118 }
1119
1120 request.messages.push(LanguageModelRequestMessage {
1121 role: Role::User,
1122 content: vec![MessageContent::Text(added_user_message)],
1123 cache: false,
1124 });
1125
1126 request
1127 }
1128
1129 fn attached_tracked_files_state(
1130 &self,
1131 messages: &mut Vec<LanguageModelRequestMessage>,
1132 cx: &App,
1133 ) {
1134 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1135
1136 let mut stale_message = String::new();
1137
1138 let action_log = self.action_log.read(cx);
1139
1140 for stale_file in action_log.stale_buffers(cx) {
1141 let Some(file) = stale_file.read(cx).file() else {
1142 continue;
1143 };
1144
1145 if stale_message.is_empty() {
1146 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1147 }
1148
1149 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1150 }
1151
1152 let mut content = Vec::with_capacity(2);
1153
1154 if !stale_message.is_empty() {
1155 content.push(stale_message.into());
1156 }
1157
1158 if !content.is_empty() {
1159 let context_message = LanguageModelRequestMessage {
1160 role: Role::User,
1161 content,
1162 cache: false,
1163 };
1164
1165 messages.push(context_message);
1166 }
1167 }
1168
1169 pub fn stream_completion(
1170 &mut self,
1171 request: LanguageModelRequest,
1172 model: Arc<dyn LanguageModel>,
1173 cx: &mut Context<Self>,
1174 ) {
1175 let pending_completion_id = post_inc(&mut self.completion_count);
1176 let mut request_callback_parameters = if self.request_callback.is_some() {
1177 Some((request.clone(), Vec::new()))
1178 } else {
1179 None
1180 };
1181 let prompt_id = self.last_prompt_id.clone();
1182 let tool_use_metadata = ToolUseMetadata {
1183 model: model.clone(),
1184 thread_id: self.id.clone(),
1185 prompt_id: prompt_id.clone(),
1186 };
1187
1188 let task = cx.spawn(async move |thread, cx| {
1189 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1190 let initial_token_usage =
1191 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1192 let stream_completion = async {
1193 let (mut events, usage) = stream_completion_future.await?;
1194
1195 let mut stop_reason = StopReason::EndTurn;
1196 let mut current_token_usage = TokenUsage::default();
1197
1198 if let Some(usage) = usage {
1199 thread
1200 .update(cx, |_thread, cx| {
1201 cx.emit(ThreadEvent::UsageUpdated(usage));
1202 })
1203 .ok();
1204 }
1205
1206 while let Some(event) = events.next().await {
1207 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1208 response_events
1209 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1210 }
1211
1212 let event = event?;
1213
1214 thread.update(cx, |thread, cx| {
1215 match event {
1216 LanguageModelCompletionEvent::StartMessage { .. } => {
1217 thread.insert_message(
1218 Role::Assistant,
1219 vec![MessageSegment::Text(String::new())],
1220 cx,
1221 );
1222 }
1223 LanguageModelCompletionEvent::Stop(reason) => {
1224 stop_reason = reason;
1225 }
1226 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1227 thread.update_token_usage_at_last_message(token_usage);
1228 thread.cumulative_token_usage = thread.cumulative_token_usage
1229 + token_usage
1230 - current_token_usage;
1231 current_token_usage = token_usage;
1232 }
1233 LanguageModelCompletionEvent::Text(chunk) => {
1234 if let Some(last_message) = thread.messages.last_mut() {
1235 if last_message.role == Role::Assistant {
1236 last_message.push_text(&chunk);
1237 cx.emit(ThreadEvent::StreamedAssistantText(
1238 last_message.id,
1239 chunk,
1240 ));
1241 } else {
1242 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1243 // of a new Assistant response.
1244 //
1245 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1246 // will result in duplicating the text of the chunk in the rendered Markdown.
1247 thread.insert_message(
1248 Role::Assistant,
1249 vec![MessageSegment::Text(chunk.to_string())],
1250 cx,
1251 );
1252 };
1253 }
1254 }
1255 LanguageModelCompletionEvent::Thinking {
1256 text: chunk,
1257 signature,
1258 } => {
1259 if let Some(last_message) = thread.messages.last_mut() {
1260 if last_message.role == Role::Assistant {
1261 last_message.push_thinking(&chunk, signature);
1262 cx.emit(ThreadEvent::StreamedAssistantThinking(
1263 last_message.id,
1264 chunk,
1265 ));
1266 } else {
1267 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1268 // of a new Assistant response.
1269 //
1270 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1271 // will result in duplicating the text of the chunk in the rendered Markdown.
1272 thread.insert_message(
1273 Role::Assistant,
1274 vec![MessageSegment::Thinking {
1275 text: chunk.to_string(),
1276 signature,
1277 }],
1278 cx,
1279 );
1280 };
1281 }
1282 }
1283 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1284 let last_assistant_message_id = thread
1285 .messages
1286 .iter_mut()
1287 .rfind(|message| message.role == Role::Assistant)
1288 .map(|message| message.id)
1289 .unwrap_or_else(|| {
1290 thread.insert_message(Role::Assistant, vec![], cx)
1291 });
1292
1293 thread.tool_use.request_tool_use(
1294 last_assistant_message_id,
1295 tool_use,
1296 tool_use_metadata.clone(),
1297 cx,
1298 );
1299 }
1300 }
1301
1302 thread.touch_updated_at();
1303 cx.emit(ThreadEvent::StreamedCompletion);
1304 cx.notify();
1305
1306 thread.auto_capture_telemetry(cx);
1307 })?;
1308
1309 smol::future::yield_now().await;
1310 }
1311
1312 thread.update(cx, |thread, cx| {
1313 thread
1314 .pending_completions
1315 .retain(|completion| completion.id != pending_completion_id);
1316
1317 // If there is a response without tool use, summarize the message. Otherwise,
1318 // allow two tool uses before summarizing.
1319 if thread.summary.is_none()
1320 && thread.messages.len() >= 2
1321 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1322 {
1323 thread.summarize(cx);
1324 }
1325 })?;
1326
1327 anyhow::Ok(stop_reason)
1328 };
1329
1330 let result = stream_completion.await;
1331
1332 thread
1333 .update(cx, |thread, cx| {
1334 thread.finalize_pending_checkpoint(cx);
1335 match result.as_ref() {
1336 Ok(stop_reason) => match stop_reason {
1337 StopReason::ToolUse => {
1338 let tool_uses = thread.use_pending_tools(cx);
1339 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1340 }
1341 StopReason::EndTurn => {}
1342 StopReason::MaxTokens => {}
1343 },
1344 Err(error) => {
1345 if error.is::<PaymentRequiredError>() {
1346 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1347 } else if error.is::<MaxMonthlySpendReachedError>() {
1348 cx.emit(ThreadEvent::ShowError(
1349 ThreadError::MaxMonthlySpendReached,
1350 ));
1351 } else if let Some(error) =
1352 error.downcast_ref::<ModelRequestLimitReachedError>()
1353 {
1354 cx.emit(ThreadEvent::ShowError(
1355 ThreadError::ModelRequestLimitReached { plan: error.plan },
1356 ));
1357 } else if let Some(known_error) =
1358 error.downcast_ref::<LanguageModelKnownError>()
1359 {
1360 match known_error {
1361 LanguageModelKnownError::ContextWindowLimitExceeded {
1362 tokens,
1363 } => {
1364 thread.exceeded_window_error = Some(ExceededWindowError {
1365 model_id: model.id(),
1366 token_count: *tokens,
1367 });
1368 cx.notify();
1369 }
1370 }
1371 } else {
1372 let error_message = error
1373 .chain()
1374 .map(|err| err.to_string())
1375 .collect::<Vec<_>>()
1376 .join("\n");
1377 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1378 header: "Error interacting with language model".into(),
1379 message: SharedString::from(error_message.clone()),
1380 }));
1381 }
1382
1383 thread.cancel_last_completion(cx);
1384 }
1385 }
1386 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1387
1388 if let Some((request_callback, (request, response_events))) = thread
1389 .request_callback
1390 .as_mut()
1391 .zip(request_callback_parameters.as_ref())
1392 {
1393 request_callback(request, response_events);
1394 }
1395
1396 thread.auto_capture_telemetry(cx);
1397
1398 if let Ok(initial_usage) = initial_token_usage {
1399 let usage = thread.cumulative_token_usage - initial_usage;
1400
1401 telemetry::event!(
1402 "Assistant Thread Completion",
1403 thread_id = thread.id().to_string(),
1404 prompt_id = prompt_id,
1405 model = model.telemetry_id(),
1406 model_provider = model.provider_id().to_string(),
1407 input_tokens = usage.input_tokens,
1408 output_tokens = usage.output_tokens,
1409 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1410 cache_read_input_tokens = usage.cache_read_input_tokens,
1411 );
1412 }
1413 })
1414 .ok();
1415 });
1416
1417 self.pending_completions.push(PendingCompletion {
1418 id: pending_completion_id,
1419 _task: task,
1420 });
1421 }
1422
1423 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1424 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1425 return;
1426 };
1427
1428 if !model.provider.is_authenticated(cx) {
1429 return;
1430 }
1431
1432 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1433 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1434 If the conversation is about a specific subject, include it in the title. \
1435 Be descriptive. DO NOT speak in the first person.";
1436
1437 let request = self.to_summarize_request(added_user_message.into());
1438
1439 self.pending_summary = cx.spawn(async move |this, cx| {
1440 async move {
1441 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1442 let (mut messages, usage) = stream.await?;
1443
1444 if let Some(usage) = usage {
1445 this.update(cx, |_thread, cx| {
1446 cx.emit(ThreadEvent::UsageUpdated(usage));
1447 })
1448 .ok();
1449 }
1450
1451 let mut new_summary = String::new();
1452 while let Some(message) = messages.stream.next().await {
1453 let text = message?;
1454 let mut lines = text.lines();
1455 new_summary.extend(lines.next());
1456
1457 // Stop if the LLM generated multiple lines.
1458 if lines.next().is_some() {
1459 break;
1460 }
1461 }
1462
1463 this.update(cx, |this, cx| {
1464 if !new_summary.is_empty() {
1465 this.summary = Some(new_summary.into());
1466 }
1467
1468 cx.emit(ThreadEvent::SummaryGenerated);
1469 })?;
1470
1471 anyhow::Ok(())
1472 }
1473 .log_err()
1474 .await
1475 });
1476 }
1477
1478 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1479 let last_message_id = self.messages.last().map(|message| message.id)?;
1480
1481 match &self.detailed_summary_state {
1482 DetailedSummaryState::Generating { message_id, .. }
1483 | DetailedSummaryState::Generated { message_id, .. }
1484 if *message_id == last_message_id =>
1485 {
1486 // Already up-to-date
1487 return None;
1488 }
1489 _ => {}
1490 }
1491
1492 let ConfiguredModel { model, provider } =
1493 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1494
1495 if !provider.is_authenticated(cx) {
1496 return None;
1497 }
1498
1499 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1500 1. A brief overview of what was discussed\n\
1501 2. Key facts or information discovered\n\
1502 3. Outcomes or conclusions reached\n\
1503 4. Any action items or next steps if any\n\
1504 Format it in Markdown with headings and bullet points.";
1505
1506 let request = self.to_summarize_request(added_user_message.into());
1507
1508 let task = cx.spawn(async move |thread, cx| {
1509 let stream = model.stream_completion_text(request, &cx);
1510 let Some(mut messages) = stream.await.log_err() else {
1511 thread
1512 .update(cx, |this, _cx| {
1513 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1514 })
1515 .log_err();
1516
1517 return;
1518 };
1519
1520 let mut new_detailed_summary = String::new();
1521
1522 while let Some(chunk) = messages.stream.next().await {
1523 if let Some(chunk) = chunk.log_err() {
1524 new_detailed_summary.push_str(&chunk);
1525 }
1526 }
1527
1528 thread
1529 .update(cx, |this, _cx| {
1530 this.detailed_summary_state = DetailedSummaryState::Generated {
1531 text: new_detailed_summary.into(),
1532 message_id: last_message_id,
1533 };
1534 })
1535 .log_err();
1536 });
1537
1538 self.detailed_summary_state = DetailedSummaryState::Generating {
1539 message_id: last_message_id,
1540 };
1541
1542 Some(task)
1543 }
1544
1545 pub fn is_generating_detailed_summary(&self) -> bool {
1546 matches!(
1547 self.detailed_summary_state,
1548 DetailedSummaryState::Generating { .. }
1549 )
1550 }
1551
1552 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1553 self.auto_capture_telemetry(cx);
1554 let request = self.to_completion_request(cx);
1555 let messages = Arc::new(request.messages);
1556 let pending_tool_uses = self
1557 .tool_use
1558 .pending_tool_uses()
1559 .into_iter()
1560 .filter(|tool_use| tool_use.status.is_idle())
1561 .cloned()
1562 .collect::<Vec<_>>();
1563
1564 for tool_use in pending_tool_uses.iter() {
1565 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1566 if tool.needs_confirmation(&tool_use.input, cx)
1567 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1568 {
1569 self.tool_use.confirm_tool_use(
1570 tool_use.id.clone(),
1571 tool_use.ui_text.clone(),
1572 tool_use.input.clone(),
1573 messages.clone(),
1574 tool,
1575 );
1576 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1577 } else {
1578 self.run_tool(
1579 tool_use.id.clone(),
1580 tool_use.ui_text.clone(),
1581 tool_use.input.clone(),
1582 &messages,
1583 tool,
1584 cx,
1585 );
1586 }
1587 }
1588 }
1589
1590 pending_tool_uses
1591 }
1592
1593 pub fn run_tool(
1594 &mut self,
1595 tool_use_id: LanguageModelToolUseId,
1596 ui_text: impl Into<SharedString>,
1597 input: serde_json::Value,
1598 messages: &[LanguageModelRequestMessage],
1599 tool: Arc<dyn Tool>,
1600 cx: &mut Context<Thread>,
1601 ) {
1602 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1603 self.tool_use
1604 .run_pending_tool(tool_use_id, ui_text.into(), task);
1605 }
1606
1607 fn spawn_tool_use(
1608 &mut self,
1609 tool_use_id: LanguageModelToolUseId,
1610 messages: &[LanguageModelRequestMessage],
1611 input: serde_json::Value,
1612 tool: Arc<dyn Tool>,
1613 cx: &mut Context<Thread>,
1614 ) -> Task<()> {
1615 let tool_name: Arc<str> = tool.name().into();
1616
1617 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1618 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1619 } else {
1620 tool.run(
1621 input,
1622 messages,
1623 self.project.clone(),
1624 self.action_log.clone(),
1625 cx,
1626 )
1627 };
1628
1629 // Store the card separately if it exists
1630 if let Some(card) = tool_result.card.clone() {
1631 self.tool_use
1632 .insert_tool_result_card(tool_use_id.clone(), card);
1633 }
1634
1635 cx.spawn({
1636 async move |thread: WeakEntity<Thread>, cx| {
1637 let output = tool_result.output.await;
1638
1639 thread
1640 .update(cx, |thread, cx| {
1641 let pending_tool_use = thread.tool_use.insert_tool_output(
1642 tool_use_id.clone(),
1643 tool_name,
1644 output,
1645 cx,
1646 );
1647 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1648 })
1649 .ok();
1650 }
1651 })
1652 }
1653
1654 fn tool_finished(
1655 &mut self,
1656 tool_use_id: LanguageModelToolUseId,
1657 pending_tool_use: Option<PendingToolUse>,
1658 canceled: bool,
1659 cx: &mut Context<Self>,
1660 ) {
1661 if self.all_tools_finished() {
1662 let model_registry = LanguageModelRegistry::read_global(cx);
1663 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1664 self.attach_tool_results(cx);
1665 if !canceled {
1666 self.send_to_model(model, cx);
1667 }
1668 }
1669 }
1670
1671 cx.emit(ThreadEvent::ToolFinished {
1672 tool_use_id,
1673 pending_tool_use,
1674 });
1675 }
1676
1677 /// Insert an empty message to be populated with tool results upon send.
1678 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1679 // Tool results are assumed to be waiting on the next message id, so they will populate
1680 // this empty message before sending to model. Would prefer this to be more straightforward.
1681 self.insert_message(Role::User, vec![], cx);
1682 self.auto_capture_telemetry(cx);
1683 }
1684
1685 /// Cancels the last pending completion, if there are any pending.
1686 ///
1687 /// Returns whether a completion was canceled.
1688 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1689 let canceled = if self.pending_completions.pop().is_some() {
1690 true
1691 } else {
1692 let mut canceled = false;
1693 for pending_tool_use in self.tool_use.cancel_pending() {
1694 canceled = true;
1695 self.tool_finished(
1696 pending_tool_use.id.clone(),
1697 Some(pending_tool_use),
1698 true,
1699 cx,
1700 );
1701 }
1702 canceled
1703 };
1704 self.finalize_pending_checkpoint(cx);
1705 canceled
1706 }
1707
1708 pub fn feedback(&self) -> Option<ThreadFeedback> {
1709 self.feedback
1710 }
1711
1712 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1713 self.message_feedback.get(&message_id).copied()
1714 }
1715
1716 pub fn report_message_feedback(
1717 &mut self,
1718 message_id: MessageId,
1719 feedback: ThreadFeedback,
1720 cx: &mut Context<Self>,
1721 ) -> Task<Result<()>> {
1722 if self.message_feedback.get(&message_id) == Some(&feedback) {
1723 return Task::ready(Ok(()));
1724 }
1725
1726 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1727 let serialized_thread = self.serialize(cx);
1728 let thread_id = self.id().clone();
1729 let client = self.project.read(cx).client();
1730
1731 let enabled_tool_names: Vec<String> = self
1732 .tools()
1733 .read(cx)
1734 .enabled_tools(cx)
1735 .iter()
1736 .map(|tool| tool.name().to_string())
1737 .collect();
1738
1739 self.message_feedback.insert(message_id, feedback);
1740
1741 cx.notify();
1742
1743 let message_content = self
1744 .message(message_id)
1745 .map(|msg| msg.to_string())
1746 .unwrap_or_default();
1747
1748 cx.background_spawn(async move {
1749 let final_project_snapshot = final_project_snapshot.await;
1750 let serialized_thread = serialized_thread.await?;
1751 let thread_data =
1752 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1753
1754 let rating = match feedback {
1755 ThreadFeedback::Positive => "positive",
1756 ThreadFeedback::Negative => "negative",
1757 };
1758 telemetry::event!(
1759 "Assistant Thread Rated",
1760 rating,
1761 thread_id,
1762 enabled_tool_names,
1763 message_id = message_id.0,
1764 message_content,
1765 thread_data,
1766 final_project_snapshot
1767 );
1768 client.telemetry().flush_events();
1769
1770 Ok(())
1771 })
1772 }
1773
1774 pub fn report_feedback(
1775 &mut self,
1776 feedback: ThreadFeedback,
1777 cx: &mut Context<Self>,
1778 ) -> Task<Result<()>> {
1779 let last_assistant_message_id = self
1780 .messages
1781 .iter()
1782 .rev()
1783 .find(|msg| msg.role == Role::Assistant)
1784 .map(|msg| msg.id);
1785
1786 if let Some(message_id) = last_assistant_message_id {
1787 self.report_message_feedback(message_id, feedback, cx)
1788 } else {
1789 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1790 let serialized_thread = self.serialize(cx);
1791 let thread_id = self.id().clone();
1792 let client = self.project.read(cx).client();
1793 self.feedback = Some(feedback);
1794 cx.notify();
1795
1796 cx.background_spawn(async move {
1797 let final_project_snapshot = final_project_snapshot.await;
1798 let serialized_thread = serialized_thread.await?;
1799 let thread_data = serde_json::to_value(serialized_thread)
1800 .unwrap_or_else(|_| serde_json::Value::Null);
1801
1802 let rating = match feedback {
1803 ThreadFeedback::Positive => "positive",
1804 ThreadFeedback::Negative => "negative",
1805 };
1806 telemetry::event!(
1807 "Assistant Thread Rated",
1808 rating,
1809 thread_id,
1810 thread_data,
1811 final_project_snapshot
1812 );
1813 client.telemetry().flush_events();
1814
1815 Ok(())
1816 })
1817 }
1818 }
1819
1820 /// Create a snapshot of the current project state including git information and unsaved buffers.
1821 fn project_snapshot(
1822 project: Entity<Project>,
1823 cx: &mut Context<Self>,
1824 ) -> Task<Arc<ProjectSnapshot>> {
1825 let git_store = project.read(cx).git_store().clone();
1826 let worktree_snapshots: Vec<_> = project
1827 .read(cx)
1828 .visible_worktrees(cx)
1829 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1830 .collect();
1831
1832 cx.spawn(async move |_, cx| {
1833 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1834
1835 let mut unsaved_buffers = Vec::new();
1836 cx.update(|app_cx| {
1837 let buffer_store = project.read(app_cx).buffer_store();
1838 for buffer_handle in buffer_store.read(app_cx).buffers() {
1839 let buffer = buffer_handle.read(app_cx);
1840 if buffer.is_dirty() {
1841 if let Some(file) = buffer.file() {
1842 let path = file.path().to_string_lossy().to_string();
1843 unsaved_buffers.push(path);
1844 }
1845 }
1846 }
1847 })
1848 .ok();
1849
1850 Arc::new(ProjectSnapshot {
1851 worktree_snapshots,
1852 unsaved_buffer_paths: unsaved_buffers,
1853 timestamp: Utc::now(),
1854 })
1855 })
1856 }
1857
1858 fn worktree_snapshot(
1859 worktree: Entity<project::Worktree>,
1860 git_store: Entity<GitStore>,
1861 cx: &App,
1862 ) -> Task<WorktreeSnapshot> {
1863 cx.spawn(async move |cx| {
1864 // Get worktree path and snapshot
1865 let worktree_info = cx.update(|app_cx| {
1866 let worktree = worktree.read(app_cx);
1867 let path = worktree.abs_path().to_string_lossy().to_string();
1868 let snapshot = worktree.snapshot();
1869 (path, snapshot)
1870 });
1871
1872 let Ok((worktree_path, _snapshot)) = worktree_info else {
1873 return WorktreeSnapshot {
1874 worktree_path: String::new(),
1875 git_state: None,
1876 };
1877 };
1878
1879 let git_state = git_store
1880 .update(cx, |git_store, cx| {
1881 git_store
1882 .repositories()
1883 .values()
1884 .find(|repo| {
1885 repo.read(cx)
1886 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1887 .is_some()
1888 })
1889 .cloned()
1890 })
1891 .ok()
1892 .flatten()
1893 .map(|repo| {
1894 repo.update(cx, |repo, _| {
1895 let current_branch =
1896 repo.branch.as_ref().map(|branch| branch.name.to_string());
1897 repo.send_job(None, |state, _| async move {
1898 let RepositoryState::Local { backend, .. } = state else {
1899 return GitState {
1900 remote_url: None,
1901 head_sha: None,
1902 current_branch,
1903 diff: None,
1904 };
1905 };
1906
1907 let remote_url = backend.remote_url("origin");
1908 let head_sha = backend.head_sha();
1909 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1910
1911 GitState {
1912 remote_url,
1913 head_sha,
1914 current_branch,
1915 diff,
1916 }
1917 })
1918 })
1919 });
1920
1921 let git_state = match git_state {
1922 Some(git_state) => match git_state.ok() {
1923 Some(git_state) => git_state.await.ok(),
1924 None => None,
1925 },
1926 None => None,
1927 };
1928
1929 WorktreeSnapshot {
1930 worktree_path,
1931 git_state,
1932 }
1933 })
1934 }
1935
1936 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1937 let mut markdown = Vec::new();
1938
1939 if let Some(summary) = self.summary() {
1940 writeln!(markdown, "# {summary}\n")?;
1941 };
1942
1943 for message in self.messages() {
1944 writeln!(
1945 markdown,
1946 "## {role}\n",
1947 role = match message.role {
1948 Role::User => "User",
1949 Role::Assistant => "Assistant",
1950 Role::System => "System",
1951 }
1952 )?;
1953
1954 if !message.context.is_empty() {
1955 writeln!(markdown, "{}", message.context)?;
1956 }
1957
1958 for segment in &message.segments {
1959 match segment {
1960 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1961 MessageSegment::Thinking { text, .. } => {
1962 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1963 }
1964 MessageSegment::RedactedThinking(_) => {}
1965 }
1966 }
1967
1968 for tool_use in self.tool_uses_for_message(message.id, cx) {
1969 writeln!(
1970 markdown,
1971 "**Use Tool: {} ({})**",
1972 tool_use.name, tool_use.id
1973 )?;
1974 writeln!(markdown, "```json")?;
1975 writeln!(
1976 markdown,
1977 "{}",
1978 serde_json::to_string_pretty(&tool_use.input)?
1979 )?;
1980 writeln!(markdown, "```")?;
1981 }
1982
1983 for tool_result in self.tool_results_for_message(message.id) {
1984 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1985 if tool_result.is_error {
1986 write!(markdown, " (Error)")?;
1987 }
1988
1989 writeln!(markdown, "**\n")?;
1990 writeln!(markdown, "{}", tool_result.content)?;
1991 }
1992 }
1993
1994 Ok(String::from_utf8_lossy(&markdown).to_string())
1995 }
1996
1997 pub fn keep_edits_in_range(
1998 &mut self,
1999 buffer: Entity<language::Buffer>,
2000 buffer_range: Range<language::Anchor>,
2001 cx: &mut Context<Self>,
2002 ) {
2003 self.action_log.update(cx, |action_log, cx| {
2004 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2005 });
2006 }
2007
2008 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2009 self.action_log
2010 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2011 }
2012
2013 pub fn reject_edits_in_ranges(
2014 &mut self,
2015 buffer: Entity<language::Buffer>,
2016 buffer_ranges: Vec<Range<language::Anchor>>,
2017 cx: &mut Context<Self>,
2018 ) -> Task<Result<()>> {
2019 self.action_log.update(cx, |action_log, cx| {
2020 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2021 })
2022 }
2023
2024 pub fn action_log(&self) -> &Entity<ActionLog> {
2025 &self.action_log
2026 }
2027
2028 pub fn project(&self) -> &Entity<Project> {
2029 &self.project
2030 }
2031
2032 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2033 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2034 return;
2035 }
2036
2037 let now = Instant::now();
2038 if let Some(last) = self.last_auto_capture_at {
2039 if now.duration_since(last).as_secs() < 10 {
2040 return;
2041 }
2042 }
2043
2044 self.last_auto_capture_at = Some(now);
2045
2046 let thread_id = self.id().clone();
2047 let github_login = self
2048 .project
2049 .read(cx)
2050 .user_store()
2051 .read(cx)
2052 .current_user()
2053 .map(|user| user.github_login.clone());
2054 let client = self.project.read(cx).client().clone();
2055 let serialize_task = self.serialize(cx);
2056
2057 cx.background_executor()
2058 .spawn(async move {
2059 if let Ok(serialized_thread) = serialize_task.await {
2060 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2061 telemetry::event!(
2062 "Agent Thread Auto-Captured",
2063 thread_id = thread_id.to_string(),
2064 thread_data = thread_data,
2065 auto_capture_reason = "tracked_user",
2066 github_login = github_login
2067 );
2068
2069 client.telemetry().flush_events();
2070 }
2071 }
2072 })
2073 .detach();
2074 }
2075
2076 pub fn cumulative_token_usage(&self) -> TokenUsage {
2077 self.cumulative_token_usage
2078 }
2079
2080 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2081 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2082 return TotalTokenUsage::default();
2083 };
2084
2085 let max = model.model.max_token_count();
2086
2087 let index = self
2088 .messages
2089 .iter()
2090 .position(|msg| msg.id == message_id)
2091 .unwrap_or(0);
2092
2093 if index == 0 {
2094 return TotalTokenUsage { total: 0, max };
2095 }
2096
2097 let token_usage = &self
2098 .request_token_usage
2099 .get(index - 1)
2100 .cloned()
2101 .unwrap_or_default();
2102
2103 TotalTokenUsage {
2104 total: token_usage.total_tokens() as usize,
2105 max,
2106 }
2107 }
2108
2109 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2110 let model_registry = LanguageModelRegistry::read_global(cx);
2111 let Some(model) = model_registry.default_model() else {
2112 return TotalTokenUsage::default();
2113 };
2114
2115 let max = model.model.max_token_count();
2116
2117 if let Some(exceeded_error) = &self.exceeded_window_error {
2118 if model.model.id() == exceeded_error.model_id {
2119 return TotalTokenUsage {
2120 total: exceeded_error.token_count,
2121 max,
2122 };
2123 }
2124 }
2125
2126 let total = self
2127 .token_usage_at_last_message()
2128 .unwrap_or_default()
2129 .total_tokens() as usize;
2130
2131 TotalTokenUsage { total, max }
2132 }
2133
2134 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2135 self.request_token_usage
2136 .get(self.messages.len().saturating_sub(1))
2137 .or_else(|| self.request_token_usage.last())
2138 .cloned()
2139 }
2140
2141 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2142 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2143 self.request_token_usage
2144 .resize(self.messages.len(), placeholder);
2145
2146 if let Some(last) = self.request_token_usage.last_mut() {
2147 *last = token_usage;
2148 }
2149 }
2150
2151 pub fn deny_tool_use(
2152 &mut self,
2153 tool_use_id: LanguageModelToolUseId,
2154 tool_name: Arc<str>,
2155 cx: &mut Context<Self>,
2156 ) {
2157 let err = Err(anyhow::anyhow!(
2158 "Permission to run tool action denied by user"
2159 ));
2160
2161 self.tool_use
2162 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2163 self.tool_finished(tool_use_id.clone(), None, true, cx);
2164 }
2165}
2166
2167#[derive(Debug, Clone, Error)]
2168pub enum ThreadError {
2169 #[error("Payment required")]
2170 PaymentRequired,
2171 #[error("Max monthly spend reached")]
2172 MaxMonthlySpendReached,
2173 #[error("Model request limit reached")]
2174 ModelRequestLimitReached { plan: Plan },
2175 #[error("Message {header}: {message}")]
2176 Message {
2177 header: SharedString,
2178 message: SharedString,
2179 },
2180}
2181
2182#[derive(Debug, Clone)]
2183pub enum ThreadEvent {
2184 ShowError(ThreadError),
2185 UsageUpdated(RequestUsage),
2186 StreamedCompletion,
2187 StreamedAssistantText(MessageId, String),
2188 StreamedAssistantThinking(MessageId, String),
2189 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2190 MessageAdded(MessageId),
2191 MessageEdited(MessageId),
2192 MessageDeleted(MessageId),
2193 SummaryGenerated,
2194 SummaryChanged,
2195 UsePendingTools {
2196 tool_uses: Vec<PendingToolUse>,
2197 },
2198 ToolFinished {
2199 #[allow(unused)]
2200 tool_use_id: LanguageModelToolUseId,
2201 /// The pending tool use that corresponds to this tool.
2202 pending_tool_use: Option<PendingToolUse>,
2203 },
2204 CheckpointChanged,
2205 ToolConfirmationNeeded,
2206}
2207
2208impl EventEmitter<ThreadEvent> for Thread {}
2209
2210struct PendingCompletion {
2211 id: usize,
2212 _task: Task<()>,
2213}
2214
2215#[cfg(test)]
2216mod tests {
2217 use super::*;
2218 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2219 use assistant_settings::AssistantSettings;
2220 use context_server::ContextServerSettings;
2221 use editor::EditorSettings;
2222 use gpui::TestAppContext;
2223 use project::{FakeFs, Project};
2224 use prompt_store::PromptBuilder;
2225 use serde_json::json;
2226 use settings::{Settings, SettingsStore};
2227 use std::sync::Arc;
2228 use theme::ThemeSettings;
2229 use util::path;
2230 use workspace::Workspace;
2231
2232 #[gpui::test]
2233 async fn test_message_with_context(cx: &mut TestAppContext) {
2234 init_test_settings(cx);
2235
2236 let project = create_test_project(
2237 cx,
2238 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2239 )
2240 .await;
2241
2242 let (_workspace, _thread_store, thread, context_store) =
2243 setup_test_environment(cx, project.clone()).await;
2244
2245 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2246 .await
2247 .unwrap();
2248
2249 let context =
2250 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2251
2252 // Insert user message with context
2253 let message_id = thread.update(cx, |thread, cx| {
2254 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2255 });
2256
2257 // Check content and context in message object
2258 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2259
2260 // Use different path format strings based on platform for the test
2261 #[cfg(windows)]
2262 let path_part = r"test\code.rs";
2263 #[cfg(not(windows))]
2264 let path_part = "test/code.rs";
2265
2266 let expected_context = format!(
2267 r#"
2268<context>
2269The following items were attached by the user. You don't need to use other tools to read them.
2270
2271<files>
2272```rs {path_part}
2273fn main() {{
2274 println!("Hello, world!");
2275}}
2276```
2277</files>
2278</context>
2279"#
2280 );
2281
2282 assert_eq!(message.role, Role::User);
2283 assert_eq!(message.segments.len(), 1);
2284 assert_eq!(
2285 message.segments[0],
2286 MessageSegment::Text("Please explain this code".to_string())
2287 );
2288 assert_eq!(message.context, expected_context);
2289
2290 // Check message in request
2291 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2292
2293 assert_eq!(request.messages.len(), 2);
2294 let expected_full_message = format!("{}Please explain this code", expected_context);
2295 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2296 }
2297
2298 #[gpui::test]
2299 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2300 init_test_settings(cx);
2301
2302 let project = create_test_project(
2303 cx,
2304 json!({
2305 "file1.rs": "fn function1() {}\n",
2306 "file2.rs": "fn function2() {}\n",
2307 "file3.rs": "fn function3() {}\n",
2308 }),
2309 )
2310 .await;
2311
2312 let (_, _thread_store, thread, context_store) =
2313 setup_test_environment(cx, project.clone()).await;
2314
2315 // Open files individually
2316 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2317 .await
2318 .unwrap();
2319 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2320 .await
2321 .unwrap();
2322 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2323 .await
2324 .unwrap();
2325
2326 // Get the context objects
2327 let contexts = context_store.update(cx, |store, _| store.context().clone());
2328 assert_eq!(contexts.len(), 3);
2329
2330 // First message with context 1
2331 let message1_id = thread.update(cx, |thread, cx| {
2332 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2333 });
2334
2335 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2336 let message2_id = thread.update(cx, |thread, cx| {
2337 thread.insert_user_message(
2338 "Message 2",
2339 vec![contexts[0].clone(), contexts[1].clone()],
2340 None,
2341 cx,
2342 )
2343 });
2344
2345 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2346 let message3_id = thread.update(cx, |thread, cx| {
2347 thread.insert_user_message(
2348 "Message 3",
2349 vec![
2350 contexts[0].clone(),
2351 contexts[1].clone(),
2352 contexts[2].clone(),
2353 ],
2354 None,
2355 cx,
2356 )
2357 });
2358
2359 // Check what contexts are included in each message
2360 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2361 (
2362 thread.message(message1_id).unwrap().clone(),
2363 thread.message(message2_id).unwrap().clone(),
2364 thread.message(message3_id).unwrap().clone(),
2365 )
2366 });
2367
2368 // First message should include context 1
2369 assert!(message1.context.contains("file1.rs"));
2370
2371 // Second message should include only context 2 (not 1)
2372 assert!(!message2.context.contains("file1.rs"));
2373 assert!(message2.context.contains("file2.rs"));
2374
2375 // Third message should include only context 3 (not 1 or 2)
2376 assert!(!message3.context.contains("file1.rs"));
2377 assert!(!message3.context.contains("file2.rs"));
2378 assert!(message3.context.contains("file3.rs"));
2379
2380 // Check entire request to make sure all contexts are properly included
2381 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2382
2383 // The request should contain all 3 messages
2384 assert_eq!(request.messages.len(), 4);
2385
2386 // Check that the contexts are properly formatted in each message
2387 assert!(request.messages[1].string_contents().contains("file1.rs"));
2388 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2389 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2390
2391 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2392 assert!(request.messages[2].string_contents().contains("file2.rs"));
2393 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2394
2395 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2396 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2397 assert!(request.messages[3].string_contents().contains("file3.rs"));
2398 }
2399
2400 #[gpui::test]
2401 async fn test_message_without_files(cx: &mut TestAppContext) {
2402 init_test_settings(cx);
2403
2404 let project = create_test_project(
2405 cx,
2406 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2407 )
2408 .await;
2409
2410 let (_, _thread_store, thread, _context_store) =
2411 setup_test_environment(cx, project.clone()).await;
2412
2413 // Insert user message without any context (empty context vector)
2414 let message_id = thread.update(cx, |thread, cx| {
2415 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2416 });
2417
2418 // Check content and context in message object
2419 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2420
2421 // Context should be empty when no files are included
2422 assert_eq!(message.role, Role::User);
2423 assert_eq!(message.segments.len(), 1);
2424 assert_eq!(
2425 message.segments[0],
2426 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2427 );
2428 assert_eq!(message.context, "");
2429
2430 // Check message in request
2431 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2432
2433 assert_eq!(request.messages.len(), 2);
2434 assert_eq!(
2435 request.messages[1].string_contents(),
2436 "What is the best way to learn Rust?"
2437 );
2438
2439 // Add second message, also without context
2440 let message2_id = thread.update(cx, |thread, cx| {
2441 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2442 });
2443
2444 let message2 =
2445 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2446 assert_eq!(message2.context, "");
2447
2448 // Check that both messages appear in the request
2449 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2450
2451 assert_eq!(request.messages.len(), 3);
2452 assert_eq!(
2453 request.messages[1].string_contents(),
2454 "What is the best way to learn Rust?"
2455 );
2456 assert_eq!(
2457 request.messages[2].string_contents(),
2458 "Are there any good books?"
2459 );
2460 }
2461
2462 #[gpui::test]
2463 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2464 init_test_settings(cx);
2465
2466 let project = create_test_project(
2467 cx,
2468 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2469 )
2470 .await;
2471
2472 let (_workspace, _thread_store, thread, context_store) =
2473 setup_test_environment(cx, project.clone()).await;
2474
2475 // Open buffer and add it to context
2476 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2477 .await
2478 .unwrap();
2479
2480 let context =
2481 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2482
2483 // Insert user message with the buffer as context
2484 thread.update(cx, |thread, cx| {
2485 thread.insert_user_message("Explain this code", vec![context], None, cx)
2486 });
2487
2488 // Create a request and check that it doesn't have a stale buffer warning yet
2489 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2490
2491 // Make sure we don't have a stale file warning yet
2492 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2493 msg.string_contents()
2494 .contains("These files changed since last read:")
2495 });
2496 assert!(
2497 !has_stale_warning,
2498 "Should not have stale buffer warning before buffer is modified"
2499 );
2500
2501 // Modify the buffer
2502 buffer.update(cx, |buffer, cx| {
2503 // Find a position at the end of line 1
2504 buffer.edit(
2505 [(1..1, "\n println!(\"Added a new line\");\n")],
2506 None,
2507 cx,
2508 );
2509 });
2510
2511 // Insert another user message without context
2512 thread.update(cx, |thread, cx| {
2513 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2514 });
2515
2516 // Create a new request and check for the stale buffer warning
2517 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2518
2519 // We should have a stale file warning as the last message
2520 let last_message = new_request
2521 .messages
2522 .last()
2523 .expect("Request should have messages");
2524
2525 // The last message should be the stale buffer notification
2526 assert_eq!(last_message.role, Role::User);
2527
2528 // Check the exact content of the message
2529 let expected_content = "These files changed since last read:\n- code.rs\n";
2530 assert_eq!(
2531 last_message.string_contents(),
2532 expected_content,
2533 "Last message should be exactly the stale buffer notification"
2534 );
2535 }
2536
2537 fn init_test_settings(cx: &mut TestAppContext) {
2538 cx.update(|cx| {
2539 let settings_store = SettingsStore::test(cx);
2540 cx.set_global(settings_store);
2541 language::init(cx);
2542 Project::init_settings(cx);
2543 AssistantSettings::register(cx);
2544 prompt_store::init(cx);
2545 thread_store::init(cx);
2546 workspace::init_settings(cx);
2547 ThemeSettings::register(cx);
2548 ContextServerSettings::register(cx);
2549 EditorSettings::register(cx);
2550 });
2551 }
2552
2553 // Helper to create a test project with test files
2554 async fn create_test_project(
2555 cx: &mut TestAppContext,
2556 files: serde_json::Value,
2557 ) -> Entity<Project> {
2558 let fs = FakeFs::new(cx.executor());
2559 fs.insert_tree(path!("/test"), files).await;
2560 Project::test(fs, [path!("/test").as_ref()], cx).await
2561 }
2562
2563 async fn setup_test_environment(
2564 cx: &mut TestAppContext,
2565 project: Entity<Project>,
2566 ) -> (
2567 Entity<Workspace>,
2568 Entity<ThreadStore>,
2569 Entity<Thread>,
2570 Entity<ContextStore>,
2571 ) {
2572 let (workspace, cx) =
2573 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2574
2575 let thread_store = cx
2576 .update(|_, cx| {
2577 ThreadStore::load(
2578 project.clone(),
2579 cx.new(|_| ToolWorkingSet::default()),
2580 Arc::new(PromptBuilder::new(None).unwrap()),
2581 cx,
2582 )
2583 })
2584 .await
2585 .unwrap();
2586
2587 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2588 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2589
2590 (workspace, thread_store, thread, context_store)
2591 }
2592
2593 async fn add_file_to_context(
2594 project: &Entity<Project>,
2595 context_store: &Entity<ContextStore>,
2596 path: &str,
2597 cx: &mut TestAppContext,
2598 ) -> Result<Entity<language::Buffer>> {
2599 let buffer_path = project
2600 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2601 .unwrap();
2602
2603 let buffer = project
2604 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2605 .await
2606 .unwrap();
2607
2608 context_store
2609 .update(cx, |store, cx| {
2610 store.add_file_from_buffer(buffer.clone(), cx)
2611 })
2612 .await?;
2613
2614 Ok(buffer)
2615 }
2616}