1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
18};
19use language_model::{
20 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
21 LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
22 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
23 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
24 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
25 TokenUsage,
26};
27use project::Project;
28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
29use prompt_store::PromptBuilder;
30use proto::Plan;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use settings::Settings;
34use thiserror::Error;
35use util::{ResultExt as _, TryFutureExt as _, post_inc};
36use uuid::Uuid;
37
38use crate::context::{AssistantContext, ContextId, format_context_as_string};
39use crate::thread_store::{
40 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
41 SerializedToolUse, SharedProjectContext,
42};
43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
44
45#[derive(
46 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
47)]
48pub struct ThreadId(Arc<str>);
49
50impl ThreadId {
51 pub fn new() -> Self {
52 Self(Uuid::new_v4().to_string().into())
53 }
54}
55
56impl std::fmt::Display for ThreadId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62impl From<&str> for ThreadId {
63 fn from(value: &str) -> Self {
64 Self(value.into())
65 }
66}
67
68/// The ID of the user prompt that initiated a request.
69///
70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
72pub struct PromptId(Arc<str>);
73
74impl PromptId {
75 pub fn new() -> Self {
76 Self(Uuid::new_v4().to_string().into())
77 }
78}
79
80impl std::fmt::Display for PromptId {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
87pub struct MessageId(pub(crate) usize);
88
89impl MessageId {
90 fn post_inc(&mut self) -> Self {
91 Self(post_inc(&mut self.0))
92 }
93}
94
95/// A message in a [`Thread`].
96#[derive(Debug, Clone)]
97pub struct Message {
98 pub id: MessageId,
99 pub role: Role,
100 pub segments: Vec<MessageSegment>,
101 pub context: String,
102 pub images: Vec<LanguageModelImage>,
103}
104
105impl Message {
106 /// Returns whether the message contains any meaningful text that should be displayed
107 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
108 pub fn should_display_content(&self) -> bool {
109 self.segments.iter().all(|segment| segment.should_display())
110 }
111
112 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
113 if let Some(MessageSegment::Thinking {
114 text: segment,
115 signature: current_signature,
116 }) = self.segments.last_mut()
117 {
118 if let Some(signature) = signature {
119 *current_signature = Some(signature);
120 }
121 segment.push_str(text);
122 } else {
123 self.segments.push(MessageSegment::Thinking {
124 text: text.to_string(),
125 signature,
126 });
127 }
128 }
129
130 pub fn push_text(&mut self, text: &str) {
131 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
132 segment.push_str(text);
133 } else {
134 self.segments.push(MessageSegment::Text(text.to_string()));
135 }
136 }
137
138 pub fn to_string(&self) -> String {
139 let mut result = String::new();
140
141 if !self.context.is_empty() {
142 result.push_str(&self.context);
143 }
144
145 for segment in &self.segments {
146 match segment {
147 MessageSegment::Text(text) => result.push_str(text),
148 MessageSegment::Thinking { text, .. } => {
149 result.push_str("<think>\n");
150 result.push_str(text);
151 result.push_str("\n</think>");
152 }
153 MessageSegment::RedactedThinking(_) => {}
154 }
155 }
156
157 result
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
162pub enum MessageSegment {
163 Text(String),
164 Thinking {
165 text: String,
166 signature: Option<String>,
167 },
168 RedactedThinking(Vec<u8>),
169}
170
171impl MessageSegment {
172 pub fn should_display(&self) -> bool {
173 match self {
174 Self::Text(text) => text.is_empty(),
175 Self::Thinking { text, .. } => text.is_empty(),
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 messages: Vec<Message>,
295 next_message_id: MessageId,
296 last_prompt_id: PromptId,
297 context: BTreeMap<ContextId, AssistantContext>,
298 context_by_message: HashMap<MessageId, Vec<ContextId>>,
299 project_context: SharedProjectContext,
300 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
301 completion_count: usize,
302 pending_completions: Vec<PendingCompletion>,
303 project: Entity<Project>,
304 prompt_builder: Arc<PromptBuilder>,
305 tools: Entity<ToolWorkingSet>,
306 tool_use: ToolUseState,
307 action_log: Entity<ActionLog>,
308 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
309 pending_checkpoint: Option<ThreadCheckpoint>,
310 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
311 request_token_usage: Vec<TokenUsage>,
312 cumulative_token_usage: TokenUsage,
313 exceeded_window_error: Option<ExceededWindowError>,
314 feedback: Option<ThreadFeedback>,
315 message_feedback: HashMap<MessageId, ThreadFeedback>,
316 last_auto_capture_at: Option<Instant>,
317 request_callback: Option<
318 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
319 >,
320 remaining_turns: u32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ExceededWindowError {
325 /// Model used when last message exceeded context window
326 model_id: LanguageModelId,
327 /// Token count including last message
328 token_count: usize,
329}
330
331impl Thread {
332 pub fn new(
333 project: Entity<Project>,
334 tools: Entity<ToolWorkingSet>,
335 prompt_builder: Arc<PromptBuilder>,
336 system_prompt: SharedProjectContext,
337 cx: &mut Context<Self>,
338 ) -> Self {
339 Self {
340 id: ThreadId::new(),
341 updated_at: Utc::now(),
342 summary: None,
343 pending_summary: Task::ready(None),
344 detailed_summary_state: DetailedSummaryState::NotGenerated,
345 messages: Vec::new(),
346 next_message_id: MessageId(0),
347 last_prompt_id: PromptId::new(),
348 context: BTreeMap::default(),
349 context_by_message: HashMap::default(),
350 project_context: system_prompt,
351 checkpoints_by_message: HashMap::default(),
352 completion_count: 0,
353 pending_completions: Vec::new(),
354 project: project.clone(),
355 prompt_builder,
356 tools: tools.clone(),
357 last_restore_checkpoint: None,
358 pending_checkpoint: None,
359 tool_use: ToolUseState::new(tools.clone()),
360 action_log: cx.new(|_| ActionLog::new(project.clone())),
361 initial_project_snapshot: {
362 let project_snapshot = Self::project_snapshot(project, cx);
363 cx.foreground_executor()
364 .spawn(async move { Some(project_snapshot.await) })
365 .shared()
366 },
367 request_token_usage: Vec::new(),
368 cumulative_token_usage: TokenUsage::default(),
369 exceeded_window_error: None,
370 feedback: None,
371 message_feedback: HashMap::default(),
372 last_auto_capture_at: None,
373 request_callback: None,
374 remaining_turns: u32::MAX,
375 }
376 }
377
378 pub fn deserialize(
379 id: ThreadId,
380 serialized: SerializedThread,
381 project: Entity<Project>,
382 tools: Entity<ToolWorkingSet>,
383 prompt_builder: Arc<PromptBuilder>,
384 project_context: SharedProjectContext,
385 cx: &mut Context<Self>,
386 ) -> Self {
387 let next_message_id = MessageId(
388 serialized
389 .messages
390 .last()
391 .map(|message| message.id.0 + 1)
392 .unwrap_or(0),
393 );
394 let tool_use =
395 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
396
397 Self {
398 id,
399 updated_at: serialized.updated_at,
400 summary: Some(serialized.summary),
401 pending_summary: Task::ready(None),
402 detailed_summary_state: serialized.detailed_summary_state,
403 messages: serialized
404 .messages
405 .into_iter()
406 .map(|message| Message {
407 id: message.id,
408 role: message.role,
409 segments: message
410 .segments
411 .into_iter()
412 .map(|segment| match segment {
413 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
414 SerializedMessageSegment::Thinking { text, signature } => {
415 MessageSegment::Thinking { text, signature }
416 }
417 SerializedMessageSegment::RedactedThinking { data } => {
418 MessageSegment::RedactedThinking(data)
419 }
420 })
421 .collect(),
422 context: message.context,
423 images: Vec::new(),
424 })
425 .collect(),
426 next_message_id,
427 last_prompt_id: PromptId::new(),
428 context: BTreeMap::default(),
429 context_by_message: HashMap::default(),
430 project_context,
431 checkpoints_by_message: HashMap::default(),
432 completion_count: 0,
433 pending_completions: Vec::new(),
434 last_restore_checkpoint: None,
435 pending_checkpoint: None,
436 project: project.clone(),
437 prompt_builder,
438 tools,
439 tool_use,
440 action_log: cx.new(|_| ActionLog::new(project)),
441 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
442 request_token_usage: serialized.request_token_usage,
443 cumulative_token_usage: serialized.cumulative_token_usage,
444 exceeded_window_error: None,
445 feedback: None,
446 message_feedback: HashMap::default(),
447 last_auto_capture_at: None,
448 request_callback: None,
449 remaining_turns: u32::MAX,
450 }
451 }
452
453 pub fn set_request_callback(
454 &mut self,
455 callback: impl 'static
456 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
457 ) {
458 self.request_callback = Some(Box::new(callback));
459 }
460
461 pub fn id(&self) -> &ThreadId {
462 &self.id
463 }
464
465 pub fn is_empty(&self) -> bool {
466 self.messages.is_empty()
467 }
468
469 pub fn updated_at(&self) -> DateTime<Utc> {
470 self.updated_at
471 }
472
473 pub fn touch_updated_at(&mut self) {
474 self.updated_at = Utc::now();
475 }
476
477 pub fn advance_prompt_id(&mut self) {
478 self.last_prompt_id = PromptId::new();
479 }
480
481 pub fn summary(&self) -> Option<SharedString> {
482 self.summary.clone()
483 }
484
485 pub fn project_context(&self) -> SharedProjectContext {
486 self.project_context.clone()
487 }
488
489 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
490
491 pub fn summary_or_default(&self) -> SharedString {
492 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
493 }
494
495 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
496 let Some(current_summary) = &self.summary else {
497 // Don't allow setting summary until generated
498 return;
499 };
500
501 let mut new_summary = new_summary.into();
502
503 if new_summary.is_empty() {
504 new_summary = Self::DEFAULT_SUMMARY;
505 }
506
507 if current_summary != &new_summary {
508 self.summary = Some(new_summary);
509 cx.emit(ThreadEvent::SummaryChanged);
510 }
511 }
512
513 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
514 self.latest_detailed_summary()
515 .unwrap_or_else(|| self.text().into())
516 }
517
518 fn latest_detailed_summary(&self) -> Option<SharedString> {
519 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
520 Some(text.clone())
521 } else {
522 None
523 }
524 }
525
526 pub fn message(&self, id: MessageId) -> Option<&Message> {
527 self.messages.iter().find(|message| message.id == id)
528 }
529
530 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
531 self.messages.iter()
532 }
533
534 pub fn is_generating(&self) -> bool {
535 !self.pending_completions.is_empty() || !self.all_tools_finished()
536 }
537
538 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
539 &self.tools
540 }
541
542 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
543 self.tool_use
544 .pending_tool_uses()
545 .into_iter()
546 .find(|tool_use| &tool_use.id == id)
547 }
548
549 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
550 self.tool_use
551 .pending_tool_uses()
552 .into_iter()
553 .filter(|tool_use| tool_use.status.needs_confirmation())
554 }
555
556 pub fn has_pending_tool_uses(&self) -> bool {
557 !self.tool_use.pending_tool_uses().is_empty()
558 }
559
560 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
561 self.checkpoints_by_message.get(&id).cloned()
562 }
563
564 pub fn restore_checkpoint(
565 &mut self,
566 checkpoint: ThreadCheckpoint,
567 cx: &mut Context<Self>,
568 ) -> Task<Result<()>> {
569 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
570 message_id: checkpoint.message_id,
571 });
572 cx.emit(ThreadEvent::CheckpointChanged);
573 cx.notify();
574
575 let git_store = self.project().read(cx).git_store().clone();
576 let restore = git_store.update(cx, |git_store, cx| {
577 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
578 });
579
580 cx.spawn(async move |this, cx| {
581 let result = restore.await;
582 this.update(cx, |this, cx| {
583 if let Err(err) = result.as_ref() {
584 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
585 message_id: checkpoint.message_id,
586 error: err.to_string(),
587 });
588 } else {
589 this.truncate(checkpoint.message_id, cx);
590 this.last_restore_checkpoint = None;
591 }
592 this.pending_checkpoint = None;
593 cx.emit(ThreadEvent::CheckpointChanged);
594 cx.notify();
595 })?;
596 result
597 })
598 }
599
600 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
601 let pending_checkpoint = if self.is_generating() {
602 return;
603 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
604 checkpoint
605 } else {
606 return;
607 };
608
609 let git_store = self.project.read(cx).git_store().clone();
610 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
611 cx.spawn(async move |this, cx| match final_checkpoint.await {
612 Ok(final_checkpoint) => {
613 let equal = git_store
614 .update(cx, |store, cx| {
615 store.compare_checkpoints(
616 pending_checkpoint.git_checkpoint.clone(),
617 final_checkpoint.clone(),
618 cx,
619 )
620 })?
621 .await
622 .unwrap_or(false);
623
624 if !equal {
625 this.update(cx, |this, cx| {
626 this.insert_checkpoint(pending_checkpoint, cx)
627 })?;
628 }
629
630 Ok(())
631 }
632 Err(_) => this.update(cx, |this, cx| {
633 this.insert_checkpoint(pending_checkpoint, cx)
634 }),
635 })
636 .detach();
637 }
638
639 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
640 self.checkpoints_by_message
641 .insert(checkpoint.message_id, checkpoint);
642 cx.emit(ThreadEvent::CheckpointChanged);
643 cx.notify();
644 }
645
646 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
647 self.last_restore_checkpoint.as_ref()
648 }
649
650 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
651 let Some(message_ix) = self
652 .messages
653 .iter()
654 .rposition(|message| message.id == message_id)
655 else {
656 return;
657 };
658 for deleted_message in self.messages.drain(message_ix..) {
659 self.context_by_message.remove(&deleted_message.id);
660 self.checkpoints_by_message.remove(&deleted_message.id);
661 }
662 cx.notify();
663 }
664
665 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
666 self.context_by_message
667 .get(&id)
668 .into_iter()
669 .flat_map(|context| {
670 context
671 .iter()
672 .filter_map(|context_id| self.context.get(&context_id))
673 })
674 }
675
676 /// Returns whether all of the tool uses have finished running.
677 pub fn all_tools_finished(&self) -> bool {
678 // If the only pending tool uses left are the ones with errors, then
679 // that means that we've finished running all of the pending tools.
680 self.tool_use
681 .pending_tool_uses()
682 .iter()
683 .all(|tool_use| tool_use.status.is_error())
684 }
685
686 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
687 self.tool_use.tool_uses_for_message(id, cx)
688 }
689
690 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
691 self.tool_use.tool_results_for_message(id)
692 }
693
694 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
695 self.tool_use.tool_result(id)
696 }
697
698 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
699 Some(&self.tool_use.tool_result(id)?.content)
700 }
701
702 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
703 self.tool_use.tool_result_card(id).cloned()
704 }
705
706 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
707 self.tool_use.message_has_tool_results(message_id)
708 }
709
710 /// Filter out contexts that have already been included in previous messages
711 pub fn filter_new_context<'a>(
712 &self,
713 context: impl Iterator<Item = &'a AssistantContext>,
714 ) -> impl Iterator<Item = &'a AssistantContext> {
715 context.filter(|ctx| self.is_context_new(ctx))
716 }
717
718 fn is_context_new(&self, context: &AssistantContext) -> bool {
719 !self.context.contains_key(&context.id())
720 }
721
722 pub fn insert_user_message(
723 &mut self,
724 text: impl Into<String>,
725 context: Vec<AssistantContext>,
726 git_checkpoint: Option<GitStoreCheckpoint>,
727 cx: &mut Context<Self>,
728 ) -> MessageId {
729 let text = text.into();
730
731 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
732
733 let new_context: Vec<_> = context
734 .into_iter()
735 .filter(|ctx| self.is_context_new(ctx))
736 .collect();
737
738 if !new_context.is_empty() {
739 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
740 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
741 message.context = context_string;
742 }
743 }
744
745 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
746 message.images = new_context
747 .iter()
748 .filter_map(|context| {
749 if let AssistantContext::Image(image_context) = context {
750 image_context.image_task.clone().now_or_never().flatten()
751 } else {
752 None
753 }
754 })
755 .collect::<Vec<_>>();
756 }
757
758 self.action_log.update(cx, |log, cx| {
759 // Track all buffers added as context
760 for ctx in &new_context {
761 match ctx {
762 AssistantContext::File(file_ctx) => {
763 log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
764 }
765 AssistantContext::Directory(dir_ctx) => {
766 for context_buffer in &dir_ctx.context_buffers {
767 log.track_buffer(context_buffer.buffer.clone(), cx);
768 }
769 }
770 AssistantContext::Symbol(symbol_ctx) => {
771 log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
772 }
773 AssistantContext::Selection(selection_context) => {
774 log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
775 }
776 AssistantContext::FetchedUrl(_)
777 | AssistantContext::Thread(_)
778 | AssistantContext::Rules(_)
779 | AssistantContext::Image(_) => {}
780 }
781 }
782 });
783 }
784
785 let context_ids = new_context
786 .iter()
787 .map(|context| context.id())
788 .collect::<Vec<_>>();
789 self.context.extend(
790 new_context
791 .into_iter()
792 .map(|context| (context.id(), context)),
793 );
794 self.context_by_message.insert(message_id, context_ids);
795
796 if let Some(git_checkpoint) = git_checkpoint {
797 self.pending_checkpoint = Some(ThreadCheckpoint {
798 message_id,
799 git_checkpoint,
800 });
801 }
802
803 self.auto_capture_telemetry(cx);
804
805 message_id
806 }
807
808 pub fn insert_message(
809 &mut self,
810 role: Role,
811 segments: Vec<MessageSegment>,
812 cx: &mut Context<Self>,
813 ) -> MessageId {
814 let id = self.next_message_id.post_inc();
815 self.messages.push(Message {
816 id,
817 role,
818 segments,
819 context: String::new(),
820 images: Vec::new(),
821 });
822 self.touch_updated_at();
823 cx.emit(ThreadEvent::MessageAdded(id));
824 id
825 }
826
827 pub fn edit_message(
828 &mut self,
829 id: MessageId,
830 new_role: Role,
831 new_segments: Vec<MessageSegment>,
832 cx: &mut Context<Self>,
833 ) -> bool {
834 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
835 return false;
836 };
837 message.role = new_role;
838 message.segments = new_segments;
839 self.touch_updated_at();
840 cx.emit(ThreadEvent::MessageEdited(id));
841 true
842 }
843
844 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
845 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
846 return false;
847 };
848 self.messages.remove(index);
849 self.context_by_message.remove(&id);
850 self.touch_updated_at();
851 cx.emit(ThreadEvent::MessageDeleted(id));
852 true
853 }
854
855 /// Returns the representation of this [`Thread`] in a textual form.
856 ///
857 /// This is the representation we use when attaching a thread as context to another thread.
858 pub fn text(&self) -> String {
859 let mut text = String::new();
860
861 for message in &self.messages {
862 text.push_str(match message.role {
863 language_model::Role::User => "User:",
864 language_model::Role::Assistant => "Assistant:",
865 language_model::Role::System => "System:",
866 });
867 text.push('\n');
868
869 for segment in &message.segments {
870 match segment {
871 MessageSegment::Text(content) => text.push_str(content),
872 MessageSegment::Thinking { text: content, .. } => {
873 text.push_str(&format!("<think>{}</think>", content))
874 }
875 MessageSegment::RedactedThinking(_) => {}
876 }
877 }
878 text.push('\n');
879 }
880
881 text
882 }
883
884 /// Serializes this thread into a format for storage or telemetry.
885 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
886 let initial_project_snapshot = self.initial_project_snapshot.clone();
887 cx.spawn(async move |this, cx| {
888 let initial_project_snapshot = initial_project_snapshot.await;
889 this.read_with(cx, |this, cx| SerializedThread {
890 version: SerializedThread::VERSION.to_string(),
891 summary: this.summary_or_default(),
892 updated_at: this.updated_at(),
893 messages: this
894 .messages()
895 .map(|message| SerializedMessage {
896 id: message.id,
897 role: message.role,
898 segments: message
899 .segments
900 .iter()
901 .map(|segment| match segment {
902 MessageSegment::Text(text) => {
903 SerializedMessageSegment::Text { text: text.clone() }
904 }
905 MessageSegment::Thinking { text, signature } => {
906 SerializedMessageSegment::Thinking {
907 text: text.clone(),
908 signature: signature.clone(),
909 }
910 }
911 MessageSegment::RedactedThinking(data) => {
912 SerializedMessageSegment::RedactedThinking {
913 data: data.clone(),
914 }
915 }
916 })
917 .collect(),
918 tool_uses: this
919 .tool_uses_for_message(message.id, cx)
920 .into_iter()
921 .map(|tool_use| SerializedToolUse {
922 id: tool_use.id,
923 name: tool_use.name,
924 input: tool_use.input,
925 })
926 .collect(),
927 tool_results: this
928 .tool_results_for_message(message.id)
929 .into_iter()
930 .map(|tool_result| SerializedToolResult {
931 tool_use_id: tool_result.tool_use_id.clone(),
932 is_error: tool_result.is_error,
933 content: tool_result.content.clone(),
934 })
935 .collect(),
936 context: message.context.clone(),
937 })
938 .collect(),
939 initial_project_snapshot,
940 cumulative_token_usage: this.cumulative_token_usage,
941 request_token_usage: this.request_token_usage.clone(),
942 detailed_summary_state: this.detailed_summary_state.clone(),
943 exceeded_window_error: this.exceeded_window_error.clone(),
944 })
945 })
946 }
947
948 pub fn remaining_turns(&self) -> u32 {
949 self.remaining_turns
950 }
951
952 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
953 self.remaining_turns = remaining_turns;
954 }
955
956 pub fn send_to_model(
957 &mut self,
958 model: Arc<dyn LanguageModel>,
959 window: Option<AnyWindowHandle>,
960 cx: &mut Context<Self>,
961 ) {
962 if self.remaining_turns == 0 {
963 return;
964 }
965
966 self.remaining_turns -= 1;
967
968 let mut request = self.to_completion_request(cx);
969 if model.supports_tools() {
970 request.tools = {
971 let mut tools = Vec::new();
972 tools.extend(
973 self.tools()
974 .read(cx)
975 .enabled_tools(cx)
976 .into_iter()
977 .filter_map(|tool| {
978 // Skip tools that cannot be supported
979 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
980 Some(LanguageModelRequestTool {
981 name: tool.name(),
982 description: tool.description(),
983 input_schema,
984 })
985 }),
986 );
987
988 tools
989 };
990 }
991
992 self.stream_completion(request, model, window, cx);
993 }
994
995 pub fn used_tools_since_last_user_message(&self) -> bool {
996 for message in self.messages.iter().rev() {
997 if self.tool_use.message_has_tool_results(message.id) {
998 return true;
999 } else if message.role == Role::User {
1000 return false;
1001 }
1002 }
1003
1004 false
1005 }
1006
1007 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1008 let mut request = LanguageModelRequest {
1009 thread_id: Some(self.id.to_string()),
1010 prompt_id: Some(self.last_prompt_id.to_string()),
1011 messages: vec![],
1012 tools: Vec::new(),
1013 stop: Vec::new(),
1014 temperature: None,
1015 };
1016
1017 if let Some(project_context) = self.project_context.borrow().as_ref() {
1018 match self
1019 .prompt_builder
1020 .generate_assistant_system_prompt(project_context)
1021 {
1022 Err(err) => {
1023 let message = format!("{err:?}").into();
1024 log::error!("{message}");
1025 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1026 header: "Error generating system prompt".into(),
1027 message,
1028 }));
1029 }
1030 Ok(system_prompt) => {
1031 request.messages.push(LanguageModelRequestMessage {
1032 role: Role::System,
1033 content: vec![MessageContent::Text(system_prompt)],
1034 cache: true,
1035 });
1036 }
1037 }
1038 } else {
1039 let message = "Context for system prompt unexpectedly not ready.".into();
1040 log::error!("{message}");
1041 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1042 header: "Error generating system prompt".into(),
1043 message,
1044 }));
1045 }
1046
1047 for message in &self.messages {
1048 let mut request_message = LanguageModelRequestMessage {
1049 role: message.role,
1050 content: Vec::new(),
1051 cache: false,
1052 };
1053
1054 self.tool_use
1055 .attach_tool_results(message.id, &mut request_message);
1056
1057 if !message.context.is_empty() {
1058 request_message
1059 .content
1060 .push(MessageContent::Text(message.context.to_string()));
1061 }
1062
1063 if !message.images.is_empty() {
1064 // Some providers only support image parts after an initial text part
1065 if request_message.content.is_empty() {
1066 request_message
1067 .content
1068 .push(MessageContent::Text("Images attached by user:".to_string()));
1069 }
1070
1071 for image in &message.images {
1072 request_message
1073 .content
1074 .push(MessageContent::Image(image.clone()))
1075 }
1076 }
1077
1078 for segment in &message.segments {
1079 match segment {
1080 MessageSegment::Text(text) => {
1081 if !text.is_empty() {
1082 request_message
1083 .content
1084 .push(MessageContent::Text(text.into()));
1085 }
1086 }
1087 MessageSegment::Thinking { text, signature } => {
1088 if !text.is_empty() {
1089 request_message.content.push(MessageContent::Thinking {
1090 text: text.into(),
1091 signature: signature.clone(),
1092 });
1093 }
1094 }
1095 MessageSegment::RedactedThinking(data) => {
1096 request_message
1097 .content
1098 .push(MessageContent::RedactedThinking(data.clone()));
1099 }
1100 };
1101 }
1102
1103 self.tool_use
1104 .attach_tool_uses(message.id, &mut request_message);
1105
1106 request.messages.push(request_message);
1107 }
1108
1109 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1110 if let Some(last) = request.messages.last_mut() {
1111 last.cache = true;
1112 }
1113
1114 self.attached_tracked_files_state(&mut request.messages, cx);
1115
1116 request
1117 }
1118
1119 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1120 let mut request = LanguageModelRequest {
1121 thread_id: None,
1122 prompt_id: None,
1123 messages: vec![],
1124 tools: Vec::new(),
1125 stop: Vec::new(),
1126 temperature: None,
1127 };
1128
1129 for message in &self.messages {
1130 let mut request_message = LanguageModelRequestMessage {
1131 role: message.role,
1132 content: Vec::new(),
1133 cache: false,
1134 };
1135
1136 // Skip tool results during summarization.
1137 if self.tool_use.message_has_tool_results(message.id) {
1138 continue;
1139 }
1140
1141 for segment in &message.segments {
1142 match segment {
1143 MessageSegment::Text(text) => request_message
1144 .content
1145 .push(MessageContent::Text(text.clone())),
1146 MessageSegment::Thinking { .. } => {}
1147 MessageSegment::RedactedThinking(_) => {}
1148 }
1149 }
1150
1151 if request_message.content.is_empty() {
1152 continue;
1153 }
1154
1155 request.messages.push(request_message);
1156 }
1157
1158 request.messages.push(LanguageModelRequestMessage {
1159 role: Role::User,
1160 content: vec![MessageContent::Text(added_user_message)],
1161 cache: false,
1162 });
1163
1164 request
1165 }
1166
1167 fn attached_tracked_files_state(
1168 &self,
1169 messages: &mut Vec<LanguageModelRequestMessage>,
1170 cx: &App,
1171 ) {
1172 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1173
1174 let mut stale_message = String::new();
1175
1176 let action_log = self.action_log.read(cx);
1177
1178 for stale_file in action_log.stale_buffers(cx) {
1179 let Some(file) = stale_file.read(cx).file() else {
1180 continue;
1181 };
1182
1183 if stale_message.is_empty() {
1184 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1185 }
1186
1187 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1188 }
1189
1190 let mut content = Vec::with_capacity(2);
1191
1192 if !stale_message.is_empty() {
1193 content.push(stale_message.into());
1194 }
1195
1196 if !content.is_empty() {
1197 let context_message = LanguageModelRequestMessage {
1198 role: Role::User,
1199 content,
1200 cache: false,
1201 };
1202
1203 messages.push(context_message);
1204 }
1205 }
1206
1207 pub fn stream_completion(
1208 &mut self,
1209 request: LanguageModelRequest,
1210 model: Arc<dyn LanguageModel>,
1211 window: Option<AnyWindowHandle>,
1212 cx: &mut Context<Self>,
1213 ) {
1214 let pending_completion_id = post_inc(&mut self.completion_count);
1215 let mut request_callback_parameters = if self.request_callback.is_some() {
1216 Some((request.clone(), Vec::new()))
1217 } else {
1218 None
1219 };
1220 let prompt_id = self.last_prompt_id.clone();
1221 let tool_use_metadata = ToolUseMetadata {
1222 model: model.clone(),
1223 thread_id: self.id.clone(),
1224 prompt_id: prompt_id.clone(),
1225 };
1226
1227 let task = cx.spawn(async move |thread, cx| {
1228 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1229 let initial_token_usage =
1230 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1231 let stream_completion = async {
1232 let (mut events, usage) = stream_completion_future.await?;
1233
1234 let mut stop_reason = StopReason::EndTurn;
1235 let mut current_token_usage = TokenUsage::default();
1236
1237 if let Some(usage) = usage {
1238 thread
1239 .update(cx, |_thread, cx| {
1240 cx.emit(ThreadEvent::UsageUpdated(usage));
1241 })
1242 .ok();
1243 }
1244
1245 while let Some(event) = events.next().await {
1246 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1247 response_events
1248 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1249 }
1250
1251 let event = event?;
1252
1253 thread.update(cx, |thread, cx| {
1254 match event {
1255 LanguageModelCompletionEvent::StartMessage { .. } => {
1256 thread.insert_message(
1257 Role::Assistant,
1258 vec![MessageSegment::Text(String::new())],
1259 cx,
1260 );
1261 }
1262 LanguageModelCompletionEvent::Stop(reason) => {
1263 stop_reason = reason;
1264 }
1265 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1266 thread.update_token_usage_at_last_message(token_usage);
1267 thread.cumulative_token_usage = thread.cumulative_token_usage
1268 + token_usage
1269 - current_token_usage;
1270 current_token_usage = token_usage;
1271 }
1272 LanguageModelCompletionEvent::Text(chunk) => {
1273 cx.emit(ThreadEvent::ReceivedTextChunk);
1274 if let Some(last_message) = thread.messages.last_mut() {
1275 if last_message.role == Role::Assistant {
1276 last_message.push_text(&chunk);
1277 cx.emit(ThreadEvent::StreamedAssistantText(
1278 last_message.id,
1279 chunk,
1280 ));
1281 } else {
1282 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1283 // of a new Assistant response.
1284 //
1285 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1286 // will result in duplicating the text of the chunk in the rendered Markdown.
1287 thread.insert_message(
1288 Role::Assistant,
1289 vec![MessageSegment::Text(chunk.to_string())],
1290 cx,
1291 );
1292 };
1293 }
1294 }
1295 LanguageModelCompletionEvent::Thinking {
1296 text: chunk,
1297 signature,
1298 } => {
1299 if let Some(last_message) = thread.messages.last_mut() {
1300 if last_message.role == Role::Assistant {
1301 last_message.push_thinking(&chunk, signature);
1302 cx.emit(ThreadEvent::StreamedAssistantThinking(
1303 last_message.id,
1304 chunk,
1305 ));
1306 } else {
1307 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1308 // of a new Assistant response.
1309 //
1310 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1311 // will result in duplicating the text of the chunk in the rendered Markdown.
1312 thread.insert_message(
1313 Role::Assistant,
1314 vec![MessageSegment::Thinking {
1315 text: chunk.to_string(),
1316 signature,
1317 }],
1318 cx,
1319 );
1320 };
1321 }
1322 }
1323 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1324 let last_assistant_message_id = thread
1325 .messages
1326 .iter_mut()
1327 .rfind(|message| message.role == Role::Assistant)
1328 .map(|message| message.id)
1329 .unwrap_or_else(|| {
1330 thread.insert_message(Role::Assistant, vec![], cx)
1331 });
1332
1333 let tool_use_id = tool_use.id.clone();
1334 let streamed_input = if tool_use.is_input_complete {
1335 None
1336 } else {
1337 Some((&tool_use.input).clone())
1338 };
1339
1340 let ui_text = thread.tool_use.request_tool_use(
1341 last_assistant_message_id,
1342 tool_use,
1343 tool_use_metadata.clone(),
1344 cx,
1345 );
1346
1347 if let Some(input) = streamed_input {
1348 cx.emit(ThreadEvent::StreamedToolUse {
1349 tool_use_id,
1350 ui_text,
1351 input,
1352 });
1353 }
1354 }
1355 }
1356
1357 thread.touch_updated_at();
1358 cx.emit(ThreadEvent::StreamedCompletion);
1359 cx.notify();
1360
1361 thread.auto_capture_telemetry(cx);
1362 })?;
1363
1364 smol::future::yield_now().await;
1365 }
1366
1367 thread.update(cx, |thread, cx| {
1368 thread
1369 .pending_completions
1370 .retain(|completion| completion.id != pending_completion_id);
1371
1372 // If there is a response without tool use, summarize the message. Otherwise,
1373 // allow two tool uses before summarizing.
1374 if thread.summary.is_none()
1375 && thread.messages.len() >= 2
1376 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1377 {
1378 thread.summarize(cx);
1379 }
1380 })?;
1381
1382 anyhow::Ok(stop_reason)
1383 };
1384
1385 let result = stream_completion.await;
1386
1387 thread
1388 .update(cx, |thread, cx| {
1389 thread.finalize_pending_checkpoint(cx);
1390 match result.as_ref() {
1391 Ok(stop_reason) => match stop_reason {
1392 StopReason::ToolUse => {
1393 let tool_uses = thread.use_pending_tools(window, cx);
1394 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1395 }
1396 StopReason::EndTurn => {}
1397 StopReason::MaxTokens => {}
1398 },
1399 Err(error) => {
1400 if error.is::<PaymentRequiredError>() {
1401 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1402 } else if error.is::<MaxMonthlySpendReachedError>() {
1403 cx.emit(ThreadEvent::ShowError(
1404 ThreadError::MaxMonthlySpendReached,
1405 ));
1406 } else if let Some(error) =
1407 error.downcast_ref::<ModelRequestLimitReachedError>()
1408 {
1409 cx.emit(ThreadEvent::ShowError(
1410 ThreadError::ModelRequestLimitReached { plan: error.plan },
1411 ));
1412 } else if let Some(known_error) =
1413 error.downcast_ref::<LanguageModelKnownError>()
1414 {
1415 match known_error {
1416 LanguageModelKnownError::ContextWindowLimitExceeded {
1417 tokens,
1418 } => {
1419 thread.exceeded_window_error = Some(ExceededWindowError {
1420 model_id: model.id(),
1421 token_count: *tokens,
1422 });
1423 cx.notify();
1424 }
1425 }
1426 } else {
1427 let error_message = error
1428 .chain()
1429 .map(|err| err.to_string())
1430 .collect::<Vec<_>>()
1431 .join("\n");
1432 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1433 header: "Error interacting with language model".into(),
1434 message: SharedString::from(error_message.clone()),
1435 }));
1436 }
1437
1438 thread.cancel_last_completion(window, cx);
1439 }
1440 }
1441 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1442
1443 if let Some((request_callback, (request, response_events))) = thread
1444 .request_callback
1445 .as_mut()
1446 .zip(request_callback_parameters.as_ref())
1447 {
1448 request_callback(request, response_events);
1449 }
1450
1451 thread.auto_capture_telemetry(cx);
1452
1453 if let Ok(initial_usage) = initial_token_usage {
1454 let usage = thread.cumulative_token_usage - initial_usage;
1455
1456 telemetry::event!(
1457 "Assistant Thread Completion",
1458 thread_id = thread.id().to_string(),
1459 prompt_id = prompt_id,
1460 model = model.telemetry_id(),
1461 model_provider = model.provider_id().to_string(),
1462 input_tokens = usage.input_tokens,
1463 output_tokens = usage.output_tokens,
1464 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1465 cache_read_input_tokens = usage.cache_read_input_tokens,
1466 );
1467 }
1468 })
1469 .ok();
1470 });
1471
1472 self.pending_completions.push(PendingCompletion {
1473 id: pending_completion_id,
1474 _task: task,
1475 });
1476 }
1477
1478 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1479 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1480 return;
1481 };
1482
1483 if !model.provider.is_authenticated(cx) {
1484 return;
1485 }
1486
1487 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1488 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1489 If the conversation is about a specific subject, include it in the title. \
1490 Be descriptive. DO NOT speak in the first person.";
1491
1492 let request = self.to_summarize_request(added_user_message.into());
1493
1494 self.pending_summary = cx.spawn(async move |this, cx| {
1495 async move {
1496 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1497 let (mut messages, usage) = stream.await?;
1498
1499 if let Some(usage) = usage {
1500 this.update(cx, |_thread, cx| {
1501 cx.emit(ThreadEvent::UsageUpdated(usage));
1502 })
1503 .ok();
1504 }
1505
1506 let mut new_summary = String::new();
1507 while let Some(message) = messages.stream.next().await {
1508 let text = message?;
1509 let mut lines = text.lines();
1510 new_summary.extend(lines.next());
1511
1512 // Stop if the LLM generated multiple lines.
1513 if lines.next().is_some() {
1514 break;
1515 }
1516 }
1517
1518 this.update(cx, |this, cx| {
1519 if !new_summary.is_empty() {
1520 this.summary = Some(new_summary.into());
1521 }
1522
1523 cx.emit(ThreadEvent::SummaryGenerated);
1524 })?;
1525
1526 anyhow::Ok(())
1527 }
1528 .log_err()
1529 .await
1530 });
1531 }
1532
1533 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1534 let last_message_id = self.messages.last().map(|message| message.id)?;
1535
1536 match &self.detailed_summary_state {
1537 DetailedSummaryState::Generating { message_id, .. }
1538 | DetailedSummaryState::Generated { message_id, .. }
1539 if *message_id == last_message_id =>
1540 {
1541 // Already up-to-date
1542 return None;
1543 }
1544 _ => {}
1545 }
1546
1547 let ConfiguredModel { model, provider } =
1548 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1549
1550 if !provider.is_authenticated(cx) {
1551 return None;
1552 }
1553
1554 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1555 1. A brief overview of what was discussed\n\
1556 2. Key facts or information discovered\n\
1557 3. Outcomes or conclusions reached\n\
1558 4. Any action items or next steps if any\n\
1559 Format it in Markdown with headings and bullet points.";
1560
1561 let request = self.to_summarize_request(added_user_message.into());
1562
1563 let task = cx.spawn(async move |thread, cx| {
1564 let stream = model.stream_completion_text(request, &cx);
1565 let Some(mut messages) = stream.await.log_err() else {
1566 thread
1567 .update(cx, |this, _cx| {
1568 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1569 })
1570 .log_err();
1571
1572 return;
1573 };
1574
1575 let mut new_detailed_summary = String::new();
1576
1577 while let Some(chunk) = messages.stream.next().await {
1578 if let Some(chunk) = chunk.log_err() {
1579 new_detailed_summary.push_str(&chunk);
1580 }
1581 }
1582
1583 thread
1584 .update(cx, |this, _cx| {
1585 this.detailed_summary_state = DetailedSummaryState::Generated {
1586 text: new_detailed_summary.into(),
1587 message_id: last_message_id,
1588 };
1589 })
1590 .log_err();
1591 });
1592
1593 self.detailed_summary_state = DetailedSummaryState::Generating {
1594 message_id: last_message_id,
1595 };
1596
1597 Some(task)
1598 }
1599
1600 pub fn is_generating_detailed_summary(&self) -> bool {
1601 matches!(
1602 self.detailed_summary_state,
1603 DetailedSummaryState::Generating { .. }
1604 )
1605 }
1606
1607 pub fn use_pending_tools(
1608 &mut self,
1609 window: Option<AnyWindowHandle>,
1610 cx: &mut Context<Self>,
1611 ) -> Vec<PendingToolUse> {
1612 self.auto_capture_telemetry(cx);
1613 let request = self.to_completion_request(cx);
1614 let messages = Arc::new(request.messages);
1615 let pending_tool_uses = self
1616 .tool_use
1617 .pending_tool_uses()
1618 .into_iter()
1619 .filter(|tool_use| tool_use.status.is_idle())
1620 .cloned()
1621 .collect::<Vec<_>>();
1622
1623 for tool_use in pending_tool_uses.iter() {
1624 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1625 if tool.needs_confirmation(&tool_use.input, cx)
1626 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1627 {
1628 self.tool_use.confirm_tool_use(
1629 tool_use.id.clone(),
1630 tool_use.ui_text.clone(),
1631 tool_use.input.clone(),
1632 messages.clone(),
1633 tool,
1634 );
1635 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1636 } else {
1637 self.run_tool(
1638 tool_use.id.clone(),
1639 tool_use.ui_text.clone(),
1640 tool_use.input.clone(),
1641 &messages,
1642 tool,
1643 window,
1644 cx,
1645 );
1646 }
1647 }
1648 }
1649
1650 pending_tool_uses
1651 }
1652
1653 pub fn run_tool(
1654 &mut self,
1655 tool_use_id: LanguageModelToolUseId,
1656 ui_text: impl Into<SharedString>,
1657 input: serde_json::Value,
1658 messages: &[LanguageModelRequestMessage],
1659 tool: Arc<dyn Tool>,
1660 window: Option<AnyWindowHandle>,
1661 cx: &mut Context<Thread>,
1662 ) {
1663 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1664 self.tool_use
1665 .run_pending_tool(tool_use_id, ui_text.into(), task);
1666 }
1667
1668 fn spawn_tool_use(
1669 &mut self,
1670 tool_use_id: LanguageModelToolUseId,
1671 messages: &[LanguageModelRequestMessage],
1672 input: serde_json::Value,
1673 tool: Arc<dyn Tool>,
1674 window: Option<AnyWindowHandle>,
1675 cx: &mut Context<Thread>,
1676 ) -> Task<()> {
1677 let tool_name: Arc<str> = tool.name().into();
1678
1679 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1680 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1681 } else {
1682 tool.run(
1683 input,
1684 messages,
1685 self.project.clone(),
1686 self.action_log.clone(),
1687 window,
1688 cx,
1689 )
1690 };
1691
1692 // Store the card separately if it exists
1693 if let Some(card) = tool_result.card.clone() {
1694 self.tool_use
1695 .insert_tool_result_card(tool_use_id.clone(), card);
1696 }
1697
1698 cx.spawn({
1699 async move |thread: WeakEntity<Thread>, cx| {
1700 let output = tool_result.output.await;
1701
1702 thread
1703 .update(cx, |thread, cx| {
1704 let pending_tool_use = thread.tool_use.insert_tool_output(
1705 tool_use_id.clone(),
1706 tool_name,
1707 output,
1708 cx,
1709 );
1710 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1711 })
1712 .ok();
1713 }
1714 })
1715 }
1716
1717 fn tool_finished(
1718 &mut self,
1719 tool_use_id: LanguageModelToolUseId,
1720 pending_tool_use: Option<PendingToolUse>,
1721 canceled: bool,
1722 window: Option<AnyWindowHandle>,
1723 cx: &mut Context<Self>,
1724 ) {
1725 if self.all_tools_finished() {
1726 let model_registry = LanguageModelRegistry::read_global(cx);
1727 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1728 self.attach_tool_results(cx);
1729 if !canceled {
1730 self.send_to_model(model, window, cx);
1731 }
1732 }
1733 }
1734
1735 cx.emit(ThreadEvent::ToolFinished {
1736 tool_use_id,
1737 pending_tool_use,
1738 });
1739 }
1740
1741 /// Insert an empty message to be populated with tool results upon send.
1742 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1743 // Tool results are assumed to be waiting on the next message id, so they will populate
1744 // this empty message before sending to model. Would prefer this to be more straightforward.
1745 self.insert_message(Role::User, vec![], cx);
1746 self.auto_capture_telemetry(cx);
1747 }
1748
1749 /// Cancels the last pending completion, if there are any pending.
1750 ///
1751 /// Returns whether a completion was canceled.
1752 pub fn cancel_last_completion(
1753 &mut self,
1754 window: Option<AnyWindowHandle>,
1755 cx: &mut Context<Self>,
1756 ) -> bool {
1757 let canceled = if self.pending_completions.pop().is_some() {
1758 true
1759 } else {
1760 let mut canceled = false;
1761 for pending_tool_use in self.tool_use.cancel_pending() {
1762 canceled = true;
1763 self.tool_finished(
1764 pending_tool_use.id.clone(),
1765 Some(pending_tool_use),
1766 true,
1767 window,
1768 cx,
1769 );
1770 }
1771 canceled
1772 };
1773 self.finalize_pending_checkpoint(cx);
1774 canceled
1775 }
1776
1777 pub fn feedback(&self) -> Option<ThreadFeedback> {
1778 self.feedback
1779 }
1780
1781 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1782 self.message_feedback.get(&message_id).copied()
1783 }
1784
1785 pub fn report_message_feedback(
1786 &mut self,
1787 message_id: MessageId,
1788 feedback: ThreadFeedback,
1789 cx: &mut Context<Self>,
1790 ) -> Task<Result<()>> {
1791 if self.message_feedback.get(&message_id) == Some(&feedback) {
1792 return Task::ready(Ok(()));
1793 }
1794
1795 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1796 let serialized_thread = self.serialize(cx);
1797 let thread_id = self.id().clone();
1798 let client = self.project.read(cx).client();
1799
1800 let enabled_tool_names: Vec<String> = self
1801 .tools()
1802 .read(cx)
1803 .enabled_tools(cx)
1804 .iter()
1805 .map(|tool| tool.name().to_string())
1806 .collect();
1807
1808 self.message_feedback.insert(message_id, feedback);
1809
1810 cx.notify();
1811
1812 let message_content = self
1813 .message(message_id)
1814 .map(|msg| msg.to_string())
1815 .unwrap_or_default();
1816
1817 cx.background_spawn(async move {
1818 let final_project_snapshot = final_project_snapshot.await;
1819 let serialized_thread = serialized_thread.await?;
1820 let thread_data =
1821 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1822
1823 let rating = match feedback {
1824 ThreadFeedback::Positive => "positive",
1825 ThreadFeedback::Negative => "negative",
1826 };
1827 telemetry::event!(
1828 "Assistant Thread Rated",
1829 rating,
1830 thread_id,
1831 enabled_tool_names,
1832 message_id = message_id.0,
1833 message_content,
1834 thread_data,
1835 final_project_snapshot
1836 );
1837 client.telemetry().flush_events().await;
1838
1839 Ok(())
1840 })
1841 }
1842
1843 pub fn report_feedback(
1844 &mut self,
1845 feedback: ThreadFeedback,
1846 cx: &mut Context<Self>,
1847 ) -> Task<Result<()>> {
1848 let last_assistant_message_id = self
1849 .messages
1850 .iter()
1851 .rev()
1852 .find(|msg| msg.role == Role::Assistant)
1853 .map(|msg| msg.id);
1854
1855 if let Some(message_id) = last_assistant_message_id {
1856 self.report_message_feedback(message_id, feedback, cx)
1857 } else {
1858 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1859 let serialized_thread = self.serialize(cx);
1860 let thread_id = self.id().clone();
1861 let client = self.project.read(cx).client();
1862 self.feedback = Some(feedback);
1863 cx.notify();
1864
1865 cx.background_spawn(async move {
1866 let final_project_snapshot = final_project_snapshot.await;
1867 let serialized_thread = serialized_thread.await?;
1868 let thread_data = serde_json::to_value(serialized_thread)
1869 .unwrap_or_else(|_| serde_json::Value::Null);
1870
1871 let rating = match feedback {
1872 ThreadFeedback::Positive => "positive",
1873 ThreadFeedback::Negative => "negative",
1874 };
1875 telemetry::event!(
1876 "Assistant Thread Rated",
1877 rating,
1878 thread_id,
1879 thread_data,
1880 final_project_snapshot
1881 );
1882 client.telemetry().flush_events().await;
1883
1884 Ok(())
1885 })
1886 }
1887 }
1888
1889 /// Create a snapshot of the current project state including git information and unsaved buffers.
1890 fn project_snapshot(
1891 project: Entity<Project>,
1892 cx: &mut Context<Self>,
1893 ) -> Task<Arc<ProjectSnapshot>> {
1894 let git_store = project.read(cx).git_store().clone();
1895 let worktree_snapshots: Vec<_> = project
1896 .read(cx)
1897 .visible_worktrees(cx)
1898 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1899 .collect();
1900
1901 cx.spawn(async move |_, cx| {
1902 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1903
1904 let mut unsaved_buffers = Vec::new();
1905 cx.update(|app_cx| {
1906 let buffer_store = project.read(app_cx).buffer_store();
1907 for buffer_handle in buffer_store.read(app_cx).buffers() {
1908 let buffer = buffer_handle.read(app_cx);
1909 if buffer.is_dirty() {
1910 if let Some(file) = buffer.file() {
1911 let path = file.path().to_string_lossy().to_string();
1912 unsaved_buffers.push(path);
1913 }
1914 }
1915 }
1916 })
1917 .ok();
1918
1919 Arc::new(ProjectSnapshot {
1920 worktree_snapshots,
1921 unsaved_buffer_paths: unsaved_buffers,
1922 timestamp: Utc::now(),
1923 })
1924 })
1925 }
1926
1927 fn worktree_snapshot(
1928 worktree: Entity<project::Worktree>,
1929 git_store: Entity<GitStore>,
1930 cx: &App,
1931 ) -> Task<WorktreeSnapshot> {
1932 cx.spawn(async move |cx| {
1933 // Get worktree path and snapshot
1934 let worktree_info = cx.update(|app_cx| {
1935 let worktree = worktree.read(app_cx);
1936 let path = worktree.abs_path().to_string_lossy().to_string();
1937 let snapshot = worktree.snapshot();
1938 (path, snapshot)
1939 });
1940
1941 let Ok((worktree_path, _snapshot)) = worktree_info else {
1942 return WorktreeSnapshot {
1943 worktree_path: String::new(),
1944 git_state: None,
1945 };
1946 };
1947
1948 let git_state = git_store
1949 .update(cx, |git_store, cx| {
1950 git_store
1951 .repositories()
1952 .values()
1953 .find(|repo| {
1954 repo.read(cx)
1955 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1956 .is_some()
1957 })
1958 .cloned()
1959 })
1960 .ok()
1961 .flatten()
1962 .map(|repo| {
1963 repo.update(cx, |repo, _| {
1964 let current_branch =
1965 repo.branch.as_ref().map(|branch| branch.name.to_string());
1966 repo.send_job(None, |state, _| async move {
1967 let RepositoryState::Local { backend, .. } = state else {
1968 return GitState {
1969 remote_url: None,
1970 head_sha: None,
1971 current_branch,
1972 diff: None,
1973 };
1974 };
1975
1976 let remote_url = backend.remote_url("origin");
1977 let head_sha = backend.head_sha();
1978 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1979
1980 GitState {
1981 remote_url,
1982 head_sha,
1983 current_branch,
1984 diff,
1985 }
1986 })
1987 })
1988 });
1989
1990 let git_state = match git_state {
1991 Some(git_state) => match git_state.ok() {
1992 Some(git_state) => git_state.await.ok(),
1993 None => None,
1994 },
1995 None => None,
1996 };
1997
1998 WorktreeSnapshot {
1999 worktree_path,
2000 git_state,
2001 }
2002 })
2003 }
2004
2005 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2006 let mut markdown = Vec::new();
2007
2008 if let Some(summary) = self.summary() {
2009 writeln!(markdown, "# {summary}\n")?;
2010 };
2011
2012 for message in self.messages() {
2013 writeln!(
2014 markdown,
2015 "## {role}\n",
2016 role = match message.role {
2017 Role::User => "User",
2018 Role::Assistant => "Assistant",
2019 Role::System => "System",
2020 }
2021 )?;
2022
2023 if !message.context.is_empty() {
2024 writeln!(markdown, "{}", message.context)?;
2025 }
2026
2027 for segment in &message.segments {
2028 match segment {
2029 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2030 MessageSegment::Thinking { text, .. } => {
2031 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2032 }
2033 MessageSegment::RedactedThinking(_) => {}
2034 }
2035 }
2036
2037 for tool_use in self.tool_uses_for_message(message.id, cx) {
2038 writeln!(
2039 markdown,
2040 "**Use Tool: {} ({})**",
2041 tool_use.name, tool_use.id
2042 )?;
2043 writeln!(markdown, "```json")?;
2044 writeln!(
2045 markdown,
2046 "{}",
2047 serde_json::to_string_pretty(&tool_use.input)?
2048 )?;
2049 writeln!(markdown, "```")?;
2050 }
2051
2052 for tool_result in self.tool_results_for_message(message.id) {
2053 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2054 if tool_result.is_error {
2055 write!(markdown, " (Error)")?;
2056 }
2057
2058 writeln!(markdown, "**\n")?;
2059 writeln!(markdown, "{}", tool_result.content)?;
2060 }
2061 }
2062
2063 Ok(String::from_utf8_lossy(&markdown).to_string())
2064 }
2065
2066 pub fn keep_edits_in_range(
2067 &mut self,
2068 buffer: Entity<language::Buffer>,
2069 buffer_range: Range<language::Anchor>,
2070 cx: &mut Context<Self>,
2071 ) {
2072 self.action_log.update(cx, |action_log, cx| {
2073 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2074 });
2075 }
2076
2077 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2078 self.action_log
2079 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2080 }
2081
2082 pub fn reject_edits_in_ranges(
2083 &mut self,
2084 buffer: Entity<language::Buffer>,
2085 buffer_ranges: Vec<Range<language::Anchor>>,
2086 cx: &mut Context<Self>,
2087 ) -> Task<Result<()>> {
2088 self.action_log.update(cx, |action_log, cx| {
2089 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2090 })
2091 }
2092
2093 pub fn action_log(&self) -> &Entity<ActionLog> {
2094 &self.action_log
2095 }
2096
2097 pub fn project(&self) -> &Entity<Project> {
2098 &self.project
2099 }
2100
2101 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2102 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2103 return;
2104 }
2105
2106 let now = Instant::now();
2107 if let Some(last) = self.last_auto_capture_at {
2108 if now.duration_since(last).as_secs() < 10 {
2109 return;
2110 }
2111 }
2112
2113 self.last_auto_capture_at = Some(now);
2114
2115 let thread_id = self.id().clone();
2116 let github_login = self
2117 .project
2118 .read(cx)
2119 .user_store()
2120 .read(cx)
2121 .current_user()
2122 .map(|user| user.github_login.clone());
2123 let client = self.project.read(cx).client().clone();
2124 let serialize_task = self.serialize(cx);
2125
2126 cx.background_executor()
2127 .spawn(async move {
2128 if let Ok(serialized_thread) = serialize_task.await {
2129 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2130 telemetry::event!(
2131 "Agent Thread Auto-Captured",
2132 thread_id = thread_id.to_string(),
2133 thread_data = thread_data,
2134 auto_capture_reason = "tracked_user",
2135 github_login = github_login
2136 );
2137
2138 client.telemetry().flush_events().await;
2139 }
2140 }
2141 })
2142 .detach();
2143 }
2144
2145 pub fn cumulative_token_usage(&self) -> TokenUsage {
2146 self.cumulative_token_usage
2147 }
2148
2149 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2150 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2151 return TotalTokenUsage::default();
2152 };
2153
2154 let max = model.model.max_token_count();
2155
2156 let index = self
2157 .messages
2158 .iter()
2159 .position(|msg| msg.id == message_id)
2160 .unwrap_or(0);
2161
2162 if index == 0 {
2163 return TotalTokenUsage { total: 0, max };
2164 }
2165
2166 let token_usage = &self
2167 .request_token_usage
2168 .get(index - 1)
2169 .cloned()
2170 .unwrap_or_default();
2171
2172 TotalTokenUsage {
2173 total: token_usage.total_tokens() as usize,
2174 max,
2175 }
2176 }
2177
2178 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2179 let model_registry = LanguageModelRegistry::read_global(cx);
2180 let Some(model) = model_registry.default_model() else {
2181 return TotalTokenUsage::default();
2182 };
2183
2184 let max = model.model.max_token_count();
2185
2186 if let Some(exceeded_error) = &self.exceeded_window_error {
2187 if model.model.id() == exceeded_error.model_id {
2188 return TotalTokenUsage {
2189 total: exceeded_error.token_count,
2190 max,
2191 };
2192 }
2193 }
2194
2195 let total = self
2196 .token_usage_at_last_message()
2197 .unwrap_or_default()
2198 .total_tokens() as usize;
2199
2200 TotalTokenUsage { total, max }
2201 }
2202
2203 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2204 self.request_token_usage
2205 .get(self.messages.len().saturating_sub(1))
2206 .or_else(|| self.request_token_usage.last())
2207 .cloned()
2208 }
2209
2210 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2211 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2212 self.request_token_usage
2213 .resize(self.messages.len(), placeholder);
2214
2215 if let Some(last) = self.request_token_usage.last_mut() {
2216 *last = token_usage;
2217 }
2218 }
2219
2220 pub fn deny_tool_use(
2221 &mut self,
2222 tool_use_id: LanguageModelToolUseId,
2223 tool_name: Arc<str>,
2224 window: Option<AnyWindowHandle>,
2225 cx: &mut Context<Self>,
2226 ) {
2227 let err = Err(anyhow::anyhow!(
2228 "Permission to run tool action denied by user"
2229 ));
2230
2231 self.tool_use
2232 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2233 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2234 }
2235}
2236
2237#[derive(Debug, Clone, Error)]
2238pub enum ThreadError {
2239 #[error("Payment required")]
2240 PaymentRequired,
2241 #[error("Max monthly spend reached")]
2242 MaxMonthlySpendReached,
2243 #[error("Model request limit reached")]
2244 ModelRequestLimitReached { plan: Plan },
2245 #[error("Message {header}: {message}")]
2246 Message {
2247 header: SharedString,
2248 message: SharedString,
2249 },
2250}
2251
2252#[derive(Debug, Clone)]
2253pub enum ThreadEvent {
2254 ShowError(ThreadError),
2255 UsageUpdated(RequestUsage),
2256 StreamedCompletion,
2257 ReceivedTextChunk,
2258 StreamedAssistantText(MessageId, String),
2259 StreamedAssistantThinking(MessageId, String),
2260 StreamedToolUse {
2261 tool_use_id: LanguageModelToolUseId,
2262 ui_text: Arc<str>,
2263 input: serde_json::Value,
2264 },
2265 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2266 MessageAdded(MessageId),
2267 MessageEdited(MessageId),
2268 MessageDeleted(MessageId),
2269 SummaryGenerated,
2270 SummaryChanged,
2271 UsePendingTools {
2272 tool_uses: Vec<PendingToolUse>,
2273 },
2274 ToolFinished {
2275 #[allow(unused)]
2276 tool_use_id: LanguageModelToolUseId,
2277 /// The pending tool use that corresponds to this tool.
2278 pending_tool_use: Option<PendingToolUse>,
2279 },
2280 CheckpointChanged,
2281 ToolConfirmationNeeded,
2282}
2283
2284impl EventEmitter<ThreadEvent> for Thread {}
2285
2286struct PendingCompletion {
2287 id: usize,
2288 _task: Task<()>,
2289}
2290
2291#[cfg(test)]
2292mod tests {
2293 use super::*;
2294 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2295 use assistant_settings::AssistantSettings;
2296 use context_server::ContextServerSettings;
2297 use editor::EditorSettings;
2298 use gpui::TestAppContext;
2299 use project::{FakeFs, Project};
2300 use prompt_store::PromptBuilder;
2301 use serde_json::json;
2302 use settings::{Settings, SettingsStore};
2303 use std::sync::Arc;
2304 use theme::ThemeSettings;
2305 use util::path;
2306 use workspace::Workspace;
2307
2308 #[gpui::test]
2309 async fn test_message_with_context(cx: &mut TestAppContext) {
2310 init_test_settings(cx);
2311
2312 let project = create_test_project(
2313 cx,
2314 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2315 )
2316 .await;
2317
2318 let (_workspace, _thread_store, thread, context_store) =
2319 setup_test_environment(cx, project.clone()).await;
2320
2321 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2322 .await
2323 .unwrap();
2324
2325 let context =
2326 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2327
2328 // Insert user message with context
2329 let message_id = thread.update(cx, |thread, cx| {
2330 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2331 });
2332
2333 // Check content and context in message object
2334 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2335
2336 // Use different path format strings based on platform for the test
2337 #[cfg(windows)]
2338 let path_part = r"test\code.rs";
2339 #[cfg(not(windows))]
2340 let path_part = "test/code.rs";
2341
2342 let expected_context = format!(
2343 r#"
2344<context>
2345The following items were attached by the user. You don't need to use other tools to read them.
2346
2347<files>
2348```rs {path_part}
2349fn main() {{
2350 println!("Hello, world!");
2351}}
2352```
2353</files>
2354</context>
2355"#
2356 );
2357
2358 assert_eq!(message.role, Role::User);
2359 assert_eq!(message.segments.len(), 1);
2360 assert_eq!(
2361 message.segments[0],
2362 MessageSegment::Text("Please explain this code".to_string())
2363 );
2364 assert_eq!(message.context, expected_context);
2365
2366 // Check message in request
2367 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2368
2369 assert_eq!(request.messages.len(), 2);
2370 let expected_full_message = format!("{}Please explain this code", expected_context);
2371 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2372 }
2373
2374 #[gpui::test]
2375 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2376 init_test_settings(cx);
2377
2378 let project = create_test_project(
2379 cx,
2380 json!({
2381 "file1.rs": "fn function1() {}\n",
2382 "file2.rs": "fn function2() {}\n",
2383 "file3.rs": "fn function3() {}\n",
2384 }),
2385 )
2386 .await;
2387
2388 let (_, _thread_store, thread, context_store) =
2389 setup_test_environment(cx, project.clone()).await;
2390
2391 // Open files individually
2392 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2393 .await
2394 .unwrap();
2395 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2396 .await
2397 .unwrap();
2398 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2399 .await
2400 .unwrap();
2401
2402 // Get the context objects
2403 let contexts = context_store.update(cx, |store, _| store.context().clone());
2404 assert_eq!(contexts.len(), 3);
2405
2406 // First message with context 1
2407 let message1_id = thread.update(cx, |thread, cx| {
2408 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2409 });
2410
2411 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2412 let message2_id = thread.update(cx, |thread, cx| {
2413 thread.insert_user_message(
2414 "Message 2",
2415 vec![contexts[0].clone(), contexts[1].clone()],
2416 None,
2417 cx,
2418 )
2419 });
2420
2421 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2422 let message3_id = thread.update(cx, |thread, cx| {
2423 thread.insert_user_message(
2424 "Message 3",
2425 vec![
2426 contexts[0].clone(),
2427 contexts[1].clone(),
2428 contexts[2].clone(),
2429 ],
2430 None,
2431 cx,
2432 )
2433 });
2434
2435 // Check what contexts are included in each message
2436 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2437 (
2438 thread.message(message1_id).unwrap().clone(),
2439 thread.message(message2_id).unwrap().clone(),
2440 thread.message(message3_id).unwrap().clone(),
2441 )
2442 });
2443
2444 // First message should include context 1
2445 assert!(message1.context.contains("file1.rs"));
2446
2447 // Second message should include only context 2 (not 1)
2448 assert!(!message2.context.contains("file1.rs"));
2449 assert!(message2.context.contains("file2.rs"));
2450
2451 // Third message should include only context 3 (not 1 or 2)
2452 assert!(!message3.context.contains("file1.rs"));
2453 assert!(!message3.context.contains("file2.rs"));
2454 assert!(message3.context.contains("file3.rs"));
2455
2456 // Check entire request to make sure all contexts are properly included
2457 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2458
2459 // The request should contain all 3 messages
2460 assert_eq!(request.messages.len(), 4);
2461
2462 // Check that the contexts are properly formatted in each message
2463 assert!(request.messages[1].string_contents().contains("file1.rs"));
2464 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2465 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2466
2467 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2468 assert!(request.messages[2].string_contents().contains("file2.rs"));
2469 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2470
2471 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2472 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2473 assert!(request.messages[3].string_contents().contains("file3.rs"));
2474 }
2475
2476 #[gpui::test]
2477 async fn test_message_without_files(cx: &mut TestAppContext) {
2478 init_test_settings(cx);
2479
2480 let project = create_test_project(
2481 cx,
2482 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2483 )
2484 .await;
2485
2486 let (_, _thread_store, thread, _context_store) =
2487 setup_test_environment(cx, project.clone()).await;
2488
2489 // Insert user message without any context (empty context vector)
2490 let message_id = thread.update(cx, |thread, cx| {
2491 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2492 });
2493
2494 // Check content and context in message object
2495 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2496
2497 // Context should be empty when no files are included
2498 assert_eq!(message.role, Role::User);
2499 assert_eq!(message.segments.len(), 1);
2500 assert_eq!(
2501 message.segments[0],
2502 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2503 );
2504 assert_eq!(message.context, "");
2505
2506 // Check message in request
2507 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2508
2509 assert_eq!(request.messages.len(), 2);
2510 assert_eq!(
2511 request.messages[1].string_contents(),
2512 "What is the best way to learn Rust?"
2513 );
2514
2515 // Add second message, also without context
2516 let message2_id = thread.update(cx, |thread, cx| {
2517 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2518 });
2519
2520 let message2 =
2521 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2522 assert_eq!(message2.context, "");
2523
2524 // Check that both messages appear in the request
2525 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2526
2527 assert_eq!(request.messages.len(), 3);
2528 assert_eq!(
2529 request.messages[1].string_contents(),
2530 "What is the best way to learn Rust?"
2531 );
2532 assert_eq!(
2533 request.messages[2].string_contents(),
2534 "Are there any good books?"
2535 );
2536 }
2537
2538 #[gpui::test]
2539 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2540 init_test_settings(cx);
2541
2542 let project = create_test_project(
2543 cx,
2544 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2545 )
2546 .await;
2547
2548 let (_workspace, _thread_store, thread, context_store) =
2549 setup_test_environment(cx, project.clone()).await;
2550
2551 // Open buffer and add it to context
2552 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2553 .await
2554 .unwrap();
2555
2556 let context =
2557 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2558
2559 // Insert user message with the buffer as context
2560 thread.update(cx, |thread, cx| {
2561 thread.insert_user_message("Explain this code", vec![context], None, cx)
2562 });
2563
2564 // Create a request and check that it doesn't have a stale buffer warning yet
2565 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2566
2567 // Make sure we don't have a stale file warning yet
2568 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2569 msg.string_contents()
2570 .contains("These files changed since last read:")
2571 });
2572 assert!(
2573 !has_stale_warning,
2574 "Should not have stale buffer warning before buffer is modified"
2575 );
2576
2577 // Modify the buffer
2578 buffer.update(cx, |buffer, cx| {
2579 // Find a position at the end of line 1
2580 buffer.edit(
2581 [(1..1, "\n println!(\"Added a new line\");\n")],
2582 None,
2583 cx,
2584 );
2585 });
2586
2587 // Insert another user message without context
2588 thread.update(cx, |thread, cx| {
2589 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2590 });
2591
2592 // Create a new request and check for the stale buffer warning
2593 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2594
2595 // We should have a stale file warning as the last message
2596 let last_message = new_request
2597 .messages
2598 .last()
2599 .expect("Request should have messages");
2600
2601 // The last message should be the stale buffer notification
2602 assert_eq!(last_message.role, Role::User);
2603
2604 // Check the exact content of the message
2605 let expected_content = "These files changed since last read:\n- code.rs\n";
2606 assert_eq!(
2607 last_message.string_contents(),
2608 expected_content,
2609 "Last message should be exactly the stale buffer notification"
2610 );
2611 }
2612
2613 fn init_test_settings(cx: &mut TestAppContext) {
2614 cx.update(|cx| {
2615 let settings_store = SettingsStore::test(cx);
2616 cx.set_global(settings_store);
2617 language::init(cx);
2618 Project::init_settings(cx);
2619 AssistantSettings::register(cx);
2620 prompt_store::init(cx);
2621 thread_store::init(cx);
2622 workspace::init_settings(cx);
2623 ThemeSettings::register(cx);
2624 ContextServerSettings::register(cx);
2625 EditorSettings::register(cx);
2626 });
2627 }
2628
2629 // Helper to create a test project with test files
2630 async fn create_test_project(
2631 cx: &mut TestAppContext,
2632 files: serde_json::Value,
2633 ) -> Entity<Project> {
2634 let fs = FakeFs::new(cx.executor());
2635 fs.insert_tree(path!("/test"), files).await;
2636 Project::test(fs, [path!("/test").as_ref()], cx).await
2637 }
2638
2639 async fn setup_test_environment(
2640 cx: &mut TestAppContext,
2641 project: Entity<Project>,
2642 ) -> (
2643 Entity<Workspace>,
2644 Entity<ThreadStore>,
2645 Entity<Thread>,
2646 Entity<ContextStore>,
2647 ) {
2648 let (workspace, cx) =
2649 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2650
2651 let thread_store = cx
2652 .update(|_, cx| {
2653 ThreadStore::load(
2654 project.clone(),
2655 cx.new(|_| ToolWorkingSet::default()),
2656 Arc::new(PromptBuilder::new(None).unwrap()),
2657 cx,
2658 )
2659 })
2660 .await
2661 .unwrap();
2662
2663 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2664 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2665
2666 (workspace, thread_store, thread, context_store)
2667 }
2668
2669 async fn add_file_to_context(
2670 project: &Entity<Project>,
2671 context_store: &Entity<ContextStore>,
2672 path: &str,
2673 cx: &mut TestAppContext,
2674 ) -> Result<Entity<language::Buffer>> {
2675 let buffer_path = project
2676 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2677 .unwrap();
2678
2679 let buffer = project
2680 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2681 .await
2682 .unwrap();
2683
2684 context_store
2685 .update(cx, |store, cx| {
2686 store.add_file_from_buffer(buffer.clone(), cx)
2687 })
2688 .await?;
2689
2690 Ok(buffer)
2691 }
2692}