1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
23 TokenUsage,
24};
25use project::Project;
26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
27use prompt_store::PromptBuilder;
28use proto::Plan;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings;
32use thiserror::Error;
33use util::{ResultExt as _, TryFutureExt as _, post_inc};
34use uuid::Uuid;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(
44 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
45)]
46pub struct ThreadId(Arc<str>);
47
48impl ThreadId {
49 pub fn new() -> Self {
50 Self(Uuid::new_v4().to_string().into())
51 }
52}
53
54impl std::fmt::Display for ThreadId {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}", self.0)
57 }
58}
59
60impl From<&str> for ThreadId {
61 fn from(value: &str) -> Self {
62 Self(value.into())
63 }
64}
65
66/// The ID of the user prompt that initiated a request.
67///
68/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
70pub struct PromptId(Arc<str>);
71
72impl PromptId {
73 pub fn new() -> Self {
74 Self(Uuid::new_v4().to_string().into())
75 }
76}
77
78impl std::fmt::Display for PromptId {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 write!(f, "{}", self.0)
81 }
82}
83
84#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
85pub struct MessageId(pub(crate) usize);
86
87impl MessageId {
88 fn post_inc(&mut self) -> Self {
89 Self(post_inc(&mut self.0))
90 }
91}
92
93/// A message in a [`Thread`].
94#[derive(Debug, Clone)]
95pub struct Message {
96 pub id: MessageId,
97 pub role: Role,
98 pub segments: Vec<MessageSegment>,
99 pub context: String,
100}
101
102impl Message {
103 /// Returns whether the message contains any meaningful text that should be displayed
104 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
105 pub fn should_display_content(&self) -> bool {
106 self.segments.iter().all(|segment| segment.should_display())
107 }
108
109 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
110 if let Some(MessageSegment::Thinking {
111 text: segment,
112 signature: current_signature,
113 }) = self.segments.last_mut()
114 {
115 if let Some(signature) = signature {
116 *current_signature = Some(signature);
117 }
118 segment.push_str(text);
119 } else {
120 self.segments.push(MessageSegment::Thinking {
121 text: text.to_string(),
122 signature,
123 });
124 }
125 }
126
127 pub fn push_text(&mut self, text: &str) {
128 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
129 segment.push_str(text);
130 } else {
131 self.segments.push(MessageSegment::Text(text.to_string()));
132 }
133 }
134
135 pub fn to_string(&self) -> String {
136 let mut result = String::new();
137
138 if !self.context.is_empty() {
139 result.push_str(&self.context);
140 }
141
142 for segment in &self.segments {
143 match segment {
144 MessageSegment::Text(text) => result.push_str(text),
145 MessageSegment::Thinking { text, .. } => {
146 result.push_str("<think>\n");
147 result.push_str(text);
148 result.push_str("\n</think>");
149 }
150 MessageSegment::RedactedThinking(_) => {}
151 }
152 }
153
154 result
155 }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum MessageSegment {
160 Text(String),
161 Thinking {
162 text: String,
163 signature: Option<String>,
164 },
165 RedactedThinking(Vec<u8>),
166}
167
168impl MessageSegment {
169 pub fn should_display(&self) -> bool {
170 // We add USING_TOOL_MARKER when making a request that includes tool uses
171 // without non-whitespace text around them, and this can cause the model
172 // to mimic the pattern, so we consider those segments not displayable.
173 match self {
174 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
175 Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 messages: Vec<Message>,
295 next_message_id: MessageId,
296 last_prompt_id: PromptId,
297 context: BTreeMap<ContextId, AssistantContext>,
298 context_by_message: HashMap<MessageId, Vec<ContextId>>,
299 project_context: SharedProjectContext,
300 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
301 completion_count: usize,
302 pending_completions: Vec<PendingCompletion>,
303 project: Entity<Project>,
304 prompt_builder: Arc<PromptBuilder>,
305 tools: Entity<ToolWorkingSet>,
306 tool_use: ToolUseState,
307 action_log: Entity<ActionLog>,
308 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
309 pending_checkpoint: Option<ThreadCheckpoint>,
310 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
311 request_token_usage: Vec<TokenUsage>,
312 cumulative_token_usage: TokenUsage,
313 exceeded_window_error: Option<ExceededWindowError>,
314 feedback: Option<ThreadFeedback>,
315 message_feedback: HashMap<MessageId, ThreadFeedback>,
316 last_auto_capture_at: Option<Instant>,
317 request_callback: Option<
318 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
319 >,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct ExceededWindowError {
324 /// Model used when last message exceeded context window
325 model_id: LanguageModelId,
326 /// Token count including last message
327 token_count: usize,
328}
329
330impl Thread {
331 pub fn new(
332 project: Entity<Project>,
333 tools: Entity<ToolWorkingSet>,
334 prompt_builder: Arc<PromptBuilder>,
335 system_prompt: SharedProjectContext,
336 cx: &mut Context<Self>,
337 ) -> Self {
338 Self {
339 id: ThreadId::new(),
340 updated_at: Utc::now(),
341 summary: None,
342 pending_summary: Task::ready(None),
343 detailed_summary_state: DetailedSummaryState::NotGenerated,
344 messages: Vec::new(),
345 next_message_id: MessageId(0),
346 last_prompt_id: PromptId::new(),
347 context: BTreeMap::default(),
348 context_by_message: HashMap::default(),
349 project_context: system_prompt,
350 checkpoints_by_message: HashMap::default(),
351 completion_count: 0,
352 pending_completions: Vec::new(),
353 project: project.clone(),
354 prompt_builder,
355 tools: tools.clone(),
356 last_restore_checkpoint: None,
357 pending_checkpoint: None,
358 tool_use: ToolUseState::new(tools.clone()),
359 action_log: cx.new(|_| ActionLog::new(project.clone())),
360 initial_project_snapshot: {
361 let project_snapshot = Self::project_snapshot(project, cx);
362 cx.foreground_executor()
363 .spawn(async move { Some(project_snapshot.await) })
364 .shared()
365 },
366 request_token_usage: Vec::new(),
367 cumulative_token_usage: TokenUsage::default(),
368 exceeded_window_error: None,
369 feedback: None,
370 message_feedback: HashMap::default(),
371 last_auto_capture_at: None,
372 request_callback: None,
373 }
374 }
375
376 pub fn deserialize(
377 id: ThreadId,
378 serialized: SerializedThread,
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 project_context: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let next_message_id = MessageId(
386 serialized
387 .messages
388 .last()
389 .map(|message| message.id.0 + 1)
390 .unwrap_or(0),
391 );
392 let tool_use =
393 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
394
395 Self {
396 id,
397 updated_at: serialized.updated_at,
398 summary: Some(serialized.summary),
399 pending_summary: Task::ready(None),
400 detailed_summary_state: serialized.detailed_summary_state,
401 messages: serialized
402 .messages
403 .into_iter()
404 .map(|message| Message {
405 id: message.id,
406 role: message.role,
407 segments: message
408 .segments
409 .into_iter()
410 .map(|segment| match segment {
411 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
412 SerializedMessageSegment::Thinking { text, signature } => {
413 MessageSegment::Thinking { text, signature }
414 }
415 SerializedMessageSegment::RedactedThinking { data } => {
416 MessageSegment::RedactedThinking(data)
417 }
418 })
419 .collect(),
420 context: message.context,
421 })
422 .collect(),
423 next_message_id,
424 last_prompt_id: PromptId::new(),
425 context: BTreeMap::default(),
426 context_by_message: HashMap::default(),
427 project_context,
428 checkpoints_by_message: HashMap::default(),
429 completion_count: 0,
430 pending_completions: Vec::new(),
431 last_restore_checkpoint: None,
432 pending_checkpoint: None,
433 project: project.clone(),
434 prompt_builder,
435 tools,
436 tool_use,
437 action_log: cx.new(|_| ActionLog::new(project)),
438 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
439 request_token_usage: serialized.request_token_usage,
440 cumulative_token_usage: serialized.cumulative_token_usage,
441 exceeded_window_error: None,
442 feedback: None,
443 message_feedback: HashMap::default(),
444 last_auto_capture_at: None,
445 request_callback: None,
446 }
447 }
448
449 pub fn set_request_callback(
450 &mut self,
451 callback: impl 'static
452 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
453 ) {
454 self.request_callback = Some(Box::new(callback));
455 }
456
457 pub fn id(&self) -> &ThreadId {
458 &self.id
459 }
460
461 pub fn is_empty(&self) -> bool {
462 self.messages.is_empty()
463 }
464
465 pub fn updated_at(&self) -> DateTime<Utc> {
466 self.updated_at
467 }
468
469 pub fn touch_updated_at(&mut self) {
470 self.updated_at = Utc::now();
471 }
472
473 pub fn advance_prompt_id(&mut self) {
474 self.last_prompt_id = PromptId::new();
475 }
476
477 pub fn summary(&self) -> Option<SharedString> {
478 self.summary.clone()
479 }
480
481 pub fn project_context(&self) -> SharedProjectContext {
482 self.project_context.clone()
483 }
484
485 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
486
487 pub fn summary_or_default(&self) -> SharedString {
488 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
489 }
490
491 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
492 let Some(current_summary) = &self.summary else {
493 // Don't allow setting summary until generated
494 return;
495 };
496
497 let mut new_summary = new_summary.into();
498
499 if new_summary.is_empty() {
500 new_summary = Self::DEFAULT_SUMMARY;
501 }
502
503 if current_summary != &new_summary {
504 self.summary = Some(new_summary);
505 cx.emit(ThreadEvent::SummaryChanged);
506 }
507 }
508
509 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
510 self.latest_detailed_summary()
511 .unwrap_or_else(|| self.text().into())
512 }
513
514 fn latest_detailed_summary(&self) -> Option<SharedString> {
515 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
516 Some(text.clone())
517 } else {
518 None
519 }
520 }
521
522 pub fn message(&self, id: MessageId) -> Option<&Message> {
523 self.messages.iter().find(|message| message.id == id)
524 }
525
526 pub fn messages(&self) -> impl Iterator<Item = &Message> {
527 self.messages.iter()
528 }
529
530 pub fn is_generating(&self) -> bool {
531 !self.pending_completions.is_empty() || !self.all_tools_finished()
532 }
533
534 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
535 &self.tools
536 }
537
538 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
539 self.tool_use
540 .pending_tool_uses()
541 .into_iter()
542 .find(|tool_use| &tool_use.id == id)
543 }
544
545 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
546 self.tool_use
547 .pending_tool_uses()
548 .into_iter()
549 .filter(|tool_use| tool_use.status.needs_confirmation())
550 }
551
552 pub fn has_pending_tool_uses(&self) -> bool {
553 !self.tool_use.pending_tool_uses().is_empty()
554 }
555
556 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
557 self.checkpoints_by_message.get(&id).cloned()
558 }
559
560 pub fn restore_checkpoint(
561 &mut self,
562 checkpoint: ThreadCheckpoint,
563 cx: &mut Context<Self>,
564 ) -> Task<Result<()>> {
565 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
566 message_id: checkpoint.message_id,
567 });
568 cx.emit(ThreadEvent::CheckpointChanged);
569 cx.notify();
570
571 let git_store = self.project().read(cx).git_store().clone();
572 let restore = git_store.update(cx, |git_store, cx| {
573 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
574 });
575
576 cx.spawn(async move |this, cx| {
577 let result = restore.await;
578 this.update(cx, |this, cx| {
579 if let Err(err) = result.as_ref() {
580 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
581 message_id: checkpoint.message_id,
582 error: err.to_string(),
583 });
584 } else {
585 this.truncate(checkpoint.message_id, cx);
586 this.last_restore_checkpoint = None;
587 }
588 this.pending_checkpoint = None;
589 cx.emit(ThreadEvent::CheckpointChanged);
590 cx.notify();
591 })?;
592 result
593 })
594 }
595
596 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
597 let pending_checkpoint = if self.is_generating() {
598 return;
599 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
600 checkpoint
601 } else {
602 return;
603 };
604
605 let git_store = self.project.read(cx).git_store().clone();
606 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
607 cx.spawn(async move |this, cx| match final_checkpoint.await {
608 Ok(final_checkpoint) => {
609 let equal = git_store
610 .update(cx, |store, cx| {
611 store.compare_checkpoints(
612 pending_checkpoint.git_checkpoint.clone(),
613 final_checkpoint.clone(),
614 cx,
615 )
616 })?
617 .await
618 .unwrap_or(false);
619
620 if equal {
621 git_store
622 .update(cx, |store, cx| {
623 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
624 })?
625 .detach();
626 } else {
627 this.update(cx, |this, cx| {
628 this.insert_checkpoint(pending_checkpoint, cx)
629 })?;
630 }
631
632 git_store
633 .update(cx, |store, cx| {
634 store.delete_checkpoint(final_checkpoint, cx)
635 })?
636 .detach();
637
638 Ok(())
639 }
640 Err(_) => this.update(cx, |this, cx| {
641 this.insert_checkpoint(pending_checkpoint, cx)
642 }),
643 })
644 .detach();
645 }
646
647 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
648 self.checkpoints_by_message
649 .insert(checkpoint.message_id, checkpoint);
650 cx.emit(ThreadEvent::CheckpointChanged);
651 cx.notify();
652 }
653
654 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
655 self.last_restore_checkpoint.as_ref()
656 }
657
658 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
659 let Some(message_ix) = self
660 .messages
661 .iter()
662 .rposition(|message| message.id == message_id)
663 else {
664 return;
665 };
666 for deleted_message in self.messages.drain(message_ix..) {
667 self.context_by_message.remove(&deleted_message.id);
668 self.checkpoints_by_message.remove(&deleted_message.id);
669 }
670 cx.notify();
671 }
672
673 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
674 self.context_by_message
675 .get(&id)
676 .into_iter()
677 .flat_map(|context| {
678 context
679 .iter()
680 .filter_map(|context_id| self.context.get(&context_id))
681 })
682 }
683
684 /// Returns whether all of the tool uses have finished running.
685 pub fn all_tools_finished(&self) -> bool {
686 // If the only pending tool uses left are the ones with errors, then
687 // that means that we've finished running all of the pending tools.
688 self.tool_use
689 .pending_tool_uses()
690 .iter()
691 .all(|tool_use| tool_use.status.is_error())
692 }
693
694 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
695 self.tool_use.tool_uses_for_message(id, cx)
696 }
697
698 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
699 self.tool_use.tool_results_for_message(id)
700 }
701
702 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
703 self.tool_use.tool_result(id)
704 }
705
706 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
707 Some(&self.tool_use.tool_result(id)?.content)
708 }
709
710 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
711 self.tool_use.tool_result_card(id).cloned()
712 }
713
714 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
715 self.tool_use.message_has_tool_results(message_id)
716 }
717
718 /// Filter out contexts that have already been included in previous messages
719 pub fn filter_new_context<'a>(
720 &self,
721 context: impl Iterator<Item = &'a AssistantContext>,
722 ) -> impl Iterator<Item = &'a AssistantContext> {
723 context.filter(|ctx| self.is_context_new(ctx))
724 }
725
726 fn is_context_new(&self, context: &AssistantContext) -> bool {
727 !self.context.contains_key(&context.id())
728 }
729
730 pub fn insert_user_message(
731 &mut self,
732 text: impl Into<String>,
733 context: Vec<AssistantContext>,
734 git_checkpoint: Option<GitStoreCheckpoint>,
735 cx: &mut Context<Self>,
736 ) -> MessageId {
737 let text = text.into();
738
739 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
740
741 let new_context: Vec<_> = context
742 .into_iter()
743 .filter(|ctx| self.is_context_new(ctx))
744 .collect();
745
746 if !new_context.is_empty() {
747 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
748 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
749 message.context = context_string;
750 }
751 }
752
753 self.action_log.update(cx, |log, cx| {
754 // Track all buffers added as context
755 for ctx in &new_context {
756 match ctx {
757 AssistantContext::File(file_ctx) => {
758 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
759 }
760 AssistantContext::Directory(dir_ctx) => {
761 for context_buffer in &dir_ctx.context_buffers {
762 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
763 }
764 }
765 AssistantContext::Symbol(symbol_ctx) => {
766 log.buffer_added_as_context(
767 symbol_ctx.context_symbol.buffer.clone(),
768 cx,
769 );
770 }
771 AssistantContext::Excerpt(excerpt_context) => {
772 log.buffer_added_as_context(
773 excerpt_context.context_buffer.buffer.clone(),
774 cx,
775 );
776 }
777 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
778 }
779 }
780 });
781 }
782
783 let context_ids = new_context
784 .iter()
785 .map(|context| context.id())
786 .collect::<Vec<_>>();
787 self.context.extend(
788 new_context
789 .into_iter()
790 .map(|context| (context.id(), context)),
791 );
792 self.context_by_message.insert(message_id, context_ids);
793
794 if let Some(git_checkpoint) = git_checkpoint {
795 self.pending_checkpoint = Some(ThreadCheckpoint {
796 message_id,
797 git_checkpoint,
798 });
799 }
800
801 self.auto_capture_telemetry(cx);
802
803 message_id
804 }
805
806 pub fn insert_message(
807 &mut self,
808 role: Role,
809 segments: Vec<MessageSegment>,
810 cx: &mut Context<Self>,
811 ) -> MessageId {
812 let id = self.next_message_id.post_inc();
813 self.messages.push(Message {
814 id,
815 role,
816 segments,
817 context: String::new(),
818 });
819 self.touch_updated_at();
820 cx.emit(ThreadEvent::MessageAdded(id));
821 id
822 }
823
824 pub fn edit_message(
825 &mut self,
826 id: MessageId,
827 new_role: Role,
828 new_segments: Vec<MessageSegment>,
829 cx: &mut Context<Self>,
830 ) -> bool {
831 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
832 return false;
833 };
834 message.role = new_role;
835 message.segments = new_segments;
836 self.touch_updated_at();
837 cx.emit(ThreadEvent::MessageEdited(id));
838 true
839 }
840
841 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
842 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
843 return false;
844 };
845 self.messages.remove(index);
846 self.context_by_message.remove(&id);
847 self.touch_updated_at();
848 cx.emit(ThreadEvent::MessageDeleted(id));
849 true
850 }
851
852 /// Returns the representation of this [`Thread`] in a textual form.
853 ///
854 /// This is the representation we use when attaching a thread as context to another thread.
855 pub fn text(&self) -> String {
856 let mut text = String::new();
857
858 for message in &self.messages {
859 text.push_str(match message.role {
860 language_model::Role::User => "User:",
861 language_model::Role::Assistant => "Assistant:",
862 language_model::Role::System => "System:",
863 });
864 text.push('\n');
865
866 for segment in &message.segments {
867 match segment {
868 MessageSegment::Text(content) => text.push_str(content),
869 MessageSegment::Thinking { text: content, .. } => {
870 text.push_str(&format!("<think>{}</think>", content))
871 }
872 MessageSegment::RedactedThinking(_) => {}
873 }
874 }
875 text.push('\n');
876 }
877
878 text
879 }
880
881 /// Serializes this thread into a format for storage or telemetry.
882 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
883 let initial_project_snapshot = self.initial_project_snapshot.clone();
884 cx.spawn(async move |this, cx| {
885 let initial_project_snapshot = initial_project_snapshot.await;
886 this.read_with(cx, |this, cx| SerializedThread {
887 version: SerializedThread::VERSION.to_string(),
888 summary: this.summary_or_default(),
889 updated_at: this.updated_at(),
890 messages: this
891 .messages()
892 .map(|message| SerializedMessage {
893 id: message.id,
894 role: message.role,
895 segments: message
896 .segments
897 .iter()
898 .map(|segment| match segment {
899 MessageSegment::Text(text) => {
900 SerializedMessageSegment::Text { text: text.clone() }
901 }
902 MessageSegment::Thinking { text, signature } => {
903 SerializedMessageSegment::Thinking {
904 text: text.clone(),
905 signature: signature.clone(),
906 }
907 }
908 MessageSegment::RedactedThinking(data) => {
909 SerializedMessageSegment::RedactedThinking {
910 data: data.clone(),
911 }
912 }
913 })
914 .collect(),
915 tool_uses: this
916 .tool_uses_for_message(message.id, cx)
917 .into_iter()
918 .map(|tool_use| SerializedToolUse {
919 id: tool_use.id,
920 name: tool_use.name,
921 input: tool_use.input,
922 })
923 .collect(),
924 tool_results: this
925 .tool_results_for_message(message.id)
926 .into_iter()
927 .map(|tool_result| SerializedToolResult {
928 tool_use_id: tool_result.tool_use_id.clone(),
929 is_error: tool_result.is_error,
930 content: tool_result.content.clone(),
931 })
932 .collect(),
933 context: message.context.clone(),
934 })
935 .collect(),
936 initial_project_snapshot,
937 cumulative_token_usage: this.cumulative_token_usage,
938 request_token_usage: this.request_token_usage.clone(),
939 detailed_summary_state: this.detailed_summary_state.clone(),
940 exceeded_window_error: this.exceeded_window_error.clone(),
941 })
942 })
943 }
944
945 pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
946 let mut request = self.to_completion_request(cx);
947 if model.supports_tools() {
948 request.tools = {
949 let mut tools = Vec::new();
950 tools.extend(
951 self.tools()
952 .read(cx)
953 .enabled_tools(cx)
954 .into_iter()
955 .filter_map(|tool| {
956 // Skip tools that cannot be supported
957 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
958 Some(LanguageModelRequestTool {
959 name: tool.name(),
960 description: tool.description(),
961 input_schema,
962 })
963 }),
964 );
965
966 tools
967 };
968 }
969
970 self.stream_completion(request, model, cx);
971 }
972
973 pub fn used_tools_since_last_user_message(&self) -> bool {
974 for message in self.messages.iter().rev() {
975 if self.tool_use.message_has_tool_results(message.id) {
976 return true;
977 } else if message.role == Role::User {
978 return false;
979 }
980 }
981
982 false
983 }
984
985 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
986 let mut request = LanguageModelRequest {
987 thread_id: Some(self.id.to_string()),
988 prompt_id: Some(self.last_prompt_id.to_string()),
989 messages: vec![],
990 tools: Vec::new(),
991 stop: Vec::new(),
992 temperature: None,
993 };
994
995 if let Some(project_context) = self.project_context.borrow().as_ref() {
996 match self
997 .prompt_builder
998 .generate_assistant_system_prompt(project_context)
999 {
1000 Err(err) => {
1001 let message = format!("{err:?}").into();
1002 log::error!("{message}");
1003 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1004 header: "Error generating system prompt".into(),
1005 message,
1006 }));
1007 }
1008 Ok(system_prompt) => {
1009 request.messages.push(LanguageModelRequestMessage {
1010 role: Role::System,
1011 content: vec![MessageContent::Text(system_prompt)],
1012 cache: true,
1013 });
1014 }
1015 }
1016 } else {
1017 let message = "Context for system prompt unexpectedly not ready.".into();
1018 log::error!("{message}");
1019 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1020 header: "Error generating system prompt".into(),
1021 message,
1022 }));
1023 }
1024
1025 for message in &self.messages {
1026 let mut request_message = LanguageModelRequestMessage {
1027 role: message.role,
1028 content: Vec::new(),
1029 cache: false,
1030 };
1031
1032 self.tool_use
1033 .attach_tool_results(message.id, &mut request_message);
1034
1035 if !message.context.is_empty() {
1036 request_message
1037 .content
1038 .push(MessageContent::Text(message.context.to_string()));
1039 }
1040
1041 for segment in &message.segments {
1042 match segment {
1043 MessageSegment::Text(text) => {
1044 if !text.is_empty() {
1045 request_message
1046 .content
1047 .push(MessageContent::Text(text.into()));
1048 }
1049 }
1050 MessageSegment::Thinking { text, signature } => {
1051 if !text.is_empty() {
1052 request_message.content.push(MessageContent::Thinking {
1053 text: text.into(),
1054 signature: signature.clone(),
1055 });
1056 }
1057 }
1058 MessageSegment::RedactedThinking(data) => {
1059 request_message
1060 .content
1061 .push(MessageContent::RedactedThinking(data.clone()));
1062 }
1063 };
1064 }
1065
1066 self.tool_use
1067 .attach_tool_uses(message.id, &mut request_message);
1068
1069 request.messages.push(request_message);
1070 }
1071
1072 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1073 if let Some(last) = request.messages.last_mut() {
1074 last.cache = true;
1075 }
1076
1077 self.attached_tracked_files_state(&mut request.messages, cx);
1078
1079 request
1080 }
1081
1082 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1083 let mut request = LanguageModelRequest {
1084 thread_id: None,
1085 prompt_id: None,
1086 messages: vec![],
1087 tools: Vec::new(),
1088 stop: Vec::new(),
1089 temperature: None,
1090 };
1091
1092 for message in &self.messages {
1093 let mut request_message = LanguageModelRequestMessage {
1094 role: message.role,
1095 content: Vec::new(),
1096 cache: false,
1097 };
1098
1099 // Skip tool results during summarization.
1100 if self.tool_use.message_has_tool_results(message.id) {
1101 continue;
1102 }
1103
1104 for segment in &message.segments {
1105 match segment {
1106 MessageSegment::Text(text) => request_message
1107 .content
1108 .push(MessageContent::Text(text.clone())),
1109 MessageSegment::Thinking { .. } => {}
1110 MessageSegment::RedactedThinking(_) => {}
1111 }
1112 }
1113
1114 if request_message.content.is_empty() {
1115 continue;
1116 }
1117
1118 request.messages.push(request_message);
1119 }
1120
1121 request.messages.push(LanguageModelRequestMessage {
1122 role: Role::User,
1123 content: vec![MessageContent::Text(added_user_message)],
1124 cache: false,
1125 });
1126
1127 request
1128 }
1129
1130 fn attached_tracked_files_state(
1131 &self,
1132 messages: &mut Vec<LanguageModelRequestMessage>,
1133 cx: &App,
1134 ) {
1135 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1136
1137 let mut stale_message = String::new();
1138
1139 let action_log = self.action_log.read(cx);
1140
1141 for stale_file in action_log.stale_buffers(cx) {
1142 let Some(file) = stale_file.read(cx).file() else {
1143 continue;
1144 };
1145
1146 if stale_message.is_empty() {
1147 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1148 }
1149
1150 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1151 }
1152
1153 let mut content = Vec::with_capacity(2);
1154
1155 if !stale_message.is_empty() {
1156 content.push(stale_message.into());
1157 }
1158
1159 if !content.is_empty() {
1160 let context_message = LanguageModelRequestMessage {
1161 role: Role::User,
1162 content,
1163 cache: false,
1164 };
1165
1166 messages.push(context_message);
1167 }
1168 }
1169
1170 pub fn stream_completion(
1171 &mut self,
1172 request: LanguageModelRequest,
1173 model: Arc<dyn LanguageModel>,
1174 cx: &mut Context<Self>,
1175 ) {
1176 let pending_completion_id = post_inc(&mut self.completion_count);
1177 let mut request_callback_parameters = if self.request_callback.is_some() {
1178 Some((request.clone(), Vec::new()))
1179 } else {
1180 None
1181 };
1182 let prompt_id = self.last_prompt_id.clone();
1183 let tool_use_metadata = ToolUseMetadata {
1184 model: model.clone(),
1185 thread_id: self.id.clone(),
1186 prompt_id: prompt_id.clone(),
1187 };
1188
1189 let task = cx.spawn(async move |thread, cx| {
1190 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1191 let initial_token_usage =
1192 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1193 let stream_completion = async {
1194 let (mut events, usage) = stream_completion_future.await?;
1195
1196 let mut stop_reason = StopReason::EndTurn;
1197 let mut current_token_usage = TokenUsage::default();
1198
1199 if let Some(usage) = usage {
1200 thread
1201 .update(cx, |_thread, cx| {
1202 cx.emit(ThreadEvent::UsageUpdated(usage));
1203 })
1204 .ok();
1205 }
1206
1207 while let Some(event) = events.next().await {
1208 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1209 response_events
1210 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1211 }
1212
1213 let event = event?;
1214
1215 thread.update(cx, |thread, cx| {
1216 match event {
1217 LanguageModelCompletionEvent::StartMessage { .. } => {
1218 thread.insert_message(
1219 Role::Assistant,
1220 vec![MessageSegment::Text(String::new())],
1221 cx,
1222 );
1223 }
1224 LanguageModelCompletionEvent::Stop(reason) => {
1225 stop_reason = reason;
1226 }
1227 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1228 thread.update_token_usage_at_last_message(token_usage);
1229 thread.cumulative_token_usage = thread.cumulative_token_usage
1230 + token_usage
1231 - current_token_usage;
1232 current_token_usage = token_usage;
1233 }
1234 LanguageModelCompletionEvent::Text(chunk) => {
1235 if let Some(last_message) = thread.messages.last_mut() {
1236 if last_message.role == Role::Assistant {
1237 last_message.push_text(&chunk);
1238 cx.emit(ThreadEvent::StreamedAssistantText(
1239 last_message.id,
1240 chunk,
1241 ));
1242 } else {
1243 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1244 // of a new Assistant response.
1245 //
1246 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1247 // will result in duplicating the text of the chunk in the rendered Markdown.
1248 thread.insert_message(
1249 Role::Assistant,
1250 vec![MessageSegment::Text(chunk.to_string())],
1251 cx,
1252 );
1253 };
1254 }
1255 }
1256 LanguageModelCompletionEvent::Thinking {
1257 text: chunk,
1258 signature,
1259 } => {
1260 if let Some(last_message) = thread.messages.last_mut() {
1261 if last_message.role == Role::Assistant {
1262 last_message.push_thinking(&chunk, signature);
1263 cx.emit(ThreadEvent::StreamedAssistantThinking(
1264 last_message.id,
1265 chunk,
1266 ));
1267 } else {
1268 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1269 // of a new Assistant response.
1270 //
1271 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1272 // will result in duplicating the text of the chunk in the rendered Markdown.
1273 thread.insert_message(
1274 Role::Assistant,
1275 vec![MessageSegment::Thinking {
1276 text: chunk.to_string(),
1277 signature,
1278 }],
1279 cx,
1280 );
1281 };
1282 }
1283 }
1284 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1285 let last_assistant_message_id = thread
1286 .messages
1287 .iter_mut()
1288 .rfind(|message| message.role == Role::Assistant)
1289 .map(|message| message.id)
1290 .unwrap_or_else(|| {
1291 thread.insert_message(Role::Assistant, vec![], cx)
1292 });
1293
1294 thread.tool_use.request_tool_use(
1295 last_assistant_message_id,
1296 tool_use,
1297 tool_use_metadata.clone(),
1298 cx,
1299 );
1300 }
1301 }
1302
1303 thread.touch_updated_at();
1304 cx.emit(ThreadEvent::StreamedCompletion);
1305 cx.notify();
1306
1307 thread.auto_capture_telemetry(cx);
1308 })?;
1309
1310 smol::future::yield_now().await;
1311 }
1312
1313 thread.update(cx, |thread, cx| {
1314 thread
1315 .pending_completions
1316 .retain(|completion| completion.id != pending_completion_id);
1317
1318 // If there is a response without tool use, summarize the message. Otherwise,
1319 // allow two tool uses before summarizing.
1320 if thread.summary.is_none()
1321 && thread.messages.len() >= 2
1322 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1323 {
1324 thread.summarize(cx);
1325 }
1326 })?;
1327
1328 anyhow::Ok(stop_reason)
1329 };
1330
1331 let result = stream_completion.await;
1332
1333 thread
1334 .update(cx, |thread, cx| {
1335 thread.finalize_pending_checkpoint(cx);
1336 match result.as_ref() {
1337 Ok(stop_reason) => match stop_reason {
1338 StopReason::ToolUse => {
1339 let tool_uses = thread.use_pending_tools(cx);
1340 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1341 }
1342 StopReason::EndTurn => {}
1343 StopReason::MaxTokens => {}
1344 },
1345 Err(error) => {
1346 if error.is::<PaymentRequiredError>() {
1347 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1348 } else if error.is::<MaxMonthlySpendReachedError>() {
1349 cx.emit(ThreadEvent::ShowError(
1350 ThreadError::MaxMonthlySpendReached,
1351 ));
1352 } else if let Some(error) =
1353 error.downcast_ref::<ModelRequestLimitReachedError>()
1354 {
1355 cx.emit(ThreadEvent::ShowError(
1356 ThreadError::ModelRequestLimitReached { plan: error.plan },
1357 ));
1358 } else if let Some(known_error) =
1359 error.downcast_ref::<LanguageModelKnownError>()
1360 {
1361 match known_error {
1362 LanguageModelKnownError::ContextWindowLimitExceeded {
1363 tokens,
1364 } => {
1365 thread.exceeded_window_error = Some(ExceededWindowError {
1366 model_id: model.id(),
1367 token_count: *tokens,
1368 });
1369 cx.notify();
1370 }
1371 }
1372 } else {
1373 let error_message = error
1374 .chain()
1375 .map(|err| err.to_string())
1376 .collect::<Vec<_>>()
1377 .join("\n");
1378 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1379 header: "Error interacting with language model".into(),
1380 message: SharedString::from(error_message.clone()),
1381 }));
1382 }
1383
1384 thread.cancel_last_completion(cx);
1385 }
1386 }
1387 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1388
1389 if let Some((request_callback, (request, response_events))) = thread
1390 .request_callback
1391 .as_mut()
1392 .zip(request_callback_parameters.as_ref())
1393 {
1394 request_callback(request, response_events);
1395 }
1396
1397 thread.auto_capture_telemetry(cx);
1398
1399 if let Ok(initial_usage) = initial_token_usage {
1400 let usage = thread.cumulative_token_usage - initial_usage;
1401
1402 telemetry::event!(
1403 "Assistant Thread Completion",
1404 thread_id = thread.id().to_string(),
1405 prompt_id = prompt_id,
1406 model = model.telemetry_id(),
1407 model_provider = model.provider_id().to_string(),
1408 input_tokens = usage.input_tokens,
1409 output_tokens = usage.output_tokens,
1410 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1411 cache_read_input_tokens = usage.cache_read_input_tokens,
1412 );
1413 }
1414 })
1415 .ok();
1416 });
1417
1418 self.pending_completions.push(PendingCompletion {
1419 id: pending_completion_id,
1420 _task: task,
1421 });
1422 }
1423
1424 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1425 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1426 return;
1427 };
1428
1429 if !model.provider.is_authenticated(cx) {
1430 return;
1431 }
1432
1433 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1434 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1435 If the conversation is about a specific subject, include it in the title. \
1436 Be descriptive. DO NOT speak in the first person.";
1437
1438 let request = self.to_summarize_request(added_user_message.into());
1439
1440 self.pending_summary = cx.spawn(async move |this, cx| {
1441 async move {
1442 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1443 let (mut messages, usage) = stream.await?;
1444
1445 if let Some(usage) = usage {
1446 this.update(cx, |_thread, cx| {
1447 cx.emit(ThreadEvent::UsageUpdated(usage));
1448 })
1449 .ok();
1450 }
1451
1452 let mut new_summary = String::new();
1453 while let Some(message) = messages.stream.next().await {
1454 let text = message?;
1455 let mut lines = text.lines();
1456 new_summary.extend(lines.next());
1457
1458 // Stop if the LLM generated multiple lines.
1459 if lines.next().is_some() {
1460 break;
1461 }
1462 }
1463
1464 this.update(cx, |this, cx| {
1465 if !new_summary.is_empty() {
1466 this.summary = Some(new_summary.into());
1467 }
1468
1469 cx.emit(ThreadEvent::SummaryGenerated);
1470 })?;
1471
1472 anyhow::Ok(())
1473 }
1474 .log_err()
1475 .await
1476 });
1477 }
1478
1479 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1480 let last_message_id = self.messages.last().map(|message| message.id)?;
1481
1482 match &self.detailed_summary_state {
1483 DetailedSummaryState::Generating { message_id, .. }
1484 | DetailedSummaryState::Generated { message_id, .. }
1485 if *message_id == last_message_id =>
1486 {
1487 // Already up-to-date
1488 return None;
1489 }
1490 _ => {}
1491 }
1492
1493 let ConfiguredModel { model, provider } =
1494 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1495
1496 if !provider.is_authenticated(cx) {
1497 return None;
1498 }
1499
1500 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1501 1. A brief overview of what was discussed\n\
1502 2. Key facts or information discovered\n\
1503 3. Outcomes or conclusions reached\n\
1504 4. Any action items or next steps if any\n\
1505 Format it in Markdown with headings and bullet points.";
1506
1507 let request = self.to_summarize_request(added_user_message.into());
1508
1509 let task = cx.spawn(async move |thread, cx| {
1510 let stream = model.stream_completion_text(request, &cx);
1511 let Some(mut messages) = stream.await.log_err() else {
1512 thread
1513 .update(cx, |this, _cx| {
1514 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1515 })
1516 .log_err();
1517
1518 return;
1519 };
1520
1521 let mut new_detailed_summary = String::new();
1522
1523 while let Some(chunk) = messages.stream.next().await {
1524 if let Some(chunk) = chunk.log_err() {
1525 new_detailed_summary.push_str(&chunk);
1526 }
1527 }
1528
1529 thread
1530 .update(cx, |this, _cx| {
1531 this.detailed_summary_state = DetailedSummaryState::Generated {
1532 text: new_detailed_summary.into(),
1533 message_id: last_message_id,
1534 };
1535 })
1536 .log_err();
1537 });
1538
1539 self.detailed_summary_state = DetailedSummaryState::Generating {
1540 message_id: last_message_id,
1541 };
1542
1543 Some(task)
1544 }
1545
1546 pub fn is_generating_detailed_summary(&self) -> bool {
1547 matches!(
1548 self.detailed_summary_state,
1549 DetailedSummaryState::Generating { .. }
1550 )
1551 }
1552
1553 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1554 self.auto_capture_telemetry(cx);
1555 let request = self.to_completion_request(cx);
1556 let messages = Arc::new(request.messages);
1557 let pending_tool_uses = self
1558 .tool_use
1559 .pending_tool_uses()
1560 .into_iter()
1561 .filter(|tool_use| tool_use.status.is_idle())
1562 .cloned()
1563 .collect::<Vec<_>>();
1564
1565 for tool_use in pending_tool_uses.iter() {
1566 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1567 if tool.needs_confirmation(&tool_use.input, cx)
1568 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1569 {
1570 self.tool_use.confirm_tool_use(
1571 tool_use.id.clone(),
1572 tool_use.ui_text.clone(),
1573 tool_use.input.clone(),
1574 messages.clone(),
1575 tool,
1576 );
1577 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1578 } else {
1579 self.run_tool(
1580 tool_use.id.clone(),
1581 tool_use.ui_text.clone(),
1582 tool_use.input.clone(),
1583 &messages,
1584 tool,
1585 cx,
1586 );
1587 }
1588 }
1589 }
1590
1591 pending_tool_uses
1592 }
1593
1594 pub fn run_tool(
1595 &mut self,
1596 tool_use_id: LanguageModelToolUseId,
1597 ui_text: impl Into<SharedString>,
1598 input: serde_json::Value,
1599 messages: &[LanguageModelRequestMessage],
1600 tool: Arc<dyn Tool>,
1601 cx: &mut Context<Thread>,
1602 ) {
1603 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1604 self.tool_use
1605 .run_pending_tool(tool_use_id, ui_text.into(), task);
1606 }
1607
1608 fn spawn_tool_use(
1609 &mut self,
1610 tool_use_id: LanguageModelToolUseId,
1611 messages: &[LanguageModelRequestMessage],
1612 input: serde_json::Value,
1613 tool: Arc<dyn Tool>,
1614 cx: &mut Context<Thread>,
1615 ) -> Task<()> {
1616 let tool_name: Arc<str> = tool.name().into();
1617
1618 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1619 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1620 } else {
1621 tool.run(
1622 input,
1623 messages,
1624 self.project.clone(),
1625 self.action_log.clone(),
1626 cx,
1627 )
1628 };
1629
1630 // Store the card separately if it exists
1631 if let Some(card) = tool_result.card.clone() {
1632 self.tool_use
1633 .insert_tool_result_card(tool_use_id.clone(), card);
1634 }
1635
1636 cx.spawn({
1637 async move |thread: WeakEntity<Thread>, cx| {
1638 let output = tool_result.output.await;
1639
1640 thread
1641 .update(cx, |thread, cx| {
1642 let pending_tool_use = thread.tool_use.insert_tool_output(
1643 tool_use_id.clone(),
1644 tool_name,
1645 output,
1646 cx,
1647 );
1648 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1649 })
1650 .ok();
1651 }
1652 })
1653 }
1654
1655 fn tool_finished(
1656 &mut self,
1657 tool_use_id: LanguageModelToolUseId,
1658 pending_tool_use: Option<PendingToolUse>,
1659 canceled: bool,
1660 cx: &mut Context<Self>,
1661 ) {
1662 if self.all_tools_finished() {
1663 let model_registry = LanguageModelRegistry::read_global(cx);
1664 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1665 self.attach_tool_results(cx);
1666 if !canceled {
1667 self.send_to_model(model, cx);
1668 }
1669 }
1670 }
1671
1672 cx.emit(ThreadEvent::ToolFinished {
1673 tool_use_id,
1674 pending_tool_use,
1675 });
1676 }
1677
1678 /// Insert an empty message to be populated with tool results upon send.
1679 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1680 // Tool results are assumed to be waiting on the next message id, so they will populate
1681 // this empty message before sending to model. Would prefer this to be more straightforward.
1682 self.insert_message(Role::User, vec![], cx);
1683 self.auto_capture_telemetry(cx);
1684 }
1685
1686 /// Cancels the last pending completion, if there are any pending.
1687 ///
1688 /// Returns whether a completion was canceled.
1689 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1690 let canceled = if self.pending_completions.pop().is_some() {
1691 true
1692 } else {
1693 let mut canceled = false;
1694 for pending_tool_use in self.tool_use.cancel_pending() {
1695 canceled = true;
1696 self.tool_finished(
1697 pending_tool_use.id.clone(),
1698 Some(pending_tool_use),
1699 true,
1700 cx,
1701 );
1702 }
1703 canceled
1704 };
1705 self.finalize_pending_checkpoint(cx);
1706 canceled
1707 }
1708
1709 pub fn feedback(&self) -> Option<ThreadFeedback> {
1710 self.feedback
1711 }
1712
1713 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1714 self.message_feedback.get(&message_id).copied()
1715 }
1716
1717 pub fn report_message_feedback(
1718 &mut self,
1719 message_id: MessageId,
1720 feedback: ThreadFeedback,
1721 cx: &mut Context<Self>,
1722 ) -> Task<Result<()>> {
1723 if self.message_feedback.get(&message_id) == Some(&feedback) {
1724 return Task::ready(Ok(()));
1725 }
1726
1727 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1728 let serialized_thread = self.serialize(cx);
1729 let thread_id = self.id().clone();
1730 let client = self.project.read(cx).client();
1731
1732 let enabled_tool_names: Vec<String> = self
1733 .tools()
1734 .read(cx)
1735 .enabled_tools(cx)
1736 .iter()
1737 .map(|tool| tool.name().to_string())
1738 .collect();
1739
1740 self.message_feedback.insert(message_id, feedback);
1741
1742 cx.notify();
1743
1744 let message_content = self
1745 .message(message_id)
1746 .map(|msg| msg.to_string())
1747 .unwrap_or_default();
1748
1749 cx.background_spawn(async move {
1750 let final_project_snapshot = final_project_snapshot.await;
1751 let serialized_thread = serialized_thread.await?;
1752 let thread_data =
1753 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1754
1755 let rating = match feedback {
1756 ThreadFeedback::Positive => "positive",
1757 ThreadFeedback::Negative => "negative",
1758 };
1759 telemetry::event!(
1760 "Assistant Thread Rated",
1761 rating,
1762 thread_id,
1763 enabled_tool_names,
1764 message_id = message_id.0,
1765 message_content,
1766 thread_data,
1767 final_project_snapshot
1768 );
1769 client.telemetry().flush_events();
1770
1771 Ok(())
1772 })
1773 }
1774
1775 pub fn report_feedback(
1776 &mut self,
1777 feedback: ThreadFeedback,
1778 cx: &mut Context<Self>,
1779 ) -> Task<Result<()>> {
1780 let last_assistant_message_id = self
1781 .messages
1782 .iter()
1783 .rev()
1784 .find(|msg| msg.role == Role::Assistant)
1785 .map(|msg| msg.id);
1786
1787 if let Some(message_id) = last_assistant_message_id {
1788 self.report_message_feedback(message_id, feedback, cx)
1789 } else {
1790 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1791 let serialized_thread = self.serialize(cx);
1792 let thread_id = self.id().clone();
1793 let client = self.project.read(cx).client();
1794 self.feedback = Some(feedback);
1795 cx.notify();
1796
1797 cx.background_spawn(async move {
1798 let final_project_snapshot = final_project_snapshot.await;
1799 let serialized_thread = serialized_thread.await?;
1800 let thread_data = serde_json::to_value(serialized_thread)
1801 .unwrap_or_else(|_| serde_json::Value::Null);
1802
1803 let rating = match feedback {
1804 ThreadFeedback::Positive => "positive",
1805 ThreadFeedback::Negative => "negative",
1806 };
1807 telemetry::event!(
1808 "Assistant Thread Rated",
1809 rating,
1810 thread_id,
1811 thread_data,
1812 final_project_snapshot
1813 );
1814 client.telemetry().flush_events();
1815
1816 Ok(())
1817 })
1818 }
1819 }
1820
1821 /// Create a snapshot of the current project state including git information and unsaved buffers.
1822 fn project_snapshot(
1823 project: Entity<Project>,
1824 cx: &mut Context<Self>,
1825 ) -> Task<Arc<ProjectSnapshot>> {
1826 let git_store = project.read(cx).git_store().clone();
1827 let worktree_snapshots: Vec<_> = project
1828 .read(cx)
1829 .visible_worktrees(cx)
1830 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1831 .collect();
1832
1833 cx.spawn(async move |_, cx| {
1834 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1835
1836 let mut unsaved_buffers = Vec::new();
1837 cx.update(|app_cx| {
1838 let buffer_store = project.read(app_cx).buffer_store();
1839 for buffer_handle in buffer_store.read(app_cx).buffers() {
1840 let buffer = buffer_handle.read(app_cx);
1841 if buffer.is_dirty() {
1842 if let Some(file) = buffer.file() {
1843 let path = file.path().to_string_lossy().to_string();
1844 unsaved_buffers.push(path);
1845 }
1846 }
1847 }
1848 })
1849 .ok();
1850
1851 Arc::new(ProjectSnapshot {
1852 worktree_snapshots,
1853 unsaved_buffer_paths: unsaved_buffers,
1854 timestamp: Utc::now(),
1855 })
1856 })
1857 }
1858
1859 fn worktree_snapshot(
1860 worktree: Entity<project::Worktree>,
1861 git_store: Entity<GitStore>,
1862 cx: &App,
1863 ) -> Task<WorktreeSnapshot> {
1864 cx.spawn(async move |cx| {
1865 // Get worktree path and snapshot
1866 let worktree_info = cx.update(|app_cx| {
1867 let worktree = worktree.read(app_cx);
1868 let path = worktree.abs_path().to_string_lossy().to_string();
1869 let snapshot = worktree.snapshot();
1870 (path, snapshot)
1871 });
1872
1873 let Ok((worktree_path, _snapshot)) = worktree_info else {
1874 return WorktreeSnapshot {
1875 worktree_path: String::new(),
1876 git_state: None,
1877 };
1878 };
1879
1880 let git_state = git_store
1881 .update(cx, |git_store, cx| {
1882 git_store
1883 .repositories()
1884 .values()
1885 .find(|repo| {
1886 repo.read(cx)
1887 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1888 .is_some()
1889 })
1890 .cloned()
1891 })
1892 .ok()
1893 .flatten()
1894 .map(|repo| {
1895 repo.update(cx, |repo, _| {
1896 let current_branch =
1897 repo.branch.as_ref().map(|branch| branch.name.to_string());
1898 repo.send_job(None, |state, _| async move {
1899 let RepositoryState::Local { backend, .. } = state else {
1900 return GitState {
1901 remote_url: None,
1902 head_sha: None,
1903 current_branch,
1904 diff: None,
1905 };
1906 };
1907
1908 let remote_url = backend.remote_url("origin");
1909 let head_sha = backend.head_sha();
1910 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1911
1912 GitState {
1913 remote_url,
1914 head_sha,
1915 current_branch,
1916 diff,
1917 }
1918 })
1919 })
1920 });
1921
1922 let git_state = match git_state {
1923 Some(git_state) => match git_state.ok() {
1924 Some(git_state) => git_state.await.ok(),
1925 None => None,
1926 },
1927 None => None,
1928 };
1929
1930 WorktreeSnapshot {
1931 worktree_path,
1932 git_state,
1933 }
1934 })
1935 }
1936
1937 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1938 let mut markdown = Vec::new();
1939
1940 if let Some(summary) = self.summary() {
1941 writeln!(markdown, "# {summary}\n")?;
1942 };
1943
1944 for message in self.messages() {
1945 writeln!(
1946 markdown,
1947 "## {role}\n",
1948 role = match message.role {
1949 Role::User => "User",
1950 Role::Assistant => "Assistant",
1951 Role::System => "System",
1952 }
1953 )?;
1954
1955 if !message.context.is_empty() {
1956 writeln!(markdown, "{}", message.context)?;
1957 }
1958
1959 for segment in &message.segments {
1960 match segment {
1961 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1962 MessageSegment::Thinking { text, .. } => {
1963 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1964 }
1965 MessageSegment::RedactedThinking(_) => {}
1966 }
1967 }
1968
1969 for tool_use in self.tool_uses_for_message(message.id, cx) {
1970 writeln!(
1971 markdown,
1972 "**Use Tool: {} ({})**",
1973 tool_use.name, tool_use.id
1974 )?;
1975 writeln!(markdown, "```json")?;
1976 writeln!(
1977 markdown,
1978 "{}",
1979 serde_json::to_string_pretty(&tool_use.input)?
1980 )?;
1981 writeln!(markdown, "```")?;
1982 }
1983
1984 for tool_result in self.tool_results_for_message(message.id) {
1985 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1986 if tool_result.is_error {
1987 write!(markdown, " (Error)")?;
1988 }
1989
1990 writeln!(markdown, "**\n")?;
1991 writeln!(markdown, "{}", tool_result.content)?;
1992 }
1993 }
1994
1995 Ok(String::from_utf8_lossy(&markdown).to_string())
1996 }
1997
1998 pub fn keep_edits_in_range(
1999 &mut self,
2000 buffer: Entity<language::Buffer>,
2001 buffer_range: Range<language::Anchor>,
2002 cx: &mut Context<Self>,
2003 ) {
2004 self.action_log.update(cx, |action_log, cx| {
2005 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2006 });
2007 }
2008
2009 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2010 self.action_log
2011 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2012 }
2013
2014 pub fn reject_edits_in_ranges(
2015 &mut self,
2016 buffer: Entity<language::Buffer>,
2017 buffer_ranges: Vec<Range<language::Anchor>>,
2018 cx: &mut Context<Self>,
2019 ) -> Task<Result<()>> {
2020 self.action_log.update(cx, |action_log, cx| {
2021 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2022 })
2023 }
2024
2025 pub fn action_log(&self) -> &Entity<ActionLog> {
2026 &self.action_log
2027 }
2028
2029 pub fn project(&self) -> &Entity<Project> {
2030 &self.project
2031 }
2032
2033 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2034 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2035 return;
2036 }
2037
2038 let now = Instant::now();
2039 if let Some(last) = self.last_auto_capture_at {
2040 if now.duration_since(last).as_secs() < 10 {
2041 return;
2042 }
2043 }
2044
2045 self.last_auto_capture_at = Some(now);
2046
2047 let thread_id = self.id().clone();
2048 let github_login = self
2049 .project
2050 .read(cx)
2051 .user_store()
2052 .read(cx)
2053 .current_user()
2054 .map(|user| user.github_login.clone());
2055 let client = self.project.read(cx).client().clone();
2056 let serialize_task = self.serialize(cx);
2057
2058 cx.background_executor()
2059 .spawn(async move {
2060 if let Ok(serialized_thread) = serialize_task.await {
2061 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2062 telemetry::event!(
2063 "Agent Thread Auto-Captured",
2064 thread_id = thread_id.to_string(),
2065 thread_data = thread_data,
2066 auto_capture_reason = "tracked_user",
2067 github_login = github_login
2068 );
2069
2070 client.telemetry().flush_events();
2071 }
2072 }
2073 })
2074 .detach();
2075 }
2076
2077 pub fn cumulative_token_usage(&self) -> TokenUsage {
2078 self.cumulative_token_usage
2079 }
2080
2081 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2082 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2083 return TotalTokenUsage::default();
2084 };
2085
2086 let max = model.model.max_token_count();
2087
2088 let index = self
2089 .messages
2090 .iter()
2091 .position(|msg| msg.id == message_id)
2092 .unwrap_or(0);
2093
2094 if index == 0 {
2095 return TotalTokenUsage { total: 0, max };
2096 }
2097
2098 let token_usage = &self
2099 .request_token_usage
2100 .get(index - 1)
2101 .cloned()
2102 .unwrap_or_default();
2103
2104 TotalTokenUsage {
2105 total: token_usage.total_tokens() as usize,
2106 max,
2107 }
2108 }
2109
2110 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2111 let model_registry = LanguageModelRegistry::read_global(cx);
2112 let Some(model) = model_registry.default_model() else {
2113 return TotalTokenUsage::default();
2114 };
2115
2116 let max = model.model.max_token_count();
2117
2118 if let Some(exceeded_error) = &self.exceeded_window_error {
2119 if model.model.id() == exceeded_error.model_id {
2120 return TotalTokenUsage {
2121 total: exceeded_error.token_count,
2122 max,
2123 };
2124 }
2125 }
2126
2127 let total = self
2128 .token_usage_at_last_message()
2129 .unwrap_or_default()
2130 .total_tokens() as usize;
2131
2132 TotalTokenUsage { total, max }
2133 }
2134
2135 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2136 self.request_token_usage
2137 .get(self.messages.len().saturating_sub(1))
2138 .or_else(|| self.request_token_usage.last())
2139 .cloned()
2140 }
2141
2142 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2143 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2144 self.request_token_usage
2145 .resize(self.messages.len(), placeholder);
2146
2147 if let Some(last) = self.request_token_usage.last_mut() {
2148 *last = token_usage;
2149 }
2150 }
2151
2152 pub fn deny_tool_use(
2153 &mut self,
2154 tool_use_id: LanguageModelToolUseId,
2155 tool_name: Arc<str>,
2156 cx: &mut Context<Self>,
2157 ) {
2158 let err = Err(anyhow::anyhow!(
2159 "Permission to run tool action denied by user"
2160 ));
2161
2162 self.tool_use
2163 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2164 self.tool_finished(tool_use_id.clone(), None, true, cx);
2165 }
2166}
2167
2168#[derive(Debug, Clone, Error)]
2169pub enum ThreadError {
2170 #[error("Payment required")]
2171 PaymentRequired,
2172 #[error("Max monthly spend reached")]
2173 MaxMonthlySpendReached,
2174 #[error("Model request limit reached")]
2175 ModelRequestLimitReached { plan: Plan },
2176 #[error("Message {header}: {message}")]
2177 Message {
2178 header: SharedString,
2179 message: SharedString,
2180 },
2181}
2182
2183#[derive(Debug, Clone)]
2184pub enum ThreadEvent {
2185 ShowError(ThreadError),
2186 UsageUpdated(RequestUsage),
2187 StreamedCompletion,
2188 StreamedAssistantText(MessageId, String),
2189 StreamedAssistantThinking(MessageId, String),
2190 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2191 MessageAdded(MessageId),
2192 MessageEdited(MessageId),
2193 MessageDeleted(MessageId),
2194 SummaryGenerated,
2195 SummaryChanged,
2196 UsePendingTools {
2197 tool_uses: Vec<PendingToolUse>,
2198 },
2199 ToolFinished {
2200 #[allow(unused)]
2201 tool_use_id: LanguageModelToolUseId,
2202 /// The pending tool use that corresponds to this tool.
2203 pending_tool_use: Option<PendingToolUse>,
2204 },
2205 CheckpointChanged,
2206 ToolConfirmationNeeded,
2207}
2208
2209impl EventEmitter<ThreadEvent> for Thread {}
2210
2211struct PendingCompletion {
2212 id: usize,
2213 _task: Task<()>,
2214}
2215
2216#[cfg(test)]
2217mod tests {
2218 use super::*;
2219 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2220 use assistant_settings::AssistantSettings;
2221 use context_server::ContextServerSettings;
2222 use editor::EditorSettings;
2223 use gpui::TestAppContext;
2224 use project::{FakeFs, Project};
2225 use prompt_store::PromptBuilder;
2226 use serde_json::json;
2227 use settings::{Settings, SettingsStore};
2228 use std::sync::Arc;
2229 use theme::ThemeSettings;
2230 use util::path;
2231 use workspace::Workspace;
2232
2233 #[gpui::test]
2234 async fn test_message_with_context(cx: &mut TestAppContext) {
2235 init_test_settings(cx);
2236
2237 let project = create_test_project(
2238 cx,
2239 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2240 )
2241 .await;
2242
2243 let (_workspace, _thread_store, thread, context_store) =
2244 setup_test_environment(cx, project.clone()).await;
2245
2246 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2247 .await
2248 .unwrap();
2249
2250 let context =
2251 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2252
2253 // Insert user message with context
2254 let message_id = thread.update(cx, |thread, cx| {
2255 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2256 });
2257
2258 // Check content and context in message object
2259 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2260
2261 // Use different path format strings based on platform for the test
2262 #[cfg(windows)]
2263 let path_part = r"test\code.rs";
2264 #[cfg(not(windows))]
2265 let path_part = "test/code.rs";
2266
2267 let expected_context = format!(
2268 r#"
2269<context>
2270The following items were attached by the user. You don't need to use other tools to read them.
2271
2272<files>
2273```rs {path_part}
2274fn main() {{
2275 println!("Hello, world!");
2276}}
2277```
2278</files>
2279</context>
2280"#
2281 );
2282
2283 assert_eq!(message.role, Role::User);
2284 assert_eq!(message.segments.len(), 1);
2285 assert_eq!(
2286 message.segments[0],
2287 MessageSegment::Text("Please explain this code".to_string())
2288 );
2289 assert_eq!(message.context, expected_context);
2290
2291 // Check message in request
2292 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2293
2294 assert_eq!(request.messages.len(), 2);
2295 let expected_full_message = format!("{}Please explain this code", expected_context);
2296 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2297 }
2298
2299 #[gpui::test]
2300 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2301 init_test_settings(cx);
2302
2303 let project = create_test_project(
2304 cx,
2305 json!({
2306 "file1.rs": "fn function1() {}\n",
2307 "file2.rs": "fn function2() {}\n",
2308 "file3.rs": "fn function3() {}\n",
2309 }),
2310 )
2311 .await;
2312
2313 let (_, _thread_store, thread, context_store) =
2314 setup_test_environment(cx, project.clone()).await;
2315
2316 // Open files individually
2317 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2318 .await
2319 .unwrap();
2320 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2321 .await
2322 .unwrap();
2323 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2324 .await
2325 .unwrap();
2326
2327 // Get the context objects
2328 let contexts = context_store.update(cx, |store, _| store.context().clone());
2329 assert_eq!(contexts.len(), 3);
2330
2331 // First message with context 1
2332 let message1_id = thread.update(cx, |thread, cx| {
2333 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2334 });
2335
2336 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2337 let message2_id = thread.update(cx, |thread, cx| {
2338 thread.insert_user_message(
2339 "Message 2",
2340 vec![contexts[0].clone(), contexts[1].clone()],
2341 None,
2342 cx,
2343 )
2344 });
2345
2346 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2347 let message3_id = thread.update(cx, |thread, cx| {
2348 thread.insert_user_message(
2349 "Message 3",
2350 vec![
2351 contexts[0].clone(),
2352 contexts[1].clone(),
2353 contexts[2].clone(),
2354 ],
2355 None,
2356 cx,
2357 )
2358 });
2359
2360 // Check what contexts are included in each message
2361 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2362 (
2363 thread.message(message1_id).unwrap().clone(),
2364 thread.message(message2_id).unwrap().clone(),
2365 thread.message(message3_id).unwrap().clone(),
2366 )
2367 });
2368
2369 // First message should include context 1
2370 assert!(message1.context.contains("file1.rs"));
2371
2372 // Second message should include only context 2 (not 1)
2373 assert!(!message2.context.contains("file1.rs"));
2374 assert!(message2.context.contains("file2.rs"));
2375
2376 // Third message should include only context 3 (not 1 or 2)
2377 assert!(!message3.context.contains("file1.rs"));
2378 assert!(!message3.context.contains("file2.rs"));
2379 assert!(message3.context.contains("file3.rs"));
2380
2381 // Check entire request to make sure all contexts are properly included
2382 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2383
2384 // The request should contain all 3 messages
2385 assert_eq!(request.messages.len(), 4);
2386
2387 // Check that the contexts are properly formatted in each message
2388 assert!(request.messages[1].string_contents().contains("file1.rs"));
2389 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2390 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2391
2392 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2393 assert!(request.messages[2].string_contents().contains("file2.rs"));
2394 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2395
2396 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2397 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2398 assert!(request.messages[3].string_contents().contains("file3.rs"));
2399 }
2400
2401 #[gpui::test]
2402 async fn test_message_without_files(cx: &mut TestAppContext) {
2403 init_test_settings(cx);
2404
2405 let project = create_test_project(
2406 cx,
2407 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2408 )
2409 .await;
2410
2411 let (_, _thread_store, thread, _context_store) =
2412 setup_test_environment(cx, project.clone()).await;
2413
2414 // Insert user message without any context (empty context vector)
2415 let message_id = thread.update(cx, |thread, cx| {
2416 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2417 });
2418
2419 // Check content and context in message object
2420 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2421
2422 // Context should be empty when no files are included
2423 assert_eq!(message.role, Role::User);
2424 assert_eq!(message.segments.len(), 1);
2425 assert_eq!(
2426 message.segments[0],
2427 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2428 );
2429 assert_eq!(message.context, "");
2430
2431 // Check message in request
2432 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2433
2434 assert_eq!(request.messages.len(), 2);
2435 assert_eq!(
2436 request.messages[1].string_contents(),
2437 "What is the best way to learn Rust?"
2438 );
2439
2440 // Add second message, also without context
2441 let message2_id = thread.update(cx, |thread, cx| {
2442 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2443 });
2444
2445 let message2 =
2446 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2447 assert_eq!(message2.context, "");
2448
2449 // Check that both messages appear in the request
2450 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2451
2452 assert_eq!(request.messages.len(), 3);
2453 assert_eq!(
2454 request.messages[1].string_contents(),
2455 "What is the best way to learn Rust?"
2456 );
2457 assert_eq!(
2458 request.messages[2].string_contents(),
2459 "Are there any good books?"
2460 );
2461 }
2462
2463 #[gpui::test]
2464 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2465 init_test_settings(cx);
2466
2467 let project = create_test_project(
2468 cx,
2469 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2470 )
2471 .await;
2472
2473 let (_workspace, _thread_store, thread, context_store) =
2474 setup_test_environment(cx, project.clone()).await;
2475
2476 // Open buffer and add it to context
2477 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2478 .await
2479 .unwrap();
2480
2481 let context =
2482 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2483
2484 // Insert user message with the buffer as context
2485 thread.update(cx, |thread, cx| {
2486 thread.insert_user_message("Explain this code", vec![context], None, cx)
2487 });
2488
2489 // Create a request and check that it doesn't have a stale buffer warning yet
2490 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2491
2492 // Make sure we don't have a stale file warning yet
2493 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2494 msg.string_contents()
2495 .contains("These files changed since last read:")
2496 });
2497 assert!(
2498 !has_stale_warning,
2499 "Should not have stale buffer warning before buffer is modified"
2500 );
2501
2502 // Modify the buffer
2503 buffer.update(cx, |buffer, cx| {
2504 // Find a position at the end of line 1
2505 buffer.edit(
2506 [(1..1, "\n println!(\"Added a new line\");\n")],
2507 None,
2508 cx,
2509 );
2510 });
2511
2512 // Insert another user message without context
2513 thread.update(cx, |thread, cx| {
2514 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2515 });
2516
2517 // Create a new request and check for the stale buffer warning
2518 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2519
2520 // We should have a stale file warning as the last message
2521 let last_message = new_request
2522 .messages
2523 .last()
2524 .expect("Request should have messages");
2525
2526 // The last message should be the stale buffer notification
2527 assert_eq!(last_message.role, Role::User);
2528
2529 // Check the exact content of the message
2530 let expected_content = "These files changed since last read:\n- code.rs\n";
2531 assert_eq!(
2532 last_message.string_contents(),
2533 expected_content,
2534 "Last message should be exactly the stale buffer notification"
2535 );
2536 }
2537
2538 fn init_test_settings(cx: &mut TestAppContext) {
2539 cx.update(|cx| {
2540 let settings_store = SettingsStore::test(cx);
2541 cx.set_global(settings_store);
2542 language::init(cx);
2543 Project::init_settings(cx);
2544 AssistantSettings::register(cx);
2545 prompt_store::init(cx);
2546 thread_store::init(cx);
2547 workspace::init_settings(cx);
2548 ThemeSettings::register(cx);
2549 ContextServerSettings::register(cx);
2550 EditorSettings::register(cx);
2551 });
2552 }
2553
2554 // Helper to create a test project with test files
2555 async fn create_test_project(
2556 cx: &mut TestAppContext,
2557 files: serde_json::Value,
2558 ) -> Entity<Project> {
2559 let fs = FakeFs::new(cx.executor());
2560 fs.insert_tree(path!("/test"), files).await;
2561 Project::test(fs, [path!("/test").as_ref()], cx).await
2562 }
2563
2564 async fn setup_test_environment(
2565 cx: &mut TestAppContext,
2566 project: Entity<Project>,
2567 ) -> (
2568 Entity<Workspace>,
2569 Entity<ThreadStore>,
2570 Entity<Thread>,
2571 Entity<ContextStore>,
2572 ) {
2573 let (workspace, cx) =
2574 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2575
2576 let thread_store = cx
2577 .update(|_, cx| {
2578 ThreadStore::load(
2579 project.clone(),
2580 cx.new(|_| ToolWorkingSet::default()),
2581 Arc::new(PromptBuilder::new(None).unwrap()),
2582 cx,
2583 )
2584 })
2585 .await
2586 .unwrap();
2587
2588 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2589 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2590
2591 (workspace, thread_store, thread, context_store)
2592 }
2593
2594 async fn add_file_to_context(
2595 project: &Entity<Project>,
2596 context_store: &Entity<ContextStore>,
2597 path: &str,
2598 cx: &mut TestAppContext,
2599 ) -> Result<Entity<language::Buffer>> {
2600 let buffer_path = project
2601 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2602 .unwrap();
2603
2604 let buffer = project
2605 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2606 .await
2607 .unwrap();
2608
2609 context_store
2610 .update(cx, |store, cx| {
2611 store.add_file_from_buffer(buffer.clone(), cx)
2612 })
2613 .await?;
2614
2615 Ok(buffer)
2616 }
2617}