1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100}
101
102impl Message {
103 /// Returns whether the message contains any meaningful text that should be displayed
104 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
105 pub fn should_display_content(&self) -> bool {
106 self.segments.iter().all(|segment| segment.should_display())
107 }
108
109 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
110 if let Some(MessageSegment::Thinking {
111 text: segment,
112 signature: current_signature,
113 }) = self.segments.last_mut()
114 {
115 if let Some(signature) = signature {
116 *current_signature = Some(signature);
117 }
118 segment.push_str(text);
119 } else {
120 self.segments.push(MessageSegment::Thinking {
121 text: text.to_string(),
122 signature,
123 });
124 }
125 }
126
127 pub fn push_text(&mut self, text: &str) {
128 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
129 segment.push_str(text);
130 } else {
131 self.segments.push(MessageSegment::Text(text.to_string()));
132 }
133 }
134
135 pub fn to_string(&self) -> String {
136 let mut result = String::new();
137
138 if !self.context.is_empty() {
139 result.push_str(&self.context);
140 }
141
142 for segment in &self.segments {
143 match segment {
144 MessageSegment::Text(text) => result.push_str(text),
145 MessageSegment::Thinking { text, .. } => {
146 result.push_str("<think>\n");
147 result.push_str(text);
148 result.push_str("\n</think>");
149 }
150 MessageSegment::RedactedThinking(_) => {}
151 }
152 }
153
154 result
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum MessageSegment {
160 Text(String),
161 Thinking {
162 text: String,
163 signature: Option<String>,
164 },
165 RedactedThinking(Vec<u8>),
166}
167
168impl MessageSegment {
169 pub fn should_display(&self) -> bool {
170 // We add USING_TOOL_MARKER when making a request that includes tool uses
171 // without non-whitespace text around them, and this can cause the model
172 // to mimic the pattern, so we consider those segments not displayable.
173 match self {
174 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
175 Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 messages: Vec<Message>,
295 next_message_id: MessageId,
296 last_prompt_id: PromptId,
297 context: BTreeMap<ContextId, AssistantContext>,
298 context_by_message: HashMap<MessageId, Vec<ContextId>>,
299 project_context: SharedProjectContext,
300 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
301 completion_count: usize,
302 pending_completions: Vec<PendingCompletion>,
303 project: Entity<Project>,
304 prompt_builder: Arc<PromptBuilder>,
305 tools: Entity<ToolWorkingSet>,
306 tool_use: ToolUseState,
307 action_log: Entity<ActionLog>,
308 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
309 pending_checkpoint: Option<ThreadCheckpoint>,
310 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
311 request_token_usage: Vec<TokenUsage>,
312 cumulative_token_usage: TokenUsage,
313 exceeded_window_error: Option<ExceededWindowError>,
314 feedback: Option<ThreadFeedback>,
315 message_feedback: HashMap<MessageId, ThreadFeedback>,
316 last_auto_capture_at: Option<Instant>,
317 request_callback: Option<
318 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
319 >,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct ExceededWindowError {
324 /// Model used when last message exceeded context window
325 model_id: LanguageModelId,
326 /// Token count including last message
327 token_count: usize,
328}
329
330impl Thread {
331 pub fn new(
332 project: Entity<Project>,
333 tools: Entity<ToolWorkingSet>,
334 prompt_builder: Arc<PromptBuilder>,
335 system_prompt: SharedProjectContext,
336 cx: &mut Context<Self>,
337 ) -> Self {
338 Self {
339 id: ThreadId::new(),
340 updated_at: Utc::now(),
341 summary: None,
342 pending_summary: Task::ready(None),
343 detailed_summary_state: DetailedSummaryState::NotGenerated,
344 messages: Vec::new(),
345 next_message_id: MessageId(0),
346 last_prompt_id: PromptId::new(),
347 context: BTreeMap::default(),
348 context_by_message: HashMap::default(),
349 project_context: system_prompt,
350 checkpoints_by_message: HashMap::default(),
351 completion_count: 0,
352 pending_completions: Vec::new(),
353 project: project.clone(),
354 prompt_builder,
355 tools: tools.clone(),
356 last_restore_checkpoint: None,
357 pending_checkpoint: None,
358 tool_use: ToolUseState::new(tools.clone()),
359 action_log: cx.new(|_| ActionLog::new(project.clone())),
360 initial_project_snapshot: {
361 let project_snapshot = Self::project_snapshot(project, cx);
362 cx.foreground_executor()
363 .spawn(async move { Some(project_snapshot.await) })
364 .shared()
365 },
366 request_token_usage: Vec::new(),
367 cumulative_token_usage: TokenUsage::default(),
368 exceeded_window_error: None,
369 feedback: None,
370 message_feedback: HashMap::default(),
371 last_auto_capture_at: None,
372 request_callback: None,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use =
393 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
394
395 Self {
396 id,
397 updated_at: serialized.updated_at,
398 summary: Some(serialized.summary),
399 pending_summary: Task::ready(None),
400 detailed_summary_state: serialized.detailed_summary_state,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 context: message.context,
421 })
422 .collect(),
423 next_message_id,
424 last_prompt_id: PromptId::new(),
425 context: BTreeMap::default(),
426 context_by_message: HashMap::default(),
427 project_context,
428 checkpoints_by_message: HashMap::default(),
429 completion_count: 0,
430 pending_completions: Vec::new(),
431 last_restore_checkpoint: None,
432 pending_checkpoint: None,
433 project: project.clone(),
434 prompt_builder,
435 tools,
436 tool_use,
437 action_log: cx.new(|_| ActionLog::new(project)),
438 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
439 request_token_usage: serialized.request_token_usage,
440 cumulative_token_usage: serialized.cumulative_token_usage,
441 exceeded_window_error: None,
442 feedback: None,
443 message_feedback: HashMap::default(),
444 last_auto_capture_at: None,
445 request_callback: None,
446 }
447 }
448
449 pub fn set_request_callback(
450 &mut self,
451 callback: impl 'static
452 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
453 ) {
454 self.request_callback = Some(Box::new(callback));
455 }
456
457 pub fn id(&self) -> &ThreadId {
458 &self.id
459 }
460
461 pub fn is_empty(&self) -> bool {
462 self.messages.is_empty()
463 }
464
465 pub fn updated_at(&self) -> DateTime<Utc> {
466 self.updated_at
467 }
468
469 pub fn touch_updated_at(&mut self) {
470 self.updated_at = Utc::now();
471 }
472
473 pub fn advance_prompt_id(&mut self) {
474 self.last_prompt_id = PromptId::new();
475 }
476
477 pub fn summary(&self) -> Option<SharedString> {
478 self.summary.clone()
479 }
480
481 pub fn project_context(&self) -> SharedProjectContext {
482 self.project_context.clone()
483 }
484
485 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
486
487 pub fn summary_or_default(&self) -> SharedString {
488 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
489 }
490
491 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
492 let Some(current_summary) = &self.summary else {
493 // Don't allow setting summary until generated
494 return;
495 };
496
497 let mut new_summary = new_summary.into();
498
499 if new_summary.is_empty() {
500 new_summary = Self::DEFAULT_SUMMARY;
501 }
502
503 if current_summary != &new_summary {
504 self.summary = Some(new_summary);
505 cx.emit(ThreadEvent::SummaryChanged);
506 }
507 }
508
509 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
510 self.latest_detailed_summary()
511 .unwrap_or_else(|| self.text().into())
512 }
513
514 fn latest_detailed_summary(&self) -> Option<SharedString> {
515 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
516 Some(text.clone())
517 } else {
518 None
519 }
520 }
521
522 pub fn message(&self, id: MessageId) -> Option<&Message> {
523 self.messages.iter().find(|message| message.id == id)
524 }
525
526 pub fn messages(&self) -> impl Iterator<Item = &Message> {
527 self.messages.iter()
528 }
529
530 pub fn is_generating(&self) -> bool {
531 !self.pending_completions.is_empty() || !self.all_tools_finished()
532 }
533
534 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
535 &self.tools
536 }
537
538 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
539 self.tool_use
540 .pending_tool_uses()
541 .into_iter()
542 .find(|tool_use| &tool_use.id == id)
543 }
544
545 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
546 self.tool_use
547 .pending_tool_uses()
548 .into_iter()
549 .filter(|tool_use| tool_use.status.needs_confirmation())
550 }
551
552 pub fn has_pending_tool_uses(&self) -> bool {
553 !self.tool_use.pending_tool_uses().is_empty()
554 }
555
556 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
557 self.checkpoints_by_message.get(&id).cloned()
558 }
559
560 pub fn restore_checkpoint(
561 &mut self,
562 checkpoint: ThreadCheckpoint,
563 cx: &mut Context<Self>,
564 ) -> Task<Result<()>> {
565 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
566 message_id: checkpoint.message_id,
567 });
568 cx.emit(ThreadEvent::CheckpointChanged);
569 cx.notify();
570
571 let git_store = self.project().read(cx).git_store().clone();
572 let restore = git_store.update(cx, |git_store, cx| {
573 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
574 });
575
576 cx.spawn(async move |this, cx| {
577 let result = restore.await;
578 this.update(cx, |this, cx| {
579 if let Err(err) = result.as_ref() {
580 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
581 message_id: checkpoint.message_id,
582 error: err.to_string(),
583 });
584 } else {
585 this.truncate(checkpoint.message_id, cx);
586 this.last_restore_checkpoint = None;
587 }
588 this.pending_checkpoint = None;
589 cx.emit(ThreadEvent::CheckpointChanged);
590 cx.notify();
591 })?;
592 result
593 })
594 }
595
596 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
597 let pending_checkpoint = if self.is_generating() {
598 return;
599 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
600 checkpoint
601 } else {
602 return;
603 };
604
605 let git_store = self.project.read(cx).git_store().clone();
606 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
607 cx.spawn(async move |this, cx| match final_checkpoint.await {
608 Ok(final_checkpoint) => {
609 let equal = git_store
610 .update(cx, |store, cx| {
611 store.compare_checkpoints(
612 pending_checkpoint.git_checkpoint.clone(),
613 final_checkpoint.clone(),
614 cx,
615 )
616 })?
617 .await
618 .unwrap_or(false);
619
620 if equal {
621 git_store
622 .update(cx, |store, cx| {
623 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
624 })?
625 .detach();
626 } else {
627 this.update(cx, |this, cx| {
628 this.insert_checkpoint(pending_checkpoint, cx)
629 })?;
630 }
631
632 git_store
633 .update(cx, |store, cx| {
634 store.delete_checkpoint(final_checkpoint, cx)
635 })?
636 .detach();
637
638 Ok(())
639 }
640 Err(_) => this.update(cx, |this, cx| {
641 this.insert_checkpoint(pending_checkpoint, cx)
642 }),
643 })
644 .detach();
645 }
646
647 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
648 self.checkpoints_by_message
649 .insert(checkpoint.message_id, checkpoint);
650 cx.emit(ThreadEvent::CheckpointChanged);
651 cx.notify();
652 }
653
654 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
655 self.last_restore_checkpoint.as_ref()
656 }
657
658 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
659 let Some(message_ix) = self
660 .messages
661 .iter()
662 .rposition(|message| message.id == message_id)
663 else {
664 return;
665 };
666 for deleted_message in self.messages.drain(message_ix..) {
667 self.context_by_message.remove(&deleted_message.id);
668 self.checkpoints_by_message.remove(&deleted_message.id);
669 }
670 cx.notify();
671 }
672
673 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
674 self.context_by_message
675 .get(&id)
676 .into_iter()
677 .flat_map(|context| {
678 context
679 .iter()
680 .filter_map(|context_id| self.context.get(&context_id))
681 })
682 }
683
684 /// Returns whether all of the tool uses have finished running.
685 pub fn all_tools_finished(&self) -> bool {
686 // If the only pending tool uses left are the ones with errors, then
687 // that means that we've finished running all of the pending tools.
688 self.tool_use
689 .pending_tool_uses()
690 .iter()
691 .all(|tool_use| tool_use.status.is_error())
692 }
693
694 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
695 self.tool_use.tool_uses_for_message(id, cx)
696 }
697
698 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
699 self.tool_use.tool_results_for_message(id)
700 }
701
702 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
703 self.tool_use.tool_result(id)
704 }
705
706 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
707 Some(&self.tool_use.tool_result(id)?.content)
708 }
709
710 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
711 self.tool_use.tool_result_card(id).cloned()
712 }
713
714 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
715 self.tool_use.message_has_tool_results(message_id)
716 }
717
718 /// Filter out contexts that have already been included in previous messages
719 pub fn filter_new_context<'a>(
720 &self,
721 context: impl Iterator<Item = &'a AssistantContext>,
722 ) -> impl Iterator<Item = &'a AssistantContext> {
723 context.filter(|ctx| self.is_context_new(ctx))
724 }
725
726 fn is_context_new(&self, context: &AssistantContext) -> bool {
727 !self.context.contains_key(&context.id())
728 }
729
730 pub fn insert_user_message(
731 &mut self,
732 text: impl Into<String>,
733 context: Vec<AssistantContext>,
734 git_checkpoint: Option<GitStoreCheckpoint>,
735 cx: &mut Context<Self>,
736 ) -> MessageId {
737 let text = text.into();
738
739 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
740
741 let new_context: Vec<_> = context
742 .into_iter()
743 .filter(|ctx| self.is_context_new(ctx))
744 .collect();
745
746 if !new_context.is_empty() {
747 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
748 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
749 message.context = context_string;
750 }
751 }
752
753 self.action_log.update(cx, |log, cx| {
754 // Track all buffers added as context
755 for ctx in &new_context {
756 match ctx {
757 AssistantContext::File(file_ctx) => {
758 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
759 }
760 AssistantContext::Directory(dir_ctx) => {
761 for context_buffer in &dir_ctx.context_buffers {
762 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
763 }
764 }
765 AssistantContext::Symbol(symbol_ctx) => {
766 log.buffer_added_as_context(
767 symbol_ctx.context_symbol.buffer.clone(),
768 cx,
769 );
770 }
771 AssistantContext::Excerpt(excerpt_context) => {
772 log.buffer_added_as_context(
773 excerpt_context.context_buffer.buffer.clone(),
774 cx,
775 );
776 }
777 AssistantContext::FetchedUrl(_)
778 | AssistantContext::Thread(_)
779 | AssistantContext::Rules(_) => {}
780 }
781 }
782 });
783 }
784
785 let context_ids = new_context
786 .iter()
787 .map(|context| context.id())
788 .collect::<Vec<_>>();
789 self.context.extend(
790 new_context
791 .into_iter()
792 .map(|context| (context.id(), context)),
793 );
794 self.context_by_message.insert(message_id, context_ids);
795
796 if let Some(git_checkpoint) = git_checkpoint {
797 self.pending_checkpoint = Some(ThreadCheckpoint {
798 message_id,
799 git_checkpoint,
800 });
801 }
802
803 self.auto_capture_telemetry(cx);
804
805 message_id
806 }
807
808 pub fn insert_message(
809 &mut self,
810 role: Role,
811 segments: Vec<MessageSegment>,
812 cx: &mut Context<Self>,
813 ) -> MessageId {
814 let id = self.next_message_id.post_inc();
815 self.messages.push(Message {
816 id,
817 role,
818 segments,
819 context: String::new(),
820 });
821 self.touch_updated_at();
822 cx.emit(ThreadEvent::MessageAdded(id));
823 id
824 }
825
826 pub fn edit_message(
827 &mut self,
828 id: MessageId,
829 new_role: Role,
830 new_segments: Vec<MessageSegment>,
831 cx: &mut Context<Self>,
832 ) -> bool {
833 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
834 return false;
835 };
836 message.role = new_role;
837 message.segments = new_segments;
838 self.touch_updated_at();
839 cx.emit(ThreadEvent::MessageEdited(id));
840 true
841 }
842
843 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
844 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
845 return false;
846 };
847 self.messages.remove(index);
848 self.context_by_message.remove(&id);
849 self.touch_updated_at();
850 cx.emit(ThreadEvent::MessageDeleted(id));
851 true
852 }
853
854 /// Returns the representation of this [`Thread`] in a textual form.
855 ///
856 /// This is the representation we use when attaching a thread as context to another thread.
857 pub fn text(&self) -> String {
858 let mut text = String::new();
859
860 for message in &self.messages {
861 text.push_str(match message.role {
862 language_model::Role::User => "User:",
863 language_model::Role::Assistant => "Assistant:",
864 language_model::Role::System => "System:",
865 });
866 text.push('\n');
867
868 for segment in &message.segments {
869 match segment {
870 MessageSegment::Text(content) => text.push_str(content),
871 MessageSegment::Thinking { text: content, .. } => {
872 text.push_str(&format!("<think>{}</think>", content))
873 }
874 MessageSegment::RedactedThinking(_) => {}
875 }
876 }
877 text.push('\n');
878 }
879
880 text
881 }
882
883 /// Serializes this thread into a format for storage or telemetry.
884 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
885 let initial_project_snapshot = self.initial_project_snapshot.clone();
886 cx.spawn(async move |this, cx| {
887 let initial_project_snapshot = initial_project_snapshot.await;
888 this.read_with(cx, |this, cx| SerializedThread {
889 version: SerializedThread::VERSION.to_string(),
890 summary: this.summary_or_default(),
891 updated_at: this.updated_at(),
892 messages: this
893 .messages()
894 .map(|message| SerializedMessage {
895 id: message.id,
896 role: message.role,
897 segments: message
898 .segments
899 .iter()
900 .map(|segment| match segment {
901 MessageSegment::Text(text) => {
902 SerializedMessageSegment::Text { text: text.clone() }
903 }
904 MessageSegment::Thinking { text, signature } => {
905 SerializedMessageSegment::Thinking {
906 text: text.clone(),
907 signature: signature.clone(),
908 }
909 }
910 MessageSegment::RedactedThinking(data) => {
911 SerializedMessageSegment::RedactedThinking {
912 data: data.clone(),
913 }
914 }
915 })
916 .collect(),
917 tool_uses: this
918 .tool_uses_for_message(message.id, cx)
919 .into_iter()
920 .map(|tool_use| SerializedToolUse {
921 id: tool_use.id,
922 name: tool_use.name,
923 input: tool_use.input,
924 })
925 .collect(),
926 tool_results: this
927 .tool_results_for_message(message.id)
928 .into_iter()
929 .map(|tool_result| SerializedToolResult {
930 tool_use_id: tool_result.tool_use_id.clone(),
931 is_error: tool_result.is_error,
932 content: tool_result.content.clone(),
933 })
934 .collect(),
935 context: message.context.clone(),
936 })
937 .collect(),
938 initial_project_snapshot,
939 cumulative_token_usage: this.cumulative_token_usage,
940 request_token_usage: this.request_token_usage.clone(),
941 detailed_summary_state: this.detailed_summary_state.clone(),
942 exceeded_window_error: this.exceeded_window_error.clone(),
943 })
944 })
945 }
946
947 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
948 let mut request = self.to_completion_request(cx);
949 if model.supports_tools() {
950 request.tools = {
951 let mut tools = Vec::new();
952 tools.extend(
953 self.tools()
954 .read(cx)
955 .enabled_tools(cx)
956 .into_iter()
957 .filter_map(|tool| {
958 // Skip tools that cannot be supported
959 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
960 Some(LanguageModelRequestTool {
961 name: tool.name(),
962 description: tool.description(),
963 input_schema,
964 })
965 }),
966 );
967
968 tools
969 };
970 }
971
972 self.stream_completion(request, model, cx);
973 }
974
975 pub fn used_tools_since_last_user_message(&self) -> bool {
976 for message in self.messages.iter().rev() {
977 if self.tool_use.message_has_tool_results(message.id) {
978 return true;
979 } else if message.role == Role::User {
980 return false;
981 }
982 }
983
984 false
985 }
986
987 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
988 let mut request = LanguageModelRequest {
989 thread_id: Some(self.id.to_string()),
990 prompt_id: Some(self.last_prompt_id.to_string()),
991 messages: vec![],
992 tools: Vec::new(),
993 stop: Vec::new(),
994 temperature: None,
995 };
996
997 if let Some(project_context) = self.project_context.borrow().as_ref() {
998 match self
999 .prompt_builder
1000 .generate_assistant_system_prompt(project_context)
1001 {
1002 Err(err) => {
1003 let message = format!("{err:?}").into();
1004 log::error!("{message}");
1005 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1006 header: "Error generating system prompt".into(),
1007 message,
1008 }));
1009 }
1010 Ok(system_prompt) => {
1011 request.messages.push(LanguageModelRequestMessage {
1012 role: Role::System,
1013 content: vec![MessageContent::Text(system_prompt)],
1014 cache: true,
1015 });
1016 }
1017 }
1018 } else {
1019 let message = "Context for system prompt unexpectedly not ready.".into();
1020 log::error!("{message}");
1021 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1022 header: "Error generating system prompt".into(),
1023 message,
1024 }));
1025 }
1026
1027 for message in &self.messages {
1028 let mut request_message = LanguageModelRequestMessage {
1029 role: message.role,
1030 content: Vec::new(),
1031 cache: false,
1032 };
1033
1034 self.tool_use
1035 .attach_tool_results(message.id, &mut request_message);
1036
1037 if !message.context.is_empty() {
1038 request_message
1039 .content
1040 .push(MessageContent::Text(message.context.to_string()));
1041 }
1042
1043 for segment in &message.segments {
1044 match segment {
1045 MessageSegment::Text(text) => {
1046 if !text.is_empty() {
1047 request_message
1048 .content
1049 .push(MessageContent::Text(text.into()));
1050 }
1051 }
1052 MessageSegment::Thinking { text, signature } => {
1053 if !text.is_empty() {
1054 request_message.content.push(MessageContent::Thinking {
1055 text: text.into(),
1056 signature: signature.clone(),
1057 });
1058 }
1059 }
1060 MessageSegment::RedactedThinking(data) => {
1061 request_message
1062 .content
1063 .push(MessageContent::RedactedThinking(data.clone()));
1064 }
1065 };
1066 }
1067
1068 self.tool_use
1069 .attach_tool_uses(message.id, &mut request_message);
1070
1071 request.messages.push(request_message);
1072 }
1073
1074 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1075 if let Some(last) = request.messages.last_mut() {
1076 last.cache = true;
1077 }
1078
1079 self.attached_tracked_files_state(&mut request.messages, cx);
1080
1081 request
1082 }
1083
1084 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1085 let mut request = LanguageModelRequest {
1086 thread_id: None,
1087 prompt_id: None,
1088 messages: vec![],
1089 tools: Vec::new(),
1090 stop: Vec::new(),
1091 temperature: None,
1092 };
1093
1094 for message in &self.messages {
1095 let mut request_message = LanguageModelRequestMessage {
1096 role: message.role,
1097 content: Vec::new(),
1098 cache: false,
1099 };
1100
1101 // Skip tool results during summarization.
1102 if self.tool_use.message_has_tool_results(message.id) {
1103 continue;
1104 }
1105
1106 for segment in &message.segments {
1107 match segment {
1108 MessageSegment::Text(text) => request_message
1109 .content
1110 .push(MessageContent::Text(text.clone())),
1111 MessageSegment::Thinking { .. } => {}
1112 MessageSegment::RedactedThinking(_) => {}
1113 }
1114 }
1115
1116 if request_message.content.is_empty() {
1117 continue;
1118 }
1119
1120 request.messages.push(request_message);
1121 }
1122
1123 request.messages.push(LanguageModelRequestMessage {
1124 role: Role::User,
1125 content: vec![MessageContent::Text(added_user_message)],
1126 cache: false,
1127 });
1128
1129 request
1130 }
1131
1132 fn attached_tracked_files_state(
1133 &self,
1134 messages: &mut Vec<LanguageModelRequestMessage>,
1135 cx: &App,
1136 ) {
1137 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1138
1139 let mut stale_message = String::new();
1140
1141 let action_log = self.action_log.read(cx);
1142
1143 for stale_file in action_log.stale_buffers(cx) {
1144 let Some(file) = stale_file.read(cx).file() else {
1145 continue;
1146 };
1147
1148 if stale_message.is_empty() {
1149 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1150 }
1151
1152 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1153 }
1154
1155 let mut content = Vec::with_capacity(2);
1156
1157 if !stale_message.is_empty() {
1158 content.push(stale_message.into());
1159 }
1160
1161 if !content.is_empty() {
1162 let context_message = LanguageModelRequestMessage {
1163 role: Role::User,
1164 content,
1165 cache: false,
1166 };
1167
1168 messages.push(context_message);
1169 }
1170 }
1171
1172 pub fn stream_completion(
1173 &mut self,
1174 request: LanguageModelRequest,
1175 model: Arc<dyn LanguageModel>,
1176 cx: &mut Context<Self>,
1177 ) {
1178 let pending_completion_id = post_inc(&mut self.completion_count);
1179 let mut request_callback_parameters = if self.request_callback.is_some() {
1180 Some((request.clone(), Vec::new()))
1181 } else {
1182 None
1183 };
1184 let prompt_id = self.last_prompt_id.clone();
1185 let tool_use_metadata = ToolUseMetadata {
1186 model: model.clone(),
1187 thread_id: self.id.clone(),
1188 prompt_id: prompt_id.clone(),
1189 };
1190
1191 let task = cx.spawn(async move |thread, cx| {
1192 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1193 let initial_token_usage =
1194 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1195 let stream_completion = async {
1196 let (mut events, usage) = stream_completion_future.await?;
1197
1198 let mut stop_reason = StopReason::EndTurn;
1199 let mut current_token_usage = TokenUsage::default();
1200
1201 if let Some(usage) = usage {
1202 thread
1203 .update(cx, |_thread, cx| {
1204 cx.emit(ThreadEvent::UsageUpdated(usage));
1205 })
1206 .ok();
1207 }
1208
1209 while let Some(event) = events.next().await {
1210 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1211 response_events
1212 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1213 }
1214
1215 let event = event?;
1216
1217 thread.update(cx, |thread, cx| {
1218 match event {
1219 LanguageModelCompletionEvent::StartMessage { .. } => {
1220 thread.insert_message(
1221 Role::Assistant,
1222 vec![MessageSegment::Text(String::new())],
1223 cx,
1224 );
1225 }
1226 LanguageModelCompletionEvent::Stop(reason) => {
1227 stop_reason = reason;
1228 }
1229 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1230 thread.update_token_usage_at_last_message(token_usage);
1231 thread.cumulative_token_usage = thread.cumulative_token_usage
1232 + token_usage
1233 - current_token_usage;
1234 current_token_usage = token_usage;
1235 }
1236 LanguageModelCompletionEvent::Text(chunk) => {
1237 if let Some(last_message) = thread.messages.last_mut() {
1238 if last_message.role == Role::Assistant {
1239 last_message.push_text(&chunk);
1240 cx.emit(ThreadEvent::StreamedAssistantText(
1241 last_message.id,
1242 chunk,
1243 ));
1244 } else {
1245 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1246 // of a new Assistant response.
1247 //
1248 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1249 // will result in duplicating the text of the chunk in the rendered Markdown.
1250 thread.insert_message(
1251 Role::Assistant,
1252 vec![MessageSegment::Text(chunk.to_string())],
1253 cx,
1254 );
1255 };
1256 }
1257 }
1258 LanguageModelCompletionEvent::Thinking {
1259 text: chunk,
1260 signature,
1261 } => {
1262 if let Some(last_message) = thread.messages.last_mut() {
1263 if last_message.role == Role::Assistant {
1264 last_message.push_thinking(&chunk, signature);
1265 cx.emit(ThreadEvent::StreamedAssistantThinking(
1266 last_message.id,
1267 chunk,
1268 ));
1269 } else {
1270 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1271 // of a new Assistant response.
1272 //
1273 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1274 // will result in duplicating the text of the chunk in the rendered Markdown.
1275 thread.insert_message(
1276 Role::Assistant,
1277 vec![MessageSegment::Thinking {
1278 text: chunk.to_string(),
1279 signature,
1280 }],
1281 cx,
1282 );
1283 };
1284 }
1285 }
1286 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1287 let last_assistant_message_id = thread
1288 .messages
1289 .iter_mut()
1290 .rfind(|message| message.role == Role::Assistant)
1291 .map(|message| message.id)
1292 .unwrap_or_else(|| {
1293 thread.insert_message(Role::Assistant, vec![], cx)
1294 });
1295
1296 let tool_use_id = tool_use.id.clone();
1297 let streamed_input = if tool_use.is_input_complete {
1298 None
1299 } else {
1300 Some((&tool_use.input).clone())
1301 };
1302
1303 let ui_text = thread.tool_use.request_tool_use(
1304 last_assistant_message_id,
1305 tool_use,
1306 tool_use_metadata.clone(),
1307 cx,
1308 );
1309
1310 if let Some(input) = streamed_input {
1311 cx.emit(ThreadEvent::StreamedToolUse {
1312 tool_use_id,
1313 ui_text,
1314 input,
1315 });
1316 }
1317 }
1318 }
1319
1320 thread.touch_updated_at();
1321 cx.emit(ThreadEvent::StreamedCompletion);
1322 cx.notify();
1323
1324 thread.auto_capture_telemetry(cx);
1325 })?;
1326
1327 smol::future::yield_now().await;
1328 }
1329
1330 thread.update(cx, |thread, cx| {
1331 thread
1332 .pending_completions
1333 .retain(|completion| completion.id != pending_completion_id);
1334
1335 // If there is a response without tool use, summarize the message. Otherwise,
1336 // allow two tool uses before summarizing.
1337 if thread.summary.is_none()
1338 && thread.messages.len() >= 2
1339 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1340 {
1341 thread.summarize(cx);
1342 }
1343 })?;
1344
1345 anyhow::Ok(stop_reason)
1346 };
1347
1348 let result = stream_completion.await;
1349
1350 thread
1351 .update(cx, |thread, cx| {
1352 thread.finalize_pending_checkpoint(cx);
1353 match result.as_ref() {
1354 Ok(stop_reason) => match stop_reason {
1355 StopReason::ToolUse => {
1356 let tool_uses = thread.use_pending_tools(cx);
1357 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1358 }
1359 StopReason::EndTurn => {}
1360 StopReason::MaxTokens => {}
1361 },
1362 Err(error) => {
1363 if error.is::<PaymentRequiredError>() {
1364 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1365 } else if error.is::<MaxMonthlySpendReachedError>() {
1366 cx.emit(ThreadEvent::ShowError(
1367 ThreadError::MaxMonthlySpendReached,
1368 ));
1369 } else if let Some(error) =
1370 error.downcast_ref::<ModelRequestLimitReachedError>()
1371 {
1372 cx.emit(ThreadEvent::ShowError(
1373 ThreadError::ModelRequestLimitReached { plan: error.plan },
1374 ));
1375 } else if let Some(known_error) =
1376 error.downcast_ref::<LanguageModelKnownError>()
1377 {
1378 match known_error {
1379 LanguageModelKnownError::ContextWindowLimitExceeded {
1380 tokens,
1381 } => {
1382 thread.exceeded_window_error = Some(ExceededWindowError {
1383 model_id: model.id(),
1384 token_count: *tokens,
1385 });
1386 cx.notify();
1387 }
1388 }
1389 } else {
1390 let error_message = error
1391 .chain()
1392 .map(|err| err.to_string())
1393 .collect::<Vec<_>>()
1394 .join("\n");
1395 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1396 header: "Error interacting with language model".into(),
1397 message: SharedString::from(error_message.clone()),
1398 }));
1399 }
1400
1401 thread.cancel_last_completion(cx);
1402 }
1403 }
1404 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1405
1406 if let Some((request_callback, (request, response_events))) = thread
1407 .request_callback
1408 .as_mut()
1409 .zip(request_callback_parameters.as_ref())
1410 {
1411 request_callback(request, response_events);
1412 }
1413
1414 thread.auto_capture_telemetry(cx);
1415
1416 if let Ok(initial_usage) = initial_token_usage {
1417 let usage = thread.cumulative_token_usage - initial_usage;
1418
1419 telemetry::event!(
1420 "Assistant Thread Completion",
1421 thread_id = thread.id().to_string(),
1422 prompt_id = prompt_id,
1423 model = model.telemetry_id(),
1424 model_provider = model.provider_id().to_string(),
1425 input_tokens = usage.input_tokens,
1426 output_tokens = usage.output_tokens,
1427 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1428 cache_read_input_tokens = usage.cache_read_input_tokens,
1429 );
1430 }
1431 })
1432 .ok();
1433 });
1434
1435 self.pending_completions.push(PendingCompletion {
1436 id: pending_completion_id,
1437 _task: task,
1438 });
1439 }
1440
1441 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1442 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1443 return;
1444 };
1445
1446 if !model.provider.is_authenticated(cx) {
1447 return;
1448 }
1449
1450 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1451 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1452 If the conversation is about a specific subject, include it in the title. \
1453 Be descriptive. DO NOT speak in the first person.";
1454
1455 let request = self.to_summarize_request(added_user_message.into());
1456
1457 self.pending_summary = cx.spawn(async move |this, cx| {
1458 async move {
1459 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1460 let (mut messages, usage) = stream.await?;
1461
1462 if let Some(usage) = usage {
1463 this.update(cx, |_thread, cx| {
1464 cx.emit(ThreadEvent::UsageUpdated(usage));
1465 })
1466 .ok();
1467 }
1468
1469 let mut new_summary = String::new();
1470 while let Some(message) = messages.stream.next().await {
1471 let text = message?;
1472 let mut lines = text.lines();
1473 new_summary.extend(lines.next());
1474
1475 // Stop if the LLM generated multiple lines.
1476 if lines.next().is_some() {
1477 break;
1478 }
1479 }
1480
1481 this.update(cx, |this, cx| {
1482 if !new_summary.is_empty() {
1483 this.summary = Some(new_summary.into());
1484 }
1485
1486 cx.emit(ThreadEvent::SummaryGenerated);
1487 })?;
1488
1489 anyhow::Ok(())
1490 }
1491 .log_err()
1492 .await
1493 });
1494 }
1495
1496 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1497 let last_message_id = self.messages.last().map(|message| message.id)?;
1498
1499 match &self.detailed_summary_state {
1500 DetailedSummaryState::Generating { message_id, .. }
1501 | DetailedSummaryState::Generated { message_id, .. }
1502 if *message_id == last_message_id =>
1503 {
1504 // Already up-to-date
1505 return None;
1506 }
1507 _ => {}
1508 }
1509
1510 let ConfiguredModel { model, provider } =
1511 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1512
1513 if !provider.is_authenticated(cx) {
1514 return None;
1515 }
1516
1517 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1518 1. A brief overview of what was discussed\n\
1519 2. Key facts or information discovered\n\
1520 3. Outcomes or conclusions reached\n\
1521 4. Any action items or next steps if any\n\
1522 Format it in Markdown with headings and bullet points.";
1523
1524 let request = self.to_summarize_request(added_user_message.into());
1525
1526 let task = cx.spawn(async move |thread, cx| {
1527 let stream = model.stream_completion_text(request, &cx);
1528 let Some(mut messages) = stream.await.log_err() else {
1529 thread
1530 .update(cx, |this, _cx| {
1531 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1532 })
1533 .log_err();
1534
1535 return;
1536 };
1537
1538 let mut new_detailed_summary = String::new();
1539
1540 while let Some(chunk) = messages.stream.next().await {
1541 if let Some(chunk) = chunk.log_err() {
1542 new_detailed_summary.push_str(&chunk);
1543 }
1544 }
1545
1546 thread
1547 .update(cx, |this, _cx| {
1548 this.detailed_summary_state = DetailedSummaryState::Generated {
1549 text: new_detailed_summary.into(),
1550 message_id: last_message_id,
1551 };
1552 })
1553 .log_err();
1554 });
1555
1556 self.detailed_summary_state = DetailedSummaryState::Generating {
1557 message_id: last_message_id,
1558 };
1559
1560 Some(task)
1561 }
1562
1563 pub fn is_generating_detailed_summary(&self) -> bool {
1564 matches!(
1565 self.detailed_summary_state,
1566 DetailedSummaryState::Generating { .. }
1567 )
1568 }
1569
1570 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1571 self.auto_capture_telemetry(cx);
1572 let request = self.to_completion_request(cx);
1573 let messages = Arc::new(request.messages);
1574 let pending_tool_uses = self
1575 .tool_use
1576 .pending_tool_uses()
1577 .into_iter()
1578 .filter(|tool_use| tool_use.status.is_idle())
1579 .cloned()
1580 .collect::<Vec<_>>();
1581
1582 for tool_use in pending_tool_uses.iter() {
1583 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1584 if tool.needs_confirmation(&tool_use.input, cx)
1585 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1586 {
1587 self.tool_use.confirm_tool_use(
1588 tool_use.id.clone(),
1589 tool_use.ui_text.clone(),
1590 tool_use.input.clone(),
1591 messages.clone(),
1592 tool,
1593 );
1594 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1595 } else {
1596 self.run_tool(
1597 tool_use.id.clone(),
1598 tool_use.ui_text.clone(),
1599 tool_use.input.clone(),
1600 &messages,
1601 tool,
1602 cx,
1603 );
1604 }
1605 }
1606 }
1607
1608 pending_tool_uses
1609 }
1610
1611 pub fn run_tool(
1612 &mut self,
1613 tool_use_id: LanguageModelToolUseId,
1614 ui_text: impl Into<SharedString>,
1615 input: serde_json::Value,
1616 messages: &[LanguageModelRequestMessage],
1617 tool: Arc<dyn Tool>,
1618 cx: &mut Context<Thread>,
1619 ) {
1620 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1621 self.tool_use
1622 .run_pending_tool(tool_use_id, ui_text.into(), task);
1623 }
1624
1625 fn spawn_tool_use(
1626 &mut self,
1627 tool_use_id: LanguageModelToolUseId,
1628 messages: &[LanguageModelRequestMessage],
1629 input: serde_json::Value,
1630 tool: Arc<dyn Tool>,
1631 cx: &mut Context<Thread>,
1632 ) -> Task<()> {
1633 let tool_name: Arc<str> = tool.name().into();
1634
1635 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1636 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1637 } else {
1638 tool.run(
1639 input,
1640 messages,
1641 self.project.clone(),
1642 self.action_log.clone(),
1643 cx,
1644 )
1645 };
1646
1647 // Store the card separately if it exists
1648 if let Some(card) = tool_result.card.clone() {
1649 self.tool_use
1650 .insert_tool_result_card(tool_use_id.clone(), card);
1651 }
1652
1653 cx.spawn({
1654 async move |thread: WeakEntity<Thread>, cx| {
1655 let output = tool_result.output.await;
1656
1657 thread
1658 .update(cx, |thread, cx| {
1659 let pending_tool_use = thread.tool_use.insert_tool_output(
1660 tool_use_id.clone(),
1661 tool_name,
1662 output,
1663 cx,
1664 );
1665 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1666 })
1667 .ok();
1668 }
1669 })
1670 }
1671
1672 fn tool_finished(
1673 &mut self,
1674 tool_use_id: LanguageModelToolUseId,
1675 pending_tool_use: Option<PendingToolUse>,
1676 canceled: bool,
1677 cx: &mut Context<Self>,
1678 ) {
1679 if self.all_tools_finished() {
1680 let model_registry = LanguageModelRegistry::read_global(cx);
1681 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1682 self.attach_tool_results(cx);
1683 if !canceled {
1684 self.send_to_model(model, cx);
1685 }
1686 }
1687 }
1688
1689 cx.emit(ThreadEvent::ToolFinished {
1690 tool_use_id,
1691 pending_tool_use,
1692 });
1693 }
1694
1695 /// Insert an empty message to be populated with tool results upon send.
1696 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1697 // Tool results are assumed to be waiting on the next message id, so they will populate
1698 // this empty message before sending to model. Would prefer this to be more straightforward.
1699 self.insert_message(Role::User, vec![], cx);
1700 self.auto_capture_telemetry(cx);
1701 }
1702
1703 /// Cancels the last pending completion, if there are any pending.
1704 ///
1705 /// Returns whether a completion was canceled.
1706 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1707 let canceled = if self.pending_completions.pop().is_some() {
1708 true
1709 } else {
1710 let mut canceled = false;
1711 for pending_tool_use in self.tool_use.cancel_pending() {
1712 canceled = true;
1713 self.tool_finished(
1714 pending_tool_use.id.clone(),
1715 Some(pending_tool_use),
1716 true,
1717 cx,
1718 );
1719 }
1720 canceled
1721 };
1722 self.finalize_pending_checkpoint(cx);
1723 canceled
1724 }
1725
1726 pub fn feedback(&self) -> Option<ThreadFeedback> {
1727 self.feedback
1728 }
1729
1730 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1731 self.message_feedback.get(&message_id).copied()
1732 }
1733
1734 pub fn report_message_feedback(
1735 &mut self,
1736 message_id: MessageId,
1737 feedback: ThreadFeedback,
1738 cx: &mut Context<Self>,
1739 ) -> Task<Result<()>> {
1740 if self.message_feedback.get(&message_id) == Some(&feedback) {
1741 return Task::ready(Ok(()));
1742 }
1743
1744 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1745 let serialized_thread = self.serialize(cx);
1746 let thread_id = self.id().clone();
1747 let client = self.project.read(cx).client();
1748
1749 let enabled_tool_names: Vec<String> = self
1750 .tools()
1751 .read(cx)
1752 .enabled_tools(cx)
1753 .iter()
1754 .map(|tool| tool.name().to_string())
1755 .collect();
1756
1757 self.message_feedback.insert(message_id, feedback);
1758
1759 cx.notify();
1760
1761 let message_content = self
1762 .message(message_id)
1763 .map(|msg| msg.to_string())
1764 .unwrap_or_default();
1765
1766 cx.background_spawn(async move {
1767 let final_project_snapshot = final_project_snapshot.await;
1768 let serialized_thread = serialized_thread.await?;
1769 let thread_data =
1770 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1771
1772 let rating = match feedback {
1773 ThreadFeedback::Positive => "positive",
1774 ThreadFeedback::Negative => "negative",
1775 };
1776 telemetry::event!(
1777 "Assistant Thread Rated",
1778 rating,
1779 thread_id,
1780 enabled_tool_names,
1781 message_id = message_id.0,
1782 message_content,
1783 thread_data,
1784 final_project_snapshot
1785 );
1786 client.telemetry().flush_events();
1787
1788 Ok(())
1789 })
1790 }
1791
1792 pub fn report_feedback(
1793 &mut self,
1794 feedback: ThreadFeedback,
1795 cx: &mut Context<Self>,
1796 ) -> Task<Result<()>> {
1797 let last_assistant_message_id = self
1798 .messages
1799 .iter()
1800 .rev()
1801 .find(|msg| msg.role == Role::Assistant)
1802 .map(|msg| msg.id);
1803
1804 if let Some(message_id) = last_assistant_message_id {
1805 self.report_message_feedback(message_id, feedback, cx)
1806 } else {
1807 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1808 let serialized_thread = self.serialize(cx);
1809 let thread_id = self.id().clone();
1810 let client = self.project.read(cx).client();
1811 self.feedback = Some(feedback);
1812 cx.notify();
1813
1814 cx.background_spawn(async move {
1815 let final_project_snapshot = final_project_snapshot.await;
1816 let serialized_thread = serialized_thread.await?;
1817 let thread_data = serde_json::to_value(serialized_thread)
1818 .unwrap_or_else(|_| serde_json::Value::Null);
1819
1820 let rating = match feedback {
1821 ThreadFeedback::Positive => "positive",
1822 ThreadFeedback::Negative => "negative",
1823 };
1824 telemetry::event!(
1825 "Assistant Thread Rated",
1826 rating,
1827 thread_id,
1828 thread_data,
1829 final_project_snapshot
1830 );
1831 client.telemetry().flush_events();
1832
1833 Ok(())
1834 })
1835 }
1836 }
1837
1838 /// Create a snapshot of the current project state including git information and unsaved buffers.
1839 fn project_snapshot(
1840 project: Entity<Project>,
1841 cx: &mut Context<Self>,
1842 ) -> Task<Arc<ProjectSnapshot>> {
1843 let git_store = project.read(cx).git_store().clone();
1844 let worktree_snapshots: Vec<_> = project
1845 .read(cx)
1846 .visible_worktrees(cx)
1847 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1848 .collect();
1849
1850 cx.spawn(async move |_, cx| {
1851 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1852
1853 let mut unsaved_buffers = Vec::new();
1854 cx.update(|app_cx| {
1855 let buffer_store = project.read(app_cx).buffer_store();
1856 for buffer_handle in buffer_store.read(app_cx).buffers() {
1857 let buffer = buffer_handle.read(app_cx);
1858 if buffer.is_dirty() {
1859 if let Some(file) = buffer.file() {
1860 let path = file.path().to_string_lossy().to_string();
1861 unsaved_buffers.push(path);
1862 }
1863 }
1864 }
1865 })
1866 .ok();
1867
1868 Arc::new(ProjectSnapshot {
1869 worktree_snapshots,
1870 unsaved_buffer_paths: unsaved_buffers,
1871 timestamp: Utc::now(),
1872 })
1873 })
1874 }
1875
1876 fn worktree_snapshot(
1877 worktree: Entity<project::Worktree>,
1878 git_store: Entity<GitStore>,
1879 cx: &App,
1880 ) -> Task<WorktreeSnapshot> {
1881 cx.spawn(async move |cx| {
1882 // Get worktree path and snapshot
1883 let worktree_info = cx.update(|app_cx| {
1884 let worktree = worktree.read(app_cx);
1885 let path = worktree.abs_path().to_string_lossy().to_string();
1886 let snapshot = worktree.snapshot();
1887 (path, snapshot)
1888 });
1889
1890 let Ok((worktree_path, _snapshot)) = worktree_info else {
1891 return WorktreeSnapshot {
1892 worktree_path: String::new(),
1893 git_state: None,
1894 };
1895 };
1896
1897 let git_state = git_store
1898 .update(cx, |git_store, cx| {
1899 git_store
1900 .repositories()
1901 .values()
1902 .find(|repo| {
1903 repo.read(cx)
1904 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1905 .is_some()
1906 })
1907 .cloned()
1908 })
1909 .ok()
1910 .flatten()
1911 .map(|repo| {
1912 repo.update(cx, |repo, _| {
1913 let current_branch =
1914 repo.branch.as_ref().map(|branch| branch.name.to_string());
1915 repo.send_job(None, |state, _| async move {
1916 let RepositoryState::Local { backend, .. } = state else {
1917 return GitState {
1918 remote_url: None,
1919 head_sha: None,
1920 current_branch,
1921 diff: None,
1922 };
1923 };
1924
1925 let remote_url = backend.remote_url("origin");
1926 let head_sha = backend.head_sha();
1927 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1928
1929 GitState {
1930 remote_url,
1931 head_sha,
1932 current_branch,
1933 diff,
1934 }
1935 })
1936 })
1937 });
1938
1939 let git_state = match git_state {
1940 Some(git_state) => match git_state.ok() {
1941 Some(git_state) => git_state.await.ok(),
1942 None => None,
1943 },
1944 None => None,
1945 };
1946
1947 WorktreeSnapshot {
1948 worktree_path,
1949 git_state,
1950 }
1951 })
1952 }
1953
1954 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1955 let mut markdown = Vec::new();
1956
1957 if let Some(summary) = self.summary() {
1958 writeln!(markdown, "# {summary}\n")?;
1959 };
1960
1961 for message in self.messages() {
1962 writeln!(
1963 markdown,
1964 "## {role}\n",
1965 role = match message.role {
1966 Role::User => "User",
1967 Role::Assistant => "Assistant",
1968 Role::System => "System",
1969 }
1970 )?;
1971
1972 if !message.context.is_empty() {
1973 writeln!(markdown, "{}", message.context)?;
1974 }
1975
1976 for segment in &message.segments {
1977 match segment {
1978 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1979 MessageSegment::Thinking { text, .. } => {
1980 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1981 }
1982 MessageSegment::RedactedThinking(_) => {}
1983 }
1984 }
1985
1986 for tool_use in self.tool_uses_for_message(message.id, cx) {
1987 writeln!(
1988 markdown,
1989 "**Use Tool: {} ({})**",
1990 tool_use.name, tool_use.id
1991 )?;
1992 writeln!(markdown, "```json")?;
1993 writeln!(
1994 markdown,
1995 "{}",
1996 serde_json::to_string_pretty(&tool_use.input)?
1997 )?;
1998 writeln!(markdown, "```")?;
1999 }
2000
2001 for tool_result in self.tool_results_for_message(message.id) {
2002 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
2003 if tool_result.is_error {
2004 write!(markdown, " (Error)")?;
2005 }
2006
2007 writeln!(markdown, "**\n")?;
2008 writeln!(markdown, "{}", tool_result.content)?;
2009 }
2010 }
2011
2012 Ok(String::from_utf8_lossy(&markdown).to_string())
2013 }
2014
2015 pub fn keep_edits_in_range(
2016 &mut self,
2017 buffer: Entity<language::Buffer>,
2018 buffer_range: Range<language::Anchor>,
2019 cx: &mut Context<Self>,
2020 ) {
2021 self.action_log.update(cx, |action_log, cx| {
2022 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2023 });
2024 }
2025
2026 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2027 self.action_log
2028 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2029 }
2030
2031 pub fn reject_edits_in_ranges(
2032 &mut self,
2033 buffer: Entity<language::Buffer>,
2034 buffer_ranges: Vec<Range<language::Anchor>>,
2035 cx: &mut Context<Self>,
2036 ) -> Task<Result<()>> {
2037 self.action_log.update(cx, |action_log, cx| {
2038 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2039 })
2040 }
2041
2042 pub fn action_log(&self) -> &Entity<ActionLog> {
2043 &self.action_log
2044 }
2045
2046 pub fn project(&self) -> &Entity<Project> {
2047 &self.project
2048 }
2049
2050 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2051 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2052 return;
2053 }
2054
2055 let now = Instant::now();
2056 if let Some(last) = self.last_auto_capture_at {
2057 if now.duration_since(last).as_secs() < 10 {
2058 return;
2059 }
2060 }
2061
2062 self.last_auto_capture_at = Some(now);
2063
2064 let thread_id = self.id().clone();
2065 let github_login = self
2066 .project
2067 .read(cx)
2068 .user_store()
2069 .read(cx)
2070 .current_user()
2071 .map(|user| user.github_login.clone());
2072 let client = self.project.read(cx).client().clone();
2073 let serialize_task = self.serialize(cx);
2074
2075 cx.background_executor()
2076 .spawn(async move {
2077 if let Ok(serialized_thread) = serialize_task.await {
2078 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2079 telemetry::event!(
2080 "Agent Thread Auto-Captured",
2081 thread_id = thread_id.to_string(),
2082 thread_data = thread_data,
2083 auto_capture_reason = "tracked_user",
2084 github_login = github_login
2085 );
2086
2087 client.telemetry().flush_events();
2088 }
2089 }
2090 })
2091 .detach();
2092 }
2093
2094 pub fn cumulative_token_usage(&self) -> TokenUsage {
2095 self.cumulative_token_usage
2096 }
2097
2098 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2099 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2100 return TotalTokenUsage::default();
2101 };
2102
2103 let max = model.model.max_token_count();
2104
2105 let index = self
2106 .messages
2107 .iter()
2108 .position(|msg| msg.id == message_id)
2109 .unwrap_or(0);
2110
2111 if index == 0 {
2112 return TotalTokenUsage { total: 0, max };
2113 }
2114
2115 let token_usage = &self
2116 .request_token_usage
2117 .get(index - 1)
2118 .cloned()
2119 .unwrap_or_default();
2120
2121 TotalTokenUsage {
2122 total: token_usage.total_tokens() as usize,
2123 max,
2124 }
2125 }
2126
2127 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2128 let model_registry = LanguageModelRegistry::read_global(cx);
2129 let Some(model) = model_registry.default_model() else {
2130 return TotalTokenUsage::default();
2131 };
2132
2133 let max = model.model.max_token_count();
2134
2135 if let Some(exceeded_error) = &self.exceeded_window_error {
2136 if model.model.id() == exceeded_error.model_id {
2137 return TotalTokenUsage {
2138 total: exceeded_error.token_count,
2139 max,
2140 };
2141 }
2142 }
2143
2144 let total = self
2145 .token_usage_at_last_message()
2146 .unwrap_or_default()
2147 .total_tokens() as usize;
2148
2149 TotalTokenUsage { total, max }
2150 }
2151
2152 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2153 self.request_token_usage
2154 .get(self.messages.len().saturating_sub(1))
2155 .or_else(|| self.request_token_usage.last())
2156 .cloned()
2157 }
2158
2159 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2160 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2161 self.request_token_usage
2162 .resize(self.messages.len(), placeholder);
2163
2164 if let Some(last) = self.request_token_usage.last_mut() {
2165 *last = token_usage;
2166 }
2167 }
2168
2169 pub fn deny_tool_use(
2170 &mut self,
2171 tool_use_id: LanguageModelToolUseId,
2172 tool_name: Arc<str>,
2173 cx: &mut Context<Self>,
2174 ) {
2175 let err = Err(anyhow::anyhow!(
2176 "Permission to run tool action denied by user"
2177 ));
2178
2179 self.tool_use
2180 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2181 self.tool_finished(tool_use_id.clone(), None, true, cx);
2182 }
2183}
2184
2185#[derive(Debug, Clone, Error)]
2186pub enum ThreadError {
2187 #[error("Payment required")]
2188 PaymentRequired,
2189 #[error("Max monthly spend reached")]
2190 MaxMonthlySpendReached,
2191 #[error("Model request limit reached")]
2192 ModelRequestLimitReached { plan: Plan },
2193 #[error("Message {header}: {message}")]
2194 Message {
2195 header: SharedString,
2196 message: SharedString,
2197 },
2198}
2199
2200#[derive(Debug, Clone)]
2201pub enum ThreadEvent {
2202 ShowError(ThreadError),
2203 UsageUpdated(RequestUsage),
2204 StreamedCompletion,
2205 StreamedAssistantText(MessageId, String),
2206 StreamedAssistantThinking(MessageId, String),
2207 StreamedToolUse {
2208 tool_use_id: LanguageModelToolUseId,
2209 ui_text: Arc<str>,
2210 input: serde_json::Value,
2211 },
2212 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2213 MessageAdded(MessageId),
2214 MessageEdited(MessageId),
2215 MessageDeleted(MessageId),
2216 SummaryGenerated,
2217 SummaryChanged,
2218 UsePendingTools {
2219 tool_uses: Vec<PendingToolUse>,
2220 },
2221 ToolFinished {
2222 #[allow(unused)]
2223 tool_use_id: LanguageModelToolUseId,
2224 /// The pending tool use that corresponds to this tool.
2225 pending_tool_use: Option<PendingToolUse>,
2226 },
2227 CheckpointChanged,
2228 ToolConfirmationNeeded,
2229}
2230
2231impl EventEmitter<ThreadEvent> for Thread {}
2232
2233struct PendingCompletion {
2234 id: usize,
2235 _task: Task<()>,
2236}
2237
2238#[cfg(test)]
2239mod tests {
2240 use super::*;
2241 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2242 use assistant_settings::AssistantSettings;
2243 use context_server::ContextServerSettings;
2244 use editor::EditorSettings;
2245 use gpui::TestAppContext;
2246 use project::{FakeFs, Project};
2247 use prompt_store::PromptBuilder;
2248 use serde_json::json;
2249 use settings::{Settings, SettingsStore};
2250 use std::sync::Arc;
2251 use theme::ThemeSettings;
2252 use util::path;
2253 use workspace::Workspace;
2254
2255 #[gpui::test]
2256 async fn test_message_with_context(cx: &mut TestAppContext) {
2257 init_test_settings(cx);
2258
2259 let project = create_test_project(
2260 cx,
2261 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2262 )
2263 .await;
2264
2265 let (_workspace, _thread_store, thread, context_store) =
2266 setup_test_environment(cx, project.clone()).await;
2267
2268 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2269 .await
2270 .unwrap();
2271
2272 let context =
2273 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2274
2275 // Insert user message with context
2276 let message_id = thread.update(cx, |thread, cx| {
2277 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2278 });
2279
2280 // Check content and context in message object
2281 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2282
2283 // Use different path format strings based on platform for the test
2284 #[cfg(windows)]
2285 let path_part = r"test\code.rs";
2286 #[cfg(not(windows))]
2287 let path_part = "test/code.rs";
2288
2289 let expected_context = format!(
2290 r#"
2291<context>
2292The following items were attached by the user. You don't need to use other tools to read them.
2293
2294<files>
2295```rs {path_part}
2296fn main() {{
2297 println!("Hello, world!");
2298}}
2299```
2300</files>
2301</context>
2302"#
2303 );
2304
2305 assert_eq!(message.role, Role::User);
2306 assert_eq!(message.segments.len(), 1);
2307 assert_eq!(
2308 message.segments[0],
2309 MessageSegment::Text("Please explain this code".to_string())
2310 );
2311 assert_eq!(message.context, expected_context);
2312
2313 // Check message in request
2314 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2315
2316 assert_eq!(request.messages.len(), 2);
2317 let expected_full_message = format!("{}Please explain this code", expected_context);
2318 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2319 }
2320
2321 #[gpui::test]
2322 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2323 init_test_settings(cx);
2324
2325 let project = create_test_project(
2326 cx,
2327 json!({
2328 "file1.rs": "fn function1() {}\n",
2329 "file2.rs": "fn function2() {}\n",
2330 "file3.rs": "fn function3() {}\n",
2331 }),
2332 )
2333 .await;
2334
2335 let (_, _thread_store, thread, context_store) =
2336 setup_test_environment(cx, project.clone()).await;
2337
2338 // Open files individually
2339 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2340 .await
2341 .unwrap();
2342 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2343 .await
2344 .unwrap();
2345 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2346 .await
2347 .unwrap();
2348
2349 // Get the context objects
2350 let contexts = context_store.update(cx, |store, _| store.context().clone());
2351 assert_eq!(contexts.len(), 3);
2352
2353 // First message with context 1
2354 let message1_id = thread.update(cx, |thread, cx| {
2355 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2356 });
2357
2358 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2359 let message2_id = thread.update(cx, |thread, cx| {
2360 thread.insert_user_message(
2361 "Message 2",
2362 vec![contexts[0].clone(), contexts[1].clone()],
2363 None,
2364 cx,
2365 )
2366 });
2367
2368 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2369 let message3_id = thread.update(cx, |thread, cx| {
2370 thread.insert_user_message(
2371 "Message 3",
2372 vec![
2373 contexts[0].clone(),
2374 contexts[1].clone(),
2375 contexts[2].clone(),
2376 ],
2377 None,
2378 cx,
2379 )
2380 });
2381
2382 // Check what contexts are included in each message
2383 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2384 (
2385 thread.message(message1_id).unwrap().clone(),
2386 thread.message(message2_id).unwrap().clone(),
2387 thread.message(message3_id).unwrap().clone(),
2388 )
2389 });
2390
2391 // First message should include context 1
2392 assert!(message1.context.contains("file1.rs"));
2393
2394 // Second message should include only context 2 (not 1)
2395 assert!(!message2.context.contains("file1.rs"));
2396 assert!(message2.context.contains("file2.rs"));
2397
2398 // Third message should include only context 3 (not 1 or 2)
2399 assert!(!message3.context.contains("file1.rs"));
2400 assert!(!message3.context.contains("file2.rs"));
2401 assert!(message3.context.contains("file3.rs"));
2402
2403 // Check entire request to make sure all contexts are properly included
2404 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2405
2406 // The request should contain all 3 messages
2407 assert_eq!(request.messages.len(), 4);
2408
2409 // Check that the contexts are properly formatted in each message
2410 assert!(request.messages[1].string_contents().contains("file1.rs"));
2411 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2412 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2413
2414 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2415 assert!(request.messages[2].string_contents().contains("file2.rs"));
2416 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2417
2418 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2419 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2420 assert!(request.messages[3].string_contents().contains("file3.rs"));
2421 }
2422
2423 #[gpui::test]
2424 async fn test_message_without_files(cx: &mut TestAppContext) {
2425 init_test_settings(cx);
2426
2427 let project = create_test_project(
2428 cx,
2429 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2430 )
2431 .await;
2432
2433 let (_, _thread_store, thread, _context_store) =
2434 setup_test_environment(cx, project.clone()).await;
2435
2436 // Insert user message without any context (empty context vector)
2437 let message_id = thread.update(cx, |thread, cx| {
2438 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2439 });
2440
2441 // Check content and context in message object
2442 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2443
2444 // Context should be empty when no files are included
2445 assert_eq!(message.role, Role::User);
2446 assert_eq!(message.segments.len(), 1);
2447 assert_eq!(
2448 message.segments[0],
2449 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2450 );
2451 assert_eq!(message.context, "");
2452
2453 // Check message in request
2454 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2455
2456 assert_eq!(request.messages.len(), 2);
2457 assert_eq!(
2458 request.messages[1].string_contents(),
2459 "What is the best way to learn Rust?"
2460 );
2461
2462 // Add second message, also without context
2463 let message2_id = thread.update(cx, |thread, cx| {
2464 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2465 });
2466
2467 let message2 =
2468 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2469 assert_eq!(message2.context, "");
2470
2471 // Check that both messages appear in the request
2472 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2473
2474 assert_eq!(request.messages.len(), 3);
2475 assert_eq!(
2476 request.messages[1].string_contents(),
2477 "What is the best way to learn Rust?"
2478 );
2479 assert_eq!(
2480 request.messages[2].string_contents(),
2481 "Are there any good books?"
2482 );
2483 }
2484
2485 #[gpui::test]
2486 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2487 init_test_settings(cx);
2488
2489 let project = create_test_project(
2490 cx,
2491 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2492 )
2493 .await;
2494
2495 let (_workspace, _thread_store, thread, context_store) =
2496 setup_test_environment(cx, project.clone()).await;
2497
2498 // Open buffer and add it to context
2499 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2500 .await
2501 .unwrap();
2502
2503 let context =
2504 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2505
2506 // Insert user message with the buffer as context
2507 thread.update(cx, |thread, cx| {
2508 thread.insert_user_message("Explain this code", vec![context], None, cx)
2509 });
2510
2511 // Create a request and check that it doesn't have a stale buffer warning yet
2512 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2513
2514 // Make sure we don't have a stale file warning yet
2515 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2516 msg.string_contents()
2517 .contains("These files changed since last read:")
2518 });
2519 assert!(
2520 !has_stale_warning,
2521 "Should not have stale buffer warning before buffer is modified"
2522 );
2523
2524 // Modify the buffer
2525 buffer.update(cx, |buffer, cx| {
2526 // Find a position at the end of line 1
2527 buffer.edit(
2528 [(1..1, "\n println!(\"Added a new line\");\n")],
2529 None,
2530 cx,
2531 );
2532 });
2533
2534 // Insert another user message without context
2535 thread.update(cx, |thread, cx| {
2536 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2537 });
2538
2539 // Create a new request and check for the stale buffer warning
2540 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2541
2542 // We should have a stale file warning as the last message
2543 let last_message = new_request
2544 .messages
2545 .last()
2546 .expect("Request should have messages");
2547
2548 // The last message should be the stale buffer notification
2549 assert_eq!(last_message.role, Role::User);
2550
2551 // Check the exact content of the message
2552 let expected_content = "These files changed since last read:\n- code.rs\n";
2553 assert_eq!(
2554 last_message.string_contents(),
2555 expected_content,
2556 "Last message should be exactly the stale buffer notification"
2557 );
2558 }
2559
2560 fn init_test_settings(cx: &mut TestAppContext) {
2561 cx.update(|cx| {
2562 let settings_store = SettingsStore::test(cx);
2563 cx.set_global(settings_store);
2564 language::init(cx);
2565 Project::init_settings(cx);
2566 AssistantSettings::register(cx);
2567 prompt_store::init(cx);
2568 thread_store::init(cx);
2569 workspace::init_settings(cx);
2570 ThemeSettings::register(cx);
2571 ContextServerSettings::register(cx);
2572 EditorSettings::register(cx);
2573 });
2574 }
2575
2576 // Helper to create a test project with test files
2577 async fn create_test_project(
2578 cx: &mut TestAppContext,
2579 files: serde_json::Value,
2580 ) -> Entity<Project> {
2581 let fs = FakeFs::new(cx.executor());
2582 fs.insert_tree(path!("/test"), files).await;
2583 Project::test(fs, [path!("/test").as_ref()], cx).await
2584 }
2585
2586 async fn setup_test_environment(
2587 cx: &mut TestAppContext,
2588 project: Entity<Project>,
2589 ) -> (
2590 Entity<Workspace>,
2591 Entity<ThreadStore>,
2592 Entity<Thread>,
2593 Entity<ContextStore>,
2594 ) {
2595 let (workspace, cx) =
2596 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2597
2598 let thread_store = cx
2599 .update(|_, cx| {
2600 ThreadStore::load(
2601 project.clone(),
2602 cx.new(|_| ToolWorkingSet::default()),
2603 Arc::new(PromptBuilder::new(None).unwrap()),
2604 cx,
2605 )
2606 })
2607 .await
2608 .unwrap();
2609
2610 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2611 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2612
2613 (workspace, thread_store, thread, context_store)
2614 }
2615
2616 async fn add_file_to_context(
2617 project: &Entity<Project>,
2618 context_store: &Entity<ContextStore>,
2619 path: &str,
2620 cx: &mut TestAppContext,
2621 ) -> Result<Entity<language::Buffer>> {
2622 let buffer_path = project
2623 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2624 .unwrap();
2625
2626 let buffer = project
2627 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2628 .await
2629 .unwrap();
2630
2631 context_store
2632 .update(cx, |store, cx| {
2633 store.add_file_from_buffer(buffer.clone(), cx)
2634 })
2635 .await?;
2636
2637 Ok(buffer)
2638 }
2639}