1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100}
101
102impl Message {
103 /// Returns whether the message contains any meaningful text that should be displayed
104 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
105 pub fn should_display_content(&self) -> bool {
106 self.segments.iter().all(|segment| segment.should_display())
107 }
108
109 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
110 if let Some(MessageSegment::Thinking {
111 text: segment,
112 signature: current_signature,
113 }) = self.segments.last_mut()
114 {
115 if let Some(signature) = signature {
116 *current_signature = Some(signature);
117 }
118 segment.push_str(text);
119 } else {
120 self.segments.push(MessageSegment::Thinking {
121 text: text.to_string(),
122 signature,
123 });
124 }
125 }
126
127 pub fn push_text(&mut self, text: &str) {
128 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
129 segment.push_str(text);
130 } else {
131 self.segments.push(MessageSegment::Text(text.to_string()));
132 }
133 }
134
135 pub fn to_string(&self) -> String {
136 let mut result = String::new();
137
138 if !self.context.is_empty() {
139 result.push_str(&self.context);
140 }
141
142 for segment in &self.segments {
143 match segment {
144 MessageSegment::Text(text) => result.push_str(text),
145 MessageSegment::Thinking { text, .. } => {
146 result.push_str("<think>\n");
147 result.push_str(text);
148 result.push_str("\n</think>");
149 }
150 MessageSegment::RedactedThinking(_) => {}
151 }
152 }
153
154 result
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum MessageSegment {
160 Text(String),
161 Thinking {
162 text: String,
163 signature: Option<String>,
164 },
165 RedactedThinking(Vec<u8>),
166}
167
168impl MessageSegment {
169 pub fn should_display(&self) -> bool {
170 match self {
171 Self::Text(text) => text.is_empty(),
172 Self::Thinking { text, .. } => text.is_empty(),
173 Self::RedactedThinking(_) => false,
174 }
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ProjectSnapshot {
180 pub worktree_snapshots: Vec<WorktreeSnapshot>,
181 pub unsaved_buffer_paths: Vec<String>,
182 pub timestamp: DateTime<Utc>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct WorktreeSnapshot {
187 pub worktree_path: String,
188 pub git_state: Option<GitState>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct GitState {
193 pub remote_url: Option<String>,
194 pub head_sha: Option<String>,
195 pub current_branch: Option<String>,
196 pub diff: Option<String>,
197}
198
199#[derive(Clone)]
200pub struct ThreadCheckpoint {
201 message_id: MessageId,
202 git_checkpoint: GitStoreCheckpoint,
203}
204
205#[derive(Copy, Clone, Debug, PartialEq, Eq)]
206pub enum ThreadFeedback {
207 Positive,
208 Negative,
209}
210
211pub enum LastRestoreCheckpoint {
212 Pending {
213 message_id: MessageId,
214 },
215 Error {
216 message_id: MessageId,
217 error: String,
218 },
219}
220
221impl LastRestoreCheckpoint {
222 pub fn message_id(&self) -> MessageId {
223 match self {
224 LastRestoreCheckpoint::Pending { message_id } => *message_id,
225 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
226 }
227 }
228}
229
230#[derive(Clone, Debug, Default, Serialize, Deserialize)]
231pub enum DetailedSummaryState {
232 #[default]
233 NotGenerated,
234 Generating {
235 message_id: MessageId,
236 },
237 Generated {
238 text: SharedString,
239 message_id: MessageId,
240 },
241}
242
243#[derive(Default)]
244pub struct TotalTokenUsage {
245 pub total: usize,
246 pub max: usize,
247}
248
249impl TotalTokenUsage {
250 pub fn ratio(&self) -> TokenUsageRatio {
251 #[cfg(debug_assertions)]
252 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
253 .unwrap_or("0.8".to_string())
254 .parse()
255 .unwrap();
256 #[cfg(not(debug_assertions))]
257 let warning_threshold: f32 = 0.8;
258
259 if self.total >= self.max {
260 TokenUsageRatio::Exceeded
261 } else if self.total as f32 / self.max as f32 >= warning_threshold {
262 TokenUsageRatio::Warning
263 } else {
264 TokenUsageRatio::Normal
265 }
266 }
267
268 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
269 TotalTokenUsage {
270 total: self.total + tokens,
271 max: self.max,
272 }
273 }
274}
275
276#[derive(Debug, Default, PartialEq, Eq)]
277pub enum TokenUsageRatio {
278 #[default]
279 Normal,
280 Warning,
281 Exceeded,
282}
283
284/// A thread of conversation with the LLM.
285pub struct Thread {
286 id: ThreadId,
287 updated_at: DateTime<Utc>,
288 summary: Option<SharedString>,
289 pending_summary: Task<Option<()>>,
290 detailed_summary_state: DetailedSummaryState,
291 messages: Vec<Message>,
292 next_message_id: MessageId,
293 last_prompt_id: PromptId,
294 context: BTreeMap<ContextId, AssistantContext>,
295 context_by_message: HashMap<MessageId, Vec<ContextId>>,
296 project_context: SharedProjectContext,
297 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
298 completion_count: usize,
299 pending_completions: Vec<PendingCompletion>,
300 project: Entity<Project>,
301 prompt_builder: Arc<PromptBuilder>,
302 tools: Entity<ToolWorkingSet>,
303 tool_use: ToolUseState,
304 action_log: Entity<ActionLog>,
305 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
306 pending_checkpoint: Option<ThreadCheckpoint>,
307 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
308 request_token_usage: Vec<TokenUsage>,
309 cumulative_token_usage: TokenUsage,
310 exceeded_window_error: Option<ExceededWindowError>,
311 feedback: Option<ThreadFeedback>,
312 message_feedback: HashMap<MessageId, ThreadFeedback>,
313 last_auto_capture_at: Option<Instant>,
314 request_callback: Option<
315 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
316 >,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct ExceededWindowError {
321 /// Model used when last message exceeded context window
322 model_id: LanguageModelId,
323 /// Token count including last message
324 token_count: usize,
325}
326
327impl Thread {
328 pub fn new(
329 project: Entity<Project>,
330 tools: Entity<ToolWorkingSet>,
331 prompt_builder: Arc<PromptBuilder>,
332 system_prompt: SharedProjectContext,
333 cx: &mut Context<Self>,
334 ) -> Self {
335 Self {
336 id: ThreadId::new(),
337 updated_at: Utc::now(),
338 summary: None,
339 pending_summary: Task::ready(None),
340 detailed_summary_state: DetailedSummaryState::NotGenerated,
341 messages: Vec::new(),
342 next_message_id: MessageId(0),
343 last_prompt_id: PromptId::new(),
344 context: BTreeMap::default(),
345 context_by_message: HashMap::default(),
346 project_context: system_prompt,
347 checkpoints_by_message: HashMap::default(),
348 completion_count: 0,
349 pending_completions: Vec::new(),
350 project: project.clone(),
351 prompt_builder,
352 tools: tools.clone(),
353 last_restore_checkpoint: None,
354 pending_checkpoint: None,
355 tool_use: ToolUseState::new(tools.clone()),
356 action_log: cx.new(|_| ActionLog::new(project.clone())),
357 initial_project_snapshot: {
358 let project_snapshot = Self::project_snapshot(project, cx);
359 cx.foreground_executor()
360 .spawn(async move { Some(project_snapshot.await) })
361 .shared()
362 },
363 request_token_usage: Vec::new(),
364 cumulative_token_usage: TokenUsage::default(),
365 exceeded_window_error: None,
366 feedback: None,
367 message_feedback: HashMap::default(),
368 last_auto_capture_at: None,
369 request_callback: None,
370 }
371 }
372
373 pub fn deserialize(
374 id: ThreadId,
375 serialized: SerializedThread,
376 project: Entity<Project>,
377 tools: Entity<ToolWorkingSet>,
378 prompt_builder: Arc<PromptBuilder>,
379 project_context: SharedProjectContext,
380 cx: &mut Context<Self>,
381 ) -> Self {
382 let next_message_id = MessageId(
383 serialized
384 .messages
385 .last()
386 .map(|message| message.id.0 + 1)
387 .unwrap_or(0),
388 );
389 let tool_use =
390 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
391
392 Self {
393 id,
394 updated_at: serialized.updated_at,
395 summary: Some(serialized.summary),
396 pending_summary: Task::ready(None),
397 detailed_summary_state: serialized.detailed_summary_state,
398 messages: serialized
399 .messages
400 .into_iter()
401 .map(|message| Message {
402 id: message.id,
403 role: message.role,
404 segments: message
405 .segments
406 .into_iter()
407 .map(|segment| match segment {
408 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
409 SerializedMessageSegment::Thinking { text, signature } => {
410 MessageSegment::Thinking { text, signature }
411 }
412 SerializedMessageSegment::RedactedThinking { data } => {
413 MessageSegment::RedactedThinking(data)
414 }
415 })
416 .collect(),
417 context: message.context,
418 })
419 .collect(),
420 next_message_id,
421 last_prompt_id: PromptId::new(),
422 context: BTreeMap::default(),
423 context_by_message: HashMap::default(),
424 project_context,
425 checkpoints_by_message: HashMap::default(),
426 completion_count: 0,
427 pending_completions: Vec::new(),
428 last_restore_checkpoint: None,
429 pending_checkpoint: None,
430 project: project.clone(),
431 prompt_builder,
432 tools,
433 tool_use,
434 action_log: cx.new(|_| ActionLog::new(project)),
435 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
436 request_token_usage: serialized.request_token_usage,
437 cumulative_token_usage: serialized.cumulative_token_usage,
438 exceeded_window_error: None,
439 feedback: None,
440 message_feedback: HashMap::default(),
441 last_auto_capture_at: None,
442 request_callback: None,
443 }
444 }
445
446 pub fn set_request_callback(
447 &mut self,
448 callback: impl 'static
449 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
450 ) {
451 self.request_callback = Some(Box::new(callback));
452 }
453
454 pub fn id(&self) -> &ThreadId {
455 &self.id
456 }
457
458 pub fn is_empty(&self) -> bool {
459 self.messages.is_empty()
460 }
461
462 pub fn updated_at(&self) -> DateTime<Utc> {
463 self.updated_at
464 }
465
466 pub fn touch_updated_at(&mut self) {
467 self.updated_at = Utc::now();
468 }
469
470 pub fn advance_prompt_id(&mut self) {
471 self.last_prompt_id = PromptId::new();
472 }
473
474 pub fn summary(&self) -> Option<SharedString> {
475 self.summary.clone()
476 }
477
478 pub fn project_context(&self) -> SharedProjectContext {
479 self.project_context.clone()
480 }
481
482 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
483
484 pub fn summary_or_default(&self) -> SharedString {
485 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
486 }
487
488 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
489 let Some(current_summary) = &self.summary else {
490 // Don't allow setting summary until generated
491 return;
492 };
493
494 let mut new_summary = new_summary.into();
495
496 if new_summary.is_empty() {
497 new_summary = Self::DEFAULT_SUMMARY;
498 }
499
500 if current_summary != &new_summary {
501 self.summary = Some(new_summary);
502 cx.emit(ThreadEvent::SummaryChanged);
503 }
504 }
505
506 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
507 self.latest_detailed_summary()
508 .unwrap_or_else(|| self.text().into())
509 }
510
511 fn latest_detailed_summary(&self) -> Option<SharedString> {
512 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
513 Some(text.clone())
514 } else {
515 None
516 }
517 }
518
519 pub fn message(&self, id: MessageId) -> Option<&Message> {
520 self.messages.iter().find(|message| message.id == id)
521 }
522
523 pub fn messages(&self) -> impl Iterator<Item = &Message> {
524 self.messages.iter()
525 }
526
527 pub fn is_generating(&self) -> bool {
528 !self.pending_completions.is_empty() || !self.all_tools_finished()
529 }
530
531 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
532 &self.tools
533 }
534
535 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
536 self.tool_use
537 .pending_tool_uses()
538 .into_iter()
539 .find(|tool_use| &tool_use.id == id)
540 }
541
542 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
543 self.tool_use
544 .pending_tool_uses()
545 .into_iter()
546 .filter(|tool_use| tool_use.status.needs_confirmation())
547 }
548
549 pub fn has_pending_tool_uses(&self) -> bool {
550 !self.tool_use.pending_tool_uses().is_empty()
551 }
552
553 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
554 self.checkpoints_by_message.get(&id).cloned()
555 }
556
557 pub fn restore_checkpoint(
558 &mut self,
559 checkpoint: ThreadCheckpoint,
560 cx: &mut Context<Self>,
561 ) -> Task<Result<()>> {
562 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
563 message_id: checkpoint.message_id,
564 });
565 cx.emit(ThreadEvent::CheckpointChanged);
566 cx.notify();
567
568 let git_store = self.project().read(cx).git_store().clone();
569 let restore = git_store.update(cx, |git_store, cx| {
570 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
571 });
572
573 cx.spawn(async move |this, cx| {
574 let result = restore.await;
575 this.update(cx, |this, cx| {
576 if let Err(err) = result.as_ref() {
577 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
578 message_id: checkpoint.message_id,
579 error: err.to_string(),
580 });
581 } else {
582 this.truncate(checkpoint.message_id, cx);
583 this.last_restore_checkpoint = None;
584 }
585 this.pending_checkpoint = None;
586 cx.emit(ThreadEvent::CheckpointChanged);
587 cx.notify();
588 })?;
589 result
590 })
591 }
592
593 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
594 let pending_checkpoint = if self.is_generating() {
595 return;
596 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
597 checkpoint
598 } else {
599 return;
600 };
601
602 let git_store = self.project.read(cx).git_store().clone();
603 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
604 cx.spawn(async move |this, cx| match final_checkpoint.await {
605 Ok(final_checkpoint) => {
606 let equal = git_store
607 .update(cx, |store, cx| {
608 store.compare_checkpoints(
609 pending_checkpoint.git_checkpoint.clone(),
610 final_checkpoint.clone(),
611 cx,
612 )
613 })?
614 .await
615 .unwrap_or(false);
616
617 if equal {
618 git_store
619 .update(cx, |store, cx| {
620 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
621 })?
622 .detach();
623 } else {
624 this.update(cx, |this, cx| {
625 this.insert_checkpoint(pending_checkpoint, cx)
626 })?;
627 }
628
629 git_store
630 .update(cx, |store, cx| {
631 store.delete_checkpoint(final_checkpoint, cx)
632 })?
633 .detach();
634
635 Ok(())
636 }
637 Err(_) => this.update(cx, |this, cx| {
638 this.insert_checkpoint(pending_checkpoint, cx)
639 }),
640 })
641 .detach();
642 }
643
644 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
645 self.checkpoints_by_message
646 .insert(checkpoint.message_id, checkpoint);
647 cx.emit(ThreadEvent::CheckpointChanged);
648 cx.notify();
649 }
650
651 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
652 self.last_restore_checkpoint.as_ref()
653 }
654
655 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
656 let Some(message_ix) = self
657 .messages
658 .iter()
659 .rposition(|message| message.id == message_id)
660 else {
661 return;
662 };
663 for deleted_message in self.messages.drain(message_ix..) {
664 self.context_by_message.remove(&deleted_message.id);
665 self.checkpoints_by_message.remove(&deleted_message.id);
666 }
667 cx.notify();
668 }
669
670 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
671 self.context_by_message
672 .get(&id)
673 .into_iter()
674 .flat_map(|context| {
675 context
676 .iter()
677 .filter_map(|context_id| self.context.get(&context_id))
678 })
679 }
680
681 /// Returns whether all of the tool uses have finished running.
682 pub fn all_tools_finished(&self) -> bool {
683 // If the only pending tool uses left are the ones with errors, then
684 // that means that we've finished running all of the pending tools.
685 self.tool_use
686 .pending_tool_uses()
687 .iter()
688 .all(|tool_use| tool_use.status.is_error())
689 }
690
691 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
692 self.tool_use.tool_uses_for_message(id, cx)
693 }
694
695 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
696 self.tool_use.tool_results_for_message(id)
697 }
698
699 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
700 self.tool_use.tool_result(id)
701 }
702
703 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
704 Some(&self.tool_use.tool_result(id)?.content)
705 }
706
707 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
708 self.tool_use.tool_result_card(id).cloned()
709 }
710
711 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
712 self.tool_use.message_has_tool_results(message_id)
713 }
714
715 /// Filter out contexts that have already been included in previous messages
716 pub fn filter_new_context<'a>(
717 &self,
718 context: impl Iterator<Item = &'a AssistantContext>,
719 ) -> impl Iterator<Item = &'a AssistantContext> {
720 context.filter(|ctx| self.is_context_new(ctx))
721 }
722
723 fn is_context_new(&self, context: &AssistantContext) -> bool {
724 !self.context.contains_key(&context.id())
725 }
726
727 pub fn insert_user_message(
728 &mut self,
729 text: impl Into<String>,
730 context: Vec<AssistantContext>,
731 git_checkpoint: Option<GitStoreCheckpoint>,
732 cx: &mut Context<Self>,
733 ) -> MessageId {
734 let text = text.into();
735
736 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
737
738 let new_context: Vec<_> = context
739 .into_iter()
740 .filter(|ctx| self.is_context_new(ctx))
741 .collect();
742
743 if !new_context.is_empty() {
744 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
745 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
746 message.context = context_string;
747 }
748 }
749
750 self.action_log.update(cx, |log, cx| {
751 // Track all buffers added as context
752 for ctx in &new_context {
753 match ctx {
754 AssistantContext::File(file_ctx) => {
755 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
756 }
757 AssistantContext::Directory(dir_ctx) => {
758 for context_buffer in &dir_ctx.context_buffers {
759 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
760 }
761 }
762 AssistantContext::Symbol(symbol_ctx) => {
763 log.buffer_added_as_context(
764 symbol_ctx.context_symbol.buffer.clone(),
765 cx,
766 );
767 }
768 AssistantContext::Excerpt(excerpt_context) => {
769 log.buffer_added_as_context(
770 excerpt_context.context_buffer.buffer.clone(),
771 cx,
772 );
773 }
774 AssistantContext::FetchedUrl(_)
775 | AssistantContext::Thread(_)
776 | AssistantContext::Rules(_) => {}
777 }
778 }
779 });
780 }
781
782 let context_ids = new_context
783 .iter()
784 .map(|context| context.id())
785 .collect::<Vec<_>>();
786 self.context.extend(
787 new_context
788 .into_iter()
789 .map(|context| (context.id(), context)),
790 );
791 self.context_by_message.insert(message_id, context_ids);
792
793 if let Some(git_checkpoint) = git_checkpoint {
794 self.pending_checkpoint = Some(ThreadCheckpoint {
795 message_id,
796 git_checkpoint,
797 });
798 }
799
800 self.auto_capture_telemetry(cx);
801
802 message_id
803 }
804
805 pub fn insert_message(
806 &mut self,
807 role: Role,
808 segments: Vec<MessageSegment>,
809 cx: &mut Context<Self>,
810 ) -> MessageId {
811 let id = self.next_message_id.post_inc();
812 self.messages.push(Message {
813 id,
814 role,
815 segments,
816 context: String::new(),
817 });
818 self.touch_updated_at();
819 cx.emit(ThreadEvent::MessageAdded(id));
820 id
821 }
822
823 pub fn edit_message(
824 &mut self,
825 id: MessageId,
826 new_role: Role,
827 new_segments: Vec<MessageSegment>,
828 cx: &mut Context<Self>,
829 ) -> bool {
830 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
831 return false;
832 };
833 message.role = new_role;
834 message.segments = new_segments;
835 self.touch_updated_at();
836 cx.emit(ThreadEvent::MessageEdited(id));
837 true
838 }
839
840 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
841 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
842 return false;
843 };
844 self.messages.remove(index);
845 self.context_by_message.remove(&id);
846 self.touch_updated_at();
847 cx.emit(ThreadEvent::MessageDeleted(id));
848 true
849 }
850
851 /// Returns the representation of this [`Thread`] in a textual form.
852 ///
853 /// This is the representation we use when attaching a thread as context to another thread.
854 pub fn text(&self) -> String {
855 let mut text = String::new();
856
857 for message in &self.messages {
858 text.push_str(match message.role {
859 language_model::Role::User => "User:",
860 language_model::Role::Assistant => "Assistant:",
861 language_model::Role::System => "System:",
862 });
863 text.push('\n');
864
865 for segment in &message.segments {
866 match segment {
867 MessageSegment::Text(content) => text.push_str(content),
868 MessageSegment::Thinking { text: content, .. } => {
869 text.push_str(&format!("<think>{}</think>", content))
870 }
871 MessageSegment::RedactedThinking(_) => {}
872 }
873 }
874 text.push('\n');
875 }
876
877 text
878 }
879
880 /// Serializes this thread into a format for storage or telemetry.
881 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
882 let initial_project_snapshot = self.initial_project_snapshot.clone();
883 cx.spawn(async move |this, cx| {
884 let initial_project_snapshot = initial_project_snapshot.await;
885 this.read_with(cx, |this, cx| SerializedThread {
886 version: SerializedThread::VERSION.to_string(),
887 summary: this.summary_or_default(),
888 updated_at: this.updated_at(),
889 messages: this
890 .messages()
891 .map(|message| SerializedMessage {
892 id: message.id,
893 role: message.role,
894 segments: message
895 .segments
896 .iter()
897 .map(|segment| match segment {
898 MessageSegment::Text(text) => {
899 SerializedMessageSegment::Text { text: text.clone() }
900 }
901 MessageSegment::Thinking { text, signature } => {
902 SerializedMessageSegment::Thinking {
903 text: text.clone(),
904 signature: signature.clone(),
905 }
906 }
907 MessageSegment::RedactedThinking(data) => {
908 SerializedMessageSegment::RedactedThinking {
909 data: data.clone(),
910 }
911 }
912 })
913 .collect(),
914 tool_uses: this
915 .tool_uses_for_message(message.id, cx)
916 .into_iter()
917 .map(|tool_use| SerializedToolUse {
918 id: tool_use.id,
919 name: tool_use.name,
920 input: tool_use.input,
921 })
922 .collect(),
923 tool_results: this
924 .tool_results_for_message(message.id)
925 .into_iter()
926 .map(|tool_result| SerializedToolResult {
927 tool_use_id: tool_result.tool_use_id.clone(),
928 is_error: tool_result.is_error,
929 content: tool_result.content.clone(),
930 })
931 .collect(),
932 context: message.context.clone(),
933 })
934 .collect(),
935 initial_project_snapshot,
936 cumulative_token_usage: this.cumulative_token_usage,
937 request_token_usage: this.request_token_usage.clone(),
938 detailed_summary_state: this.detailed_summary_state.clone(),
939 exceeded_window_error: this.exceeded_window_error.clone(),
940 })
941 })
942 }
943
944 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
945 let mut request = self.to_completion_request(cx);
946 if model.supports_tools() {
947 request.tools = {
948 let mut tools = Vec::new();
949 tools.extend(
950 self.tools()
951 .read(cx)
952 .enabled_tools(cx)
953 .into_iter()
954 .filter_map(|tool| {
955 // Skip tools that cannot be supported
956 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
957 Some(LanguageModelRequestTool {
958 name: tool.name(),
959 description: tool.description(),
960 input_schema,
961 })
962 }),
963 );
964
965 tools
966 };
967 }
968
969 self.stream_completion(request, model, cx);
970 }
971
972 pub fn used_tools_since_last_user_message(&self) -> bool {
973 for message in self.messages.iter().rev() {
974 if self.tool_use.message_has_tool_results(message.id) {
975 return true;
976 } else if message.role == Role::User {
977 return false;
978 }
979 }
980
981 false
982 }
983
984 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
985 let mut request = LanguageModelRequest {
986 thread_id: Some(self.id.to_string()),
987 prompt_id: Some(self.last_prompt_id.to_string()),
988 messages: vec![],
989 tools: Vec::new(),
990 stop: Vec::new(),
991 temperature: None,
992 };
993
994 if let Some(project_context) = self.project_context.borrow().as_ref() {
995 match self
996 .prompt_builder
997 .generate_assistant_system_prompt(project_context)
998 {
999 Err(err) => {
1000 let message = format!("{err:?}").into();
1001 log::error!("{message}");
1002 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1003 header: "Error generating system prompt".into(),
1004 message,
1005 }));
1006 }
1007 Ok(system_prompt) => {
1008 request.messages.push(LanguageModelRequestMessage {
1009 role: Role::System,
1010 content: vec![MessageContent::Text(system_prompt)],
1011 cache: true,
1012 });
1013 }
1014 }
1015 } else {
1016 let message = "Context for system prompt unexpectedly not ready.".into();
1017 log::error!("{message}");
1018 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1019 header: "Error generating system prompt".into(),
1020 message,
1021 }));
1022 }
1023
1024 for message in &self.messages {
1025 let mut request_message = LanguageModelRequestMessage {
1026 role: message.role,
1027 content: Vec::new(),
1028 cache: false,
1029 };
1030
1031 self.tool_use
1032 .attach_tool_results(message.id, &mut request_message);
1033
1034 if !message.context.is_empty() {
1035 request_message
1036 .content
1037 .push(MessageContent::Text(message.context.to_string()));
1038 }
1039
1040 for segment in &message.segments {
1041 match segment {
1042 MessageSegment::Text(text) => {
1043 if !text.is_empty() {
1044 request_message
1045 .content
1046 .push(MessageContent::Text(text.into()));
1047 }
1048 }
1049 MessageSegment::Thinking { text, signature } => {
1050 if !text.is_empty() {
1051 request_message.content.push(MessageContent::Thinking {
1052 text: text.into(),
1053 signature: signature.clone(),
1054 });
1055 }
1056 }
1057 MessageSegment::RedactedThinking(data) => {
1058 request_message
1059 .content
1060 .push(MessageContent::RedactedThinking(data.clone()));
1061 }
1062 };
1063 }
1064
1065 self.tool_use
1066 .attach_tool_uses(message.id, &mut request_message);
1067
1068 request.messages.push(request_message);
1069 }
1070
1071 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1072 if let Some(last) = request.messages.last_mut() {
1073 last.cache = true;
1074 }
1075
1076 self.attached_tracked_files_state(&mut request.messages, cx);
1077
1078 request
1079 }
1080
1081 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1082 let mut request = LanguageModelRequest {
1083 thread_id: None,
1084 prompt_id: None,
1085 messages: vec![],
1086 tools: Vec::new(),
1087 stop: Vec::new(),
1088 temperature: None,
1089 };
1090
1091 for message in &self.messages {
1092 let mut request_message = LanguageModelRequestMessage {
1093 role: message.role,
1094 content: Vec::new(),
1095 cache: false,
1096 };
1097
1098 // Skip tool results during summarization.
1099 if self.tool_use.message_has_tool_results(message.id) {
1100 continue;
1101 }
1102
1103 for segment in &message.segments {
1104 match segment {
1105 MessageSegment::Text(text) => request_message
1106 .content
1107 .push(MessageContent::Text(text.clone())),
1108 MessageSegment::Thinking { .. } => {}
1109 MessageSegment::RedactedThinking(_) => {}
1110 }
1111 }
1112
1113 if request_message.content.is_empty() {
1114 continue;
1115 }
1116
1117 request.messages.push(request_message);
1118 }
1119
1120 request.messages.push(LanguageModelRequestMessage {
1121 role: Role::User,
1122 content: vec![MessageContent::Text(added_user_message)],
1123 cache: false,
1124 });
1125
1126 request
1127 }
1128
1129 fn attached_tracked_files_state(
1130 &self,
1131 messages: &mut Vec<LanguageModelRequestMessage>,
1132 cx: &App,
1133 ) {
1134 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1135
1136 let mut stale_message = String::new();
1137
1138 let action_log = self.action_log.read(cx);
1139
1140 for stale_file in action_log.stale_buffers(cx) {
1141 let Some(file) = stale_file.read(cx).file() else {
1142 continue;
1143 };
1144
1145 if stale_message.is_empty() {
1146 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1147 }
1148
1149 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1150 }
1151
1152 let mut content = Vec::with_capacity(2);
1153
1154 if !stale_message.is_empty() {
1155 content.push(stale_message.into());
1156 }
1157
1158 if !content.is_empty() {
1159 let context_message = LanguageModelRequestMessage {
1160 role: Role::User,
1161 content,
1162 cache: false,
1163 };
1164
1165 messages.push(context_message);
1166 }
1167 }
1168
1169 pub fn stream_completion(
1170 &mut self,
1171 request: LanguageModelRequest,
1172 model: Arc<dyn LanguageModel>,
1173 cx: &mut Context<Self>,
1174 ) {
1175 let pending_completion_id = post_inc(&mut self.completion_count);
1176 let mut request_callback_parameters = if self.request_callback.is_some() {
1177 Some((request.clone(), Vec::new()))
1178 } else {
1179 None
1180 };
1181 let prompt_id = self.last_prompt_id.clone();
1182 let tool_use_metadata = ToolUseMetadata {
1183 model: model.clone(),
1184 thread_id: self.id.clone(),
1185 prompt_id: prompt_id.clone(),
1186 };
1187
1188 let task = cx.spawn(async move |thread, cx| {
1189 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1190 let initial_token_usage =
1191 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1192 let stream_completion = async {
1193 let (mut events, usage) = stream_completion_future.await?;
1194
1195 let mut stop_reason = StopReason::EndTurn;
1196 let mut current_token_usage = TokenUsage::default();
1197
1198 if let Some(usage) = usage {
1199 thread
1200 .update(cx, |_thread, cx| {
1201 cx.emit(ThreadEvent::UsageUpdated(usage));
1202 })
1203 .ok();
1204 }
1205
1206 while let Some(event) = events.next().await {
1207 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1208 response_events
1209 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1210 }
1211
1212 let event = event?;
1213
1214 thread.update(cx, |thread, cx| {
1215 match event {
1216 LanguageModelCompletionEvent::StartMessage { .. } => {
1217 thread.insert_message(
1218 Role::Assistant,
1219 vec![MessageSegment::Text(String::new())],
1220 cx,
1221 );
1222 }
1223 LanguageModelCompletionEvent::Stop(reason) => {
1224 stop_reason = reason;
1225 }
1226 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1227 thread.update_token_usage_at_last_message(token_usage);
1228 thread.cumulative_token_usage = thread.cumulative_token_usage
1229 + token_usage
1230 - current_token_usage;
1231 current_token_usage = token_usage;
1232 }
1233 LanguageModelCompletionEvent::Text(chunk) => {
1234 if let Some(last_message) = thread.messages.last_mut() {
1235 if last_message.role == Role::Assistant {
1236 last_message.push_text(&chunk);
1237 cx.emit(ThreadEvent::StreamedAssistantText(
1238 last_message.id,
1239 chunk,
1240 ));
1241 } else {
1242 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1243 // of a new Assistant response.
1244 //
1245 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1246 // will result in duplicating the text of the chunk in the rendered Markdown.
1247 thread.insert_message(
1248 Role::Assistant,
1249 vec![MessageSegment::Text(chunk.to_string())],
1250 cx,
1251 );
1252 };
1253 }
1254 }
1255 LanguageModelCompletionEvent::Thinking {
1256 text: chunk,
1257 signature,
1258 } => {
1259 if let Some(last_message) = thread.messages.last_mut() {
1260 if last_message.role == Role::Assistant {
1261 last_message.push_thinking(&chunk, signature);
1262 cx.emit(ThreadEvent::StreamedAssistantThinking(
1263 last_message.id,
1264 chunk,
1265 ));
1266 } else {
1267 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1268 // of a new Assistant response.
1269 //
1270 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1271 // will result in duplicating the text of the chunk in the rendered Markdown.
1272 thread.insert_message(
1273 Role::Assistant,
1274 vec![MessageSegment::Thinking {
1275 text: chunk.to_string(),
1276 signature,
1277 }],
1278 cx,
1279 );
1280 };
1281 }
1282 }
1283 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1284 let last_assistant_message_id = thread
1285 .messages
1286 .iter_mut()
1287 .rfind(|message| message.role == Role::Assistant)
1288 .map(|message| message.id)
1289 .unwrap_or_else(|| {
1290 thread.insert_message(Role::Assistant, vec![], cx)
1291 });
1292
1293 let tool_use_id = tool_use.id.clone();
1294 let streamed_input = if tool_use.is_input_complete {
1295 None
1296 } else {
1297 Some((&tool_use.input).clone())
1298 };
1299
1300 let ui_text = thread.tool_use.request_tool_use(
1301 last_assistant_message_id,
1302 tool_use,
1303 tool_use_metadata.clone(),
1304 cx,
1305 );
1306
1307 if let Some(input) = streamed_input {
1308 cx.emit(ThreadEvent::StreamedToolUse {
1309 tool_use_id,
1310 ui_text,
1311 input,
1312 });
1313 }
1314 }
1315 }
1316
1317 thread.touch_updated_at();
1318 cx.emit(ThreadEvent::StreamedCompletion);
1319 cx.notify();
1320
1321 thread.auto_capture_telemetry(cx);
1322 })?;
1323
1324 smol::future::yield_now().await;
1325 }
1326
1327 thread.update(cx, |thread, cx| {
1328 thread
1329 .pending_completions
1330 .retain(|completion| completion.id != pending_completion_id);
1331
1332 // If there is a response without tool use, summarize the message. Otherwise,
1333 // allow two tool uses before summarizing.
1334 if thread.summary.is_none()
1335 && thread.messages.len() >= 2
1336 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1337 {
1338 thread.summarize(cx);
1339 }
1340 })?;
1341
1342 anyhow::Ok(stop_reason)
1343 };
1344
1345 let result = stream_completion.await;
1346
1347 thread
1348 .update(cx, |thread, cx| {
1349 thread.finalize_pending_checkpoint(cx);
1350 match result.as_ref() {
1351 Ok(stop_reason) => match stop_reason {
1352 StopReason::ToolUse => {
1353 let tool_uses = thread.use_pending_tools(cx);
1354 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1355 }
1356 StopReason::EndTurn => {}
1357 StopReason::MaxTokens => {}
1358 },
1359 Err(error) => {
1360 if error.is::<PaymentRequiredError>() {
1361 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1362 } else if error.is::<MaxMonthlySpendReachedError>() {
1363 cx.emit(ThreadEvent::ShowError(
1364 ThreadError::MaxMonthlySpendReached,
1365 ));
1366 } else if let Some(error) =
1367 error.downcast_ref::<ModelRequestLimitReachedError>()
1368 {
1369 cx.emit(ThreadEvent::ShowError(
1370 ThreadError::ModelRequestLimitReached { plan: error.plan },
1371 ));
1372 } else if let Some(known_error) =
1373 error.downcast_ref::<LanguageModelKnownError>()
1374 {
1375 match known_error {
1376 LanguageModelKnownError::ContextWindowLimitExceeded {
1377 tokens,
1378 } => {
1379 thread.exceeded_window_error = Some(ExceededWindowError {
1380 model_id: model.id(),
1381 token_count: *tokens,
1382 });
1383 cx.notify();
1384 }
1385 }
1386 } else {
1387 let error_message = error
1388 .chain()
1389 .map(|err| err.to_string())
1390 .collect::<Vec<_>>()
1391 .join("\n");
1392 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1393 header: "Error interacting with language model".into(),
1394 message: SharedString::from(error_message.clone()),
1395 }));
1396 }
1397
1398 thread.cancel_last_completion(cx);
1399 }
1400 }
1401 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1402
1403 if let Some((request_callback, (request, response_events))) = thread
1404 .request_callback
1405 .as_mut()
1406 .zip(request_callback_parameters.as_ref())
1407 {
1408 request_callback(request, response_events);
1409 }
1410
1411 thread.auto_capture_telemetry(cx);
1412
1413 if let Ok(initial_usage) = initial_token_usage {
1414 let usage = thread.cumulative_token_usage - initial_usage;
1415
1416 telemetry::event!(
1417 "Assistant Thread Completion",
1418 thread_id = thread.id().to_string(),
1419 prompt_id = prompt_id,
1420 model = model.telemetry_id(),
1421 model_provider = model.provider_id().to_string(),
1422 input_tokens = usage.input_tokens,
1423 output_tokens = usage.output_tokens,
1424 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1425 cache_read_input_tokens = usage.cache_read_input_tokens,
1426 );
1427 }
1428 })
1429 .ok();
1430 });
1431
1432 self.pending_completions.push(PendingCompletion {
1433 id: pending_completion_id,
1434 _task: task,
1435 });
1436 }
1437
1438 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1439 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1440 return;
1441 };
1442
1443 if !model.provider.is_authenticated(cx) {
1444 return;
1445 }
1446
1447 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1448 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1449 If the conversation is about a specific subject, include it in the title. \
1450 Be descriptive. DO NOT speak in the first person.";
1451
1452 let request = self.to_summarize_request(added_user_message.into());
1453
1454 self.pending_summary = cx.spawn(async move |this, cx| {
1455 async move {
1456 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1457 let (mut messages, usage) = stream.await?;
1458
1459 if let Some(usage) = usage {
1460 this.update(cx, |_thread, cx| {
1461 cx.emit(ThreadEvent::UsageUpdated(usage));
1462 })
1463 .ok();
1464 }
1465
1466 let mut new_summary = String::new();
1467 while let Some(message) = messages.stream.next().await {
1468 let text = message?;
1469 let mut lines = text.lines();
1470 new_summary.extend(lines.next());
1471
1472 // Stop if the LLM generated multiple lines.
1473 if lines.next().is_some() {
1474 break;
1475 }
1476 }
1477
1478 this.update(cx, |this, cx| {
1479 if !new_summary.is_empty() {
1480 this.summary = Some(new_summary.into());
1481 }
1482
1483 cx.emit(ThreadEvent::SummaryGenerated);
1484 })?;
1485
1486 anyhow::Ok(())
1487 }
1488 .log_err()
1489 .await
1490 });
1491 }
1492
1493 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1494 let last_message_id = self.messages.last().map(|message| message.id)?;
1495
1496 match &self.detailed_summary_state {
1497 DetailedSummaryState::Generating { message_id, .. }
1498 | DetailedSummaryState::Generated { message_id, .. }
1499 if *message_id == last_message_id =>
1500 {
1501 // Already up-to-date
1502 return None;
1503 }
1504 _ => {}
1505 }
1506
1507 let ConfiguredModel { model, provider } =
1508 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1509
1510 if !provider.is_authenticated(cx) {
1511 return None;
1512 }
1513
1514 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1515 1. A brief overview of what was discussed\n\
1516 2. Key facts or information discovered\n\
1517 3. Outcomes or conclusions reached\n\
1518 4. Any action items or next steps if any\n\
1519 Format it in Markdown with headings and bullet points.";
1520
1521 let request = self.to_summarize_request(added_user_message.into());
1522
1523 let task = cx.spawn(async move |thread, cx| {
1524 let stream = model.stream_completion_text(request, &cx);
1525 let Some(mut messages) = stream.await.log_err() else {
1526 thread
1527 .update(cx, |this, _cx| {
1528 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1529 })
1530 .log_err();
1531
1532 return;
1533 };
1534
1535 let mut new_detailed_summary = String::new();
1536
1537 while let Some(chunk) = messages.stream.next().await {
1538 if let Some(chunk) = chunk.log_err() {
1539 new_detailed_summary.push_str(&chunk);
1540 }
1541 }
1542
1543 thread
1544 .update(cx, |this, _cx| {
1545 this.detailed_summary_state = DetailedSummaryState::Generated {
1546 text: new_detailed_summary.into(),
1547 message_id: last_message_id,
1548 };
1549 })
1550 .log_err();
1551 });
1552
1553 self.detailed_summary_state = DetailedSummaryState::Generating {
1554 message_id: last_message_id,
1555 };
1556
1557 Some(task)
1558 }
1559
1560 pub fn is_generating_detailed_summary(&self) -> bool {
1561 matches!(
1562 self.detailed_summary_state,
1563 DetailedSummaryState::Generating { .. }
1564 )
1565 }
1566
1567 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1568 self.auto_capture_telemetry(cx);
1569 let request = self.to_completion_request(cx);
1570 let messages = Arc::new(request.messages);
1571 let pending_tool_uses = self
1572 .tool_use
1573 .pending_tool_uses()
1574 .into_iter()
1575 .filter(|tool_use| tool_use.status.is_idle())
1576 .cloned()
1577 .collect::<Vec<_>>();
1578
1579 for tool_use in pending_tool_uses.iter() {
1580 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1581 if tool.needs_confirmation(&tool_use.input, cx)
1582 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1583 {
1584 self.tool_use.confirm_tool_use(
1585 tool_use.id.clone(),
1586 tool_use.ui_text.clone(),
1587 tool_use.input.clone(),
1588 messages.clone(),
1589 tool,
1590 );
1591 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1592 } else {
1593 self.run_tool(
1594 tool_use.id.clone(),
1595 tool_use.ui_text.clone(),
1596 tool_use.input.clone(),
1597 &messages,
1598 tool,
1599 cx,
1600 );
1601 }
1602 }
1603 }
1604
1605 pending_tool_uses
1606 }
1607
1608 pub fn run_tool(
1609 &mut self,
1610 tool_use_id: LanguageModelToolUseId,
1611 ui_text: impl Into<SharedString>,
1612 input: serde_json::Value,
1613 messages: &[LanguageModelRequestMessage],
1614 tool: Arc<dyn Tool>,
1615 cx: &mut Context<Thread>,
1616 ) {
1617 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1618 self.tool_use
1619 .run_pending_tool(tool_use_id, ui_text.into(), task);
1620 }
1621
1622 fn spawn_tool_use(
1623 &mut self,
1624 tool_use_id: LanguageModelToolUseId,
1625 messages: &[LanguageModelRequestMessage],
1626 input: serde_json::Value,
1627 tool: Arc<dyn Tool>,
1628 cx: &mut Context<Thread>,
1629 ) -> Task<()> {
1630 let tool_name: Arc<str> = tool.name().into();
1631
1632 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1633 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1634 } else {
1635 tool.run(
1636 input,
1637 messages,
1638 self.project.clone(),
1639 self.action_log.clone(),
1640 cx,
1641 )
1642 };
1643
1644 // Store the card separately if it exists
1645 if let Some(card) = tool_result.card.clone() {
1646 self.tool_use
1647 .insert_tool_result_card(tool_use_id.clone(), card);
1648 }
1649
1650 cx.spawn({
1651 async move |thread: WeakEntity<Thread>, cx| {
1652 let output = tool_result.output.await;
1653
1654 thread
1655 .update(cx, |thread, cx| {
1656 let pending_tool_use = thread.tool_use.insert_tool_output(
1657 tool_use_id.clone(),
1658 tool_name,
1659 output,
1660 cx,
1661 );
1662 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1663 })
1664 .ok();
1665 }
1666 })
1667 }
1668
1669 fn tool_finished(
1670 &mut self,
1671 tool_use_id: LanguageModelToolUseId,
1672 pending_tool_use: Option<PendingToolUse>,
1673 canceled: bool,
1674 cx: &mut Context<Self>,
1675 ) {
1676 if self.all_tools_finished() {
1677 let model_registry = LanguageModelRegistry::read_global(cx);
1678 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1679 self.attach_tool_results(cx);
1680 if !canceled {
1681 self.send_to_model(model, cx);
1682 }
1683 }
1684 }
1685
1686 cx.emit(ThreadEvent::ToolFinished {
1687 tool_use_id,
1688 pending_tool_use,
1689 });
1690 }
1691
1692 /// Insert an empty message to be populated with tool results upon send.
1693 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1694 // Tool results are assumed to be waiting on the next message id, so they will populate
1695 // this empty message before sending to model. Would prefer this to be more straightforward.
1696 self.insert_message(Role::User, vec![], cx);
1697 self.auto_capture_telemetry(cx);
1698 }
1699
1700 /// Cancels the last pending completion, if there are any pending.
1701 ///
1702 /// Returns whether a completion was canceled.
1703 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1704 let canceled = if self.pending_completions.pop().is_some() {
1705 true
1706 } else {
1707 let mut canceled = false;
1708 for pending_tool_use in self.tool_use.cancel_pending() {
1709 canceled = true;
1710 self.tool_finished(
1711 pending_tool_use.id.clone(),
1712 Some(pending_tool_use),
1713 true,
1714 cx,
1715 );
1716 }
1717 canceled
1718 };
1719 self.finalize_pending_checkpoint(cx);
1720 canceled
1721 }
1722
1723 pub fn feedback(&self) -> Option<ThreadFeedback> {
1724 self.feedback
1725 }
1726
1727 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1728 self.message_feedback.get(&message_id).copied()
1729 }
1730
1731 pub fn report_message_feedback(
1732 &mut self,
1733 message_id: MessageId,
1734 feedback: ThreadFeedback,
1735 cx: &mut Context<Self>,
1736 ) -> Task<Result<()>> {
1737 if self.message_feedback.get(&message_id) == Some(&feedback) {
1738 return Task::ready(Ok(()));
1739 }
1740
1741 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1742 let serialized_thread = self.serialize(cx);
1743 let thread_id = self.id().clone();
1744 let client = self.project.read(cx).client();
1745
1746 let enabled_tool_names: Vec<String> = self
1747 .tools()
1748 .read(cx)
1749 .enabled_tools(cx)
1750 .iter()
1751 .map(|tool| tool.name().to_string())
1752 .collect();
1753
1754 self.message_feedback.insert(message_id, feedback);
1755
1756 cx.notify();
1757
1758 let message_content = self
1759 .message(message_id)
1760 .map(|msg| msg.to_string())
1761 .unwrap_or_default();
1762
1763 cx.background_spawn(async move {
1764 let final_project_snapshot = final_project_snapshot.await;
1765 let serialized_thread = serialized_thread.await?;
1766 let thread_data =
1767 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1768
1769 let rating = match feedback {
1770 ThreadFeedback::Positive => "positive",
1771 ThreadFeedback::Negative => "negative",
1772 };
1773 telemetry::event!(
1774 "Assistant Thread Rated",
1775 rating,
1776 thread_id,
1777 enabled_tool_names,
1778 message_id = message_id.0,
1779 message_content,
1780 thread_data,
1781 final_project_snapshot
1782 );
1783 client.telemetry().flush_events();
1784
1785 Ok(())
1786 })
1787 }
1788
1789 pub fn report_feedback(
1790 &mut self,
1791 feedback: ThreadFeedback,
1792 cx: &mut Context<Self>,
1793 ) -> Task<Result<()>> {
1794 let last_assistant_message_id = self
1795 .messages
1796 .iter()
1797 .rev()
1798 .find(|msg| msg.role == Role::Assistant)
1799 .map(|msg| msg.id);
1800
1801 if let Some(message_id) = last_assistant_message_id {
1802 self.report_message_feedback(message_id, feedback, cx)
1803 } else {
1804 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1805 let serialized_thread = self.serialize(cx);
1806 let thread_id = self.id().clone();
1807 let client = self.project.read(cx).client();
1808 self.feedback = Some(feedback);
1809 cx.notify();
1810
1811 cx.background_spawn(async move {
1812 let final_project_snapshot = final_project_snapshot.await;
1813 let serialized_thread = serialized_thread.await?;
1814 let thread_data = serde_json::to_value(serialized_thread)
1815 .unwrap_or_else(|_| serde_json::Value::Null);
1816
1817 let rating = match feedback {
1818 ThreadFeedback::Positive => "positive",
1819 ThreadFeedback::Negative => "negative",
1820 };
1821 telemetry::event!(
1822 "Assistant Thread Rated",
1823 rating,
1824 thread_id,
1825 thread_data,
1826 final_project_snapshot
1827 );
1828 client.telemetry().flush_events();
1829
1830 Ok(())
1831 })
1832 }
1833 }
1834
1835 /// Create a snapshot of the current project state including git information and unsaved buffers.
1836 fn project_snapshot(
1837 project: Entity<Project>,
1838 cx: &mut Context<Self>,
1839 ) -> Task<Arc<ProjectSnapshot>> {
1840 let git_store = project.read(cx).git_store().clone();
1841 let worktree_snapshots: Vec<_> = project
1842 .read(cx)
1843 .visible_worktrees(cx)
1844 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1845 .collect();
1846
1847 cx.spawn(async move |_, cx| {
1848 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1849
1850 let mut unsaved_buffers = Vec::new();
1851 cx.update(|app_cx| {
1852 let buffer_store = project.read(app_cx).buffer_store();
1853 for buffer_handle in buffer_store.read(app_cx).buffers() {
1854 let buffer = buffer_handle.read(app_cx);
1855 if buffer.is_dirty() {
1856 if let Some(file) = buffer.file() {
1857 let path = file.path().to_string_lossy().to_string();
1858 unsaved_buffers.push(path);
1859 }
1860 }
1861 }
1862 })
1863 .ok();
1864
1865 Arc::new(ProjectSnapshot {
1866 worktree_snapshots,
1867 unsaved_buffer_paths: unsaved_buffers,
1868 timestamp: Utc::now(),
1869 })
1870 })
1871 }
1872
1873 fn worktree_snapshot(
1874 worktree: Entity<project::Worktree>,
1875 git_store: Entity<GitStore>,
1876 cx: &App,
1877 ) -> Task<WorktreeSnapshot> {
1878 cx.spawn(async move |cx| {
1879 // Get worktree path and snapshot
1880 let worktree_info = cx.update(|app_cx| {
1881 let worktree = worktree.read(app_cx);
1882 let path = worktree.abs_path().to_string_lossy().to_string();
1883 let snapshot = worktree.snapshot();
1884 (path, snapshot)
1885 });
1886
1887 let Ok((worktree_path, _snapshot)) = worktree_info else {
1888 return WorktreeSnapshot {
1889 worktree_path: String::new(),
1890 git_state: None,
1891 };
1892 };
1893
1894 let git_state = git_store
1895 .update(cx, |git_store, cx| {
1896 git_store
1897 .repositories()
1898 .values()
1899 .find(|repo| {
1900 repo.read(cx)
1901 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1902 .is_some()
1903 })
1904 .cloned()
1905 })
1906 .ok()
1907 .flatten()
1908 .map(|repo| {
1909 repo.update(cx, |repo, _| {
1910 let current_branch =
1911 repo.branch.as_ref().map(|branch| branch.name.to_string());
1912 repo.send_job(None, |state, _| async move {
1913 let RepositoryState::Local { backend, .. } = state else {
1914 return GitState {
1915 remote_url: None,
1916 head_sha: None,
1917 current_branch,
1918 diff: None,
1919 };
1920 };
1921
1922 let remote_url = backend.remote_url("origin");
1923 let head_sha = backend.head_sha();
1924 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1925
1926 GitState {
1927 remote_url,
1928 head_sha,
1929 current_branch,
1930 diff,
1931 }
1932 })
1933 })
1934 });
1935
1936 let git_state = match git_state {
1937 Some(git_state) => match git_state.ok() {
1938 Some(git_state) => git_state.await.ok(),
1939 None => None,
1940 },
1941 None => None,
1942 };
1943
1944 WorktreeSnapshot {
1945 worktree_path,
1946 git_state,
1947 }
1948 })
1949 }
1950
1951 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1952 let mut markdown = Vec::new();
1953
1954 if let Some(summary) = self.summary() {
1955 writeln!(markdown, "# {summary}\n")?;
1956 };
1957
1958 for message in self.messages() {
1959 writeln!(
1960 markdown,
1961 "## {role}\n",
1962 role = match message.role {
1963 Role::User => "User",
1964 Role::Assistant => "Assistant",
1965 Role::System => "System",
1966 }
1967 )?;
1968
1969 if !message.context.is_empty() {
1970 writeln!(markdown, "{}", message.context)?;
1971 }
1972
1973 for segment in &message.segments {
1974 match segment {
1975 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1976 MessageSegment::Thinking { text, .. } => {
1977 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1978 }
1979 MessageSegment::RedactedThinking(_) => {}
1980 }
1981 }
1982
1983 for tool_use in self.tool_uses_for_message(message.id, cx) {
1984 writeln!(
1985 markdown,
1986 "**Use Tool: {} ({})**",
1987 tool_use.name, tool_use.id
1988 )?;
1989 writeln!(markdown, "```json")?;
1990 writeln!(
1991 markdown,
1992 "{}",
1993 serde_json::to_string_pretty(&tool_use.input)?
1994 )?;
1995 writeln!(markdown, "```")?;
1996 }
1997
1998 for tool_result in self.tool_results_for_message(message.id) {
1999 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2000 if tool_result.is_error {
2001 write!(markdown, " (Error)")?;
2002 }
2003
2004 writeln!(markdown, "**\n")?;
2005 writeln!(markdown, "{}", tool_result.content)?;
2006 }
2007 }
2008
2009 Ok(String::from_utf8_lossy(&markdown).to_string())
2010 }
2011
2012 pub fn keep_edits_in_range(
2013 &mut self,
2014 buffer: Entity<language::Buffer>,
2015 buffer_range: Range<language::Anchor>,
2016 cx: &mut Context<Self>,
2017 ) {
2018 self.action_log.update(cx, |action_log, cx| {
2019 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2020 });
2021 }
2022
2023 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2024 self.action_log
2025 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2026 }
2027
2028 pub fn reject_edits_in_ranges(
2029 &mut self,
2030 buffer: Entity<language::Buffer>,
2031 buffer_ranges: Vec<Range<language::Anchor>>,
2032 cx: &mut Context<Self>,
2033 ) -> Task<Result<()>> {
2034 self.action_log.update(cx, |action_log, cx| {
2035 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2036 })
2037 }
2038
2039 pub fn action_log(&self) -> &Entity<ActionLog> {
2040 &self.action_log
2041 }
2042
2043 pub fn project(&self) -> &Entity<Project> {
2044 &self.project
2045 }
2046
2047 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2048 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2049 return;
2050 }
2051
2052 let now = Instant::now();
2053 if let Some(last) = self.last_auto_capture_at {
2054 if now.duration_since(last).as_secs() < 10 {
2055 return;
2056 }
2057 }
2058
2059 self.last_auto_capture_at = Some(now);
2060
2061 let thread_id = self.id().clone();
2062 let github_login = self
2063 .project
2064 .read(cx)
2065 .user_store()
2066 .read(cx)
2067 .current_user()
2068 .map(|user| user.github_login.clone());
2069 let client = self.project.read(cx).client().clone();
2070 let serialize_task = self.serialize(cx);
2071
2072 cx.background_executor()
2073 .spawn(async move {
2074 if let Ok(serialized_thread) = serialize_task.await {
2075 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2076 telemetry::event!(
2077 "Agent Thread Auto-Captured",
2078 thread_id = thread_id.to_string(),
2079 thread_data = thread_data,
2080 auto_capture_reason = "tracked_user",
2081 github_login = github_login
2082 );
2083
2084 client.telemetry().flush_events();
2085 }
2086 }
2087 })
2088 .detach();
2089 }
2090
2091 pub fn cumulative_token_usage(&self) -> TokenUsage {
2092 self.cumulative_token_usage
2093 }
2094
2095 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2096 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2097 return TotalTokenUsage::default();
2098 };
2099
2100 let max = model.model.max_token_count();
2101
2102 let index = self
2103 .messages
2104 .iter()
2105 .position(|msg| msg.id == message_id)
2106 .unwrap_or(0);
2107
2108 if index == 0 {
2109 return TotalTokenUsage { total: 0, max };
2110 }
2111
2112 let token_usage = &self
2113 .request_token_usage
2114 .get(index - 1)
2115 .cloned()
2116 .unwrap_or_default();
2117
2118 TotalTokenUsage {
2119 total: token_usage.total_tokens() as usize,
2120 max,
2121 }
2122 }
2123
2124 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2125 let model_registry = LanguageModelRegistry::read_global(cx);
2126 let Some(model) = model_registry.default_model() else {
2127 return TotalTokenUsage::default();
2128 };
2129
2130 let max = model.model.max_token_count();
2131
2132 if let Some(exceeded_error) = &self.exceeded_window_error {
2133 if model.model.id() == exceeded_error.model_id {
2134 return TotalTokenUsage {
2135 total: exceeded_error.token_count,
2136 max,
2137 };
2138 }
2139 }
2140
2141 let total = self
2142 .token_usage_at_last_message()
2143 .unwrap_or_default()
2144 .total_tokens() as usize;
2145
2146 TotalTokenUsage { total, max }
2147 }
2148
2149 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2150 self.request_token_usage
2151 .get(self.messages.len().saturating_sub(1))
2152 .or_else(|| self.request_token_usage.last())
2153 .cloned()
2154 }
2155
2156 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2157 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2158 self.request_token_usage
2159 .resize(self.messages.len(), placeholder);
2160
2161 if let Some(last) = self.request_token_usage.last_mut() {
2162 *last = token_usage;
2163 }
2164 }
2165
2166 pub fn deny_tool_use(
2167 &mut self,
2168 tool_use_id: LanguageModelToolUseId,
2169 tool_name: Arc<str>,
2170 cx: &mut Context<Self>,
2171 ) {
2172 let err = Err(anyhow::anyhow!(
2173 "Permission to run tool action denied by user"
2174 ));
2175
2176 self.tool_use
2177 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2178 self.tool_finished(tool_use_id.clone(), None, true, cx);
2179 }
2180}
2181
2182#[derive(Debug, Clone, Error)]
2183pub enum ThreadError {
2184 #[error("Payment required")]
2185 PaymentRequired,
2186 #[error("Max monthly spend reached")]
2187 MaxMonthlySpendReached,
2188 #[error("Model request limit reached")]
2189 ModelRequestLimitReached { plan: Plan },
2190 #[error("Message {header}: {message}")]
2191 Message {
2192 header: SharedString,
2193 message: SharedString,
2194 },
2195}
2196
2197#[derive(Debug, Clone)]
2198pub enum ThreadEvent {
2199 ShowError(ThreadError),
2200 UsageUpdated(RequestUsage),
2201 StreamedCompletion,
2202 StreamedAssistantText(MessageId, String),
2203 StreamedAssistantThinking(MessageId, String),
2204 StreamedToolUse {
2205 tool_use_id: LanguageModelToolUseId,
2206 ui_text: Arc<str>,
2207 input: serde_json::Value,
2208 },
2209 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2210 MessageAdded(MessageId),
2211 MessageEdited(MessageId),
2212 MessageDeleted(MessageId),
2213 SummaryGenerated,
2214 SummaryChanged,
2215 UsePendingTools {
2216 tool_uses: Vec<PendingToolUse>,
2217 },
2218 ToolFinished {
2219 #[allow(unused)]
2220 tool_use_id: LanguageModelToolUseId,
2221 /// The pending tool use that corresponds to this tool.
2222 pending_tool_use: Option<PendingToolUse>,
2223 },
2224 CheckpointChanged,
2225 ToolConfirmationNeeded,
2226}
2227
2228impl EventEmitter<ThreadEvent> for Thread {}
2229
2230struct PendingCompletion {
2231 id: usize,
2232 _task: Task<()>,
2233}
2234
2235#[cfg(test)]
2236mod tests {
2237 use super::*;
2238 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2239 use assistant_settings::AssistantSettings;
2240 use context_server::ContextServerSettings;
2241 use editor::EditorSettings;
2242 use gpui::TestAppContext;
2243 use project::{FakeFs, Project};
2244 use prompt_store::PromptBuilder;
2245 use serde_json::json;
2246 use settings::{Settings, SettingsStore};
2247 use std::sync::Arc;
2248 use theme::ThemeSettings;
2249 use util::path;
2250 use workspace::Workspace;
2251
2252 #[gpui::test]
2253 async fn test_message_with_context(cx: &mut TestAppContext) {
2254 init_test_settings(cx);
2255
2256 let project = create_test_project(
2257 cx,
2258 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2259 )
2260 .await;
2261
2262 let (_workspace, _thread_store, thread, context_store) =
2263 setup_test_environment(cx, project.clone()).await;
2264
2265 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2266 .await
2267 .unwrap();
2268
2269 let context =
2270 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2271
2272 // Insert user message with context
2273 let message_id = thread.update(cx, |thread, cx| {
2274 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2275 });
2276
2277 // Check content and context in message object
2278 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2279
2280 // Use different path format strings based on platform for the test
2281 #[cfg(windows)]
2282 let path_part = r"test\code.rs";
2283 #[cfg(not(windows))]
2284 let path_part = "test/code.rs";
2285
2286 let expected_context = format!(
2287 r#"
2288<context>
2289The following items were attached by the user. You don't need to use other tools to read them.
2290
2291<files>
2292```rs {path_part}
2293fn main() {{
2294 println!("Hello, world!");
2295}}
2296```
2297</files>
2298</context>
2299"#
2300 );
2301
2302 assert_eq!(message.role, Role::User);
2303 assert_eq!(message.segments.len(), 1);
2304 assert_eq!(
2305 message.segments[0],
2306 MessageSegment::Text("Please explain this code".to_string())
2307 );
2308 assert_eq!(message.context, expected_context);
2309
2310 // Check message in request
2311 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2312
2313 assert_eq!(request.messages.len(), 2);
2314 let expected_full_message = format!("{}Please explain this code", expected_context);
2315 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2316 }
2317
2318 #[gpui::test]
2319 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2320 init_test_settings(cx);
2321
2322 let project = create_test_project(
2323 cx,
2324 json!({
2325 "file1.rs": "fn function1() {}\n",
2326 "file2.rs": "fn function2() {}\n",
2327 "file3.rs": "fn function3() {}\n",
2328 }),
2329 )
2330 .await;
2331
2332 let (_, _thread_store, thread, context_store) =
2333 setup_test_environment(cx, project.clone()).await;
2334
2335 // Open files individually
2336 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2337 .await
2338 .unwrap();
2339 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2340 .await
2341 .unwrap();
2342 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2343 .await
2344 .unwrap();
2345
2346 // Get the context objects
2347 let contexts = context_store.update(cx, |store, _| store.context().clone());
2348 assert_eq!(contexts.len(), 3);
2349
2350 // First message with context 1
2351 let message1_id = thread.update(cx, |thread, cx| {
2352 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2353 });
2354
2355 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2356 let message2_id = thread.update(cx, |thread, cx| {
2357 thread.insert_user_message(
2358 "Message 2",
2359 vec![contexts[0].clone(), contexts[1].clone()],
2360 None,
2361 cx,
2362 )
2363 });
2364
2365 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2366 let message3_id = thread.update(cx, |thread, cx| {
2367 thread.insert_user_message(
2368 "Message 3",
2369 vec![
2370 contexts[0].clone(),
2371 contexts[1].clone(),
2372 contexts[2].clone(),
2373 ],
2374 None,
2375 cx,
2376 )
2377 });
2378
2379 // Check what contexts are included in each message
2380 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2381 (
2382 thread.message(message1_id).unwrap().clone(),
2383 thread.message(message2_id).unwrap().clone(),
2384 thread.message(message3_id).unwrap().clone(),
2385 )
2386 });
2387
2388 // First message should include context 1
2389 assert!(message1.context.contains("file1.rs"));
2390
2391 // Second message should include only context 2 (not 1)
2392 assert!(!message2.context.contains("file1.rs"));
2393 assert!(message2.context.contains("file2.rs"));
2394
2395 // Third message should include only context 3 (not 1 or 2)
2396 assert!(!message3.context.contains("file1.rs"));
2397 assert!(!message3.context.contains("file2.rs"));
2398 assert!(message3.context.contains("file3.rs"));
2399
2400 // Check entire request to make sure all contexts are properly included
2401 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2402
2403 // The request should contain all 3 messages
2404 assert_eq!(request.messages.len(), 4);
2405
2406 // Check that the contexts are properly formatted in each message
2407 assert!(request.messages[1].string_contents().contains("file1.rs"));
2408 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2409 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2410
2411 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2412 assert!(request.messages[2].string_contents().contains("file2.rs"));
2413 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2414
2415 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2416 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2417 assert!(request.messages[3].string_contents().contains("file3.rs"));
2418 }
2419
2420 #[gpui::test]
2421 async fn test_message_without_files(cx: &mut TestAppContext) {
2422 init_test_settings(cx);
2423
2424 let project = create_test_project(
2425 cx,
2426 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2427 )
2428 .await;
2429
2430 let (_, _thread_store, thread, _context_store) =
2431 setup_test_environment(cx, project.clone()).await;
2432
2433 // Insert user message without any context (empty context vector)
2434 let message_id = thread.update(cx, |thread, cx| {
2435 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2436 });
2437
2438 // Check content and context in message object
2439 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2440
2441 // Context should be empty when no files are included
2442 assert_eq!(message.role, Role::User);
2443 assert_eq!(message.segments.len(), 1);
2444 assert_eq!(
2445 message.segments[0],
2446 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2447 );
2448 assert_eq!(message.context, "");
2449
2450 // Check message in request
2451 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2452
2453 assert_eq!(request.messages.len(), 2);
2454 assert_eq!(
2455 request.messages[1].string_contents(),
2456 "What is the best way to learn Rust?"
2457 );
2458
2459 // Add second message, also without context
2460 let message2_id = thread.update(cx, |thread, cx| {
2461 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2462 });
2463
2464 let message2 =
2465 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2466 assert_eq!(message2.context, "");
2467
2468 // Check that both messages appear in the request
2469 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2470
2471 assert_eq!(request.messages.len(), 3);
2472 assert_eq!(
2473 request.messages[1].string_contents(),
2474 "What is the best way to learn Rust?"
2475 );
2476 assert_eq!(
2477 request.messages[2].string_contents(),
2478 "Are there any good books?"
2479 );
2480 }
2481
2482 #[gpui::test]
2483 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2484 init_test_settings(cx);
2485
2486 let project = create_test_project(
2487 cx,
2488 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2489 )
2490 .await;
2491
2492 let (_workspace, _thread_store, thread, context_store) =
2493 setup_test_environment(cx, project.clone()).await;
2494
2495 // Open buffer and add it to context
2496 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2497 .await
2498 .unwrap();
2499
2500 let context =
2501 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2502
2503 // Insert user message with the buffer as context
2504 thread.update(cx, |thread, cx| {
2505 thread.insert_user_message("Explain this code", vec![context], None, cx)
2506 });
2507
2508 // Create a request and check that it doesn't have a stale buffer warning yet
2509 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2510
2511 // Make sure we don't have a stale file warning yet
2512 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2513 msg.string_contents()
2514 .contains("These files changed since last read:")
2515 });
2516 assert!(
2517 !has_stale_warning,
2518 "Should not have stale buffer warning before buffer is modified"
2519 );
2520
2521 // Modify the buffer
2522 buffer.update(cx, |buffer, cx| {
2523 // Find a position at the end of line 1
2524 buffer.edit(
2525 [(1..1, "\n println!(\"Added a new line\");\n")],
2526 None,
2527 cx,
2528 );
2529 });
2530
2531 // Insert another user message without context
2532 thread.update(cx, |thread, cx| {
2533 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2534 });
2535
2536 // Create a new request and check for the stale buffer warning
2537 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2538
2539 // We should have a stale file warning as the last message
2540 let last_message = new_request
2541 .messages
2542 .last()
2543 .expect("Request should have messages");
2544
2545 // The last message should be the stale buffer notification
2546 assert_eq!(last_message.role, Role::User);
2547
2548 // Check the exact content of the message
2549 let expected_content = "These files changed since last read:\n- code.rs\n";
2550 assert_eq!(
2551 last_message.string_contents(),
2552 expected_content,
2553 "Last message should be exactly the stale buffer notification"
2554 );
2555 }
2556
2557 fn init_test_settings(cx: &mut TestAppContext) {
2558 cx.update(|cx| {
2559 let settings_store = SettingsStore::test(cx);
2560 cx.set_global(settings_store);
2561 language::init(cx);
2562 Project::init_settings(cx);
2563 AssistantSettings::register(cx);
2564 prompt_store::init(cx);
2565 thread_store::init(cx);
2566 workspace::init_settings(cx);
2567 ThemeSettings::register(cx);
2568 ContextServerSettings::register(cx);
2569 EditorSettings::register(cx);
2570 });
2571 }
2572
2573 // Helper to create a test project with test files
2574 async fn create_test_project(
2575 cx: &mut TestAppContext,
2576 files: serde_json::Value,
2577 ) -> Entity<Project> {
2578 let fs = FakeFs::new(cx.executor());
2579 fs.insert_tree(path!("/test"), files).await;
2580 Project::test(fs, [path!("/test").as_ref()], cx).await
2581 }
2582
2583 async fn setup_test_environment(
2584 cx: &mut TestAppContext,
2585 project: Entity<Project>,
2586 ) -> (
2587 Entity<Workspace>,
2588 Entity<ThreadStore>,
2589 Entity<Thread>,
2590 Entity<ContextStore>,
2591 ) {
2592 let (workspace, cx) =
2593 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2594
2595 let thread_store = cx
2596 .update(|_, cx| {
2597 ThreadStore::load(
2598 project.clone(),
2599 cx.new(|_| ToolWorkingSet::default()),
2600 Arc::new(PromptBuilder::new(None).unwrap()),
2601 cx,
2602 )
2603 })
2604 .await
2605 .unwrap();
2606
2607 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2608 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2609
2610 (workspace, thread_store, thread, context_store)
2611 }
2612
2613 async fn add_file_to_context(
2614 project: &Entity<Project>,
2615 context_store: &Entity<ContextStore>,
2616 path: &str,
2617 cx: &mut TestAppContext,
2618 ) -> Result<Entity<language::Buffer>> {
2619 let buffer_path = project
2620 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2621 .unwrap();
2622
2623 let buffer = project
2624 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2625 .await
2626 .unwrap();
2627
2628 context_store
2629 .update(cx, |store, cx| {
2630 store.add_file_from_buffer(buffer.clone(), cx)
2631 })
2632 .await?;
2633
2634 Ok(buffer)
2635 }
2636}