1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
19 LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
20 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
21 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
22 ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
23};
24use project::Project;
25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
26use prompt_store::PromptBuilder;
27use proto::Plan;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::Settings;
31use thiserror::Error;
32use util::{ResultExt as _, TryFutureExt as _, post_inc};
33use uuid::Uuid;
34use zed_llm_client::UsageLimit;
35
36use crate::context::{AssistantContext, ContextId, format_context_as_string};
37use crate::thread_store::{
38 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
39 SerializedToolUse, SharedProjectContext,
40};
41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
42
43#[derive(Debug, Clone, Copy)]
44pub enum RequestKind {
45 Chat,
46 /// Used when summarizing a thread.
47 Summarize,
48}
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
74pub struct MessageId(pub(crate) usize);
75
76impl MessageId {
77 fn post_inc(&mut self) -> Self {
78 Self(post_inc(&mut self.0))
79 }
80}
81
82/// A message in a [`Thread`].
83#[derive(Debug, Clone)]
84pub struct Message {
85 pub id: MessageId,
86 pub role: Role,
87 pub segments: Vec<MessageSegment>,
88 pub context: String,
89}
90
91impl Message {
92 /// Returns whether the message contains any meaningful text that should be displayed
93 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
94 pub fn should_display_content(&self) -> bool {
95 self.segments.iter().all(|segment| segment.should_display())
96 }
97
98 pub fn push_thinking(&mut self, text: &str) {
99 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
100 segment.push_str(text);
101 } else {
102 self.segments
103 .push(MessageSegment::Thinking(text.to_string()));
104 }
105 }
106
107 pub fn push_text(&mut self, text: &str) {
108 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
109 segment.push_str(text);
110 } else {
111 self.segments.push(MessageSegment::Text(text.to_string()));
112 }
113 }
114
115 pub fn to_string(&self) -> String {
116 let mut result = String::new();
117
118 if !self.context.is_empty() {
119 result.push_str(&self.context);
120 }
121
122 for segment in &self.segments {
123 match segment {
124 MessageSegment::Text(text) => result.push_str(text),
125 MessageSegment::Thinking(text) => {
126 result.push_str("<think>");
127 result.push_str(text);
128 result.push_str("</think>");
129 }
130 }
131 }
132
133 result
134 }
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum MessageSegment {
139 Text(String),
140 Thinking(String),
141}
142
143impl MessageSegment {
144 pub fn text_mut(&mut self) -> &mut String {
145 match self {
146 Self::Text(text) => text,
147 Self::Thinking(text) => text,
148 }
149 }
150
151 pub fn should_display(&self) -> bool {
152 // We add USING_TOOL_MARKER when making a request that includes tool uses
153 // without non-whitespace text around them, and this can cause the model
154 // to mimic the pattern, so we consider those segments not displayable.
155 match self {
156 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
157 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
158 }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct ProjectSnapshot {
164 pub worktree_snapshots: Vec<WorktreeSnapshot>,
165 pub unsaved_buffer_paths: Vec<String>,
166 pub timestamp: DateTime<Utc>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct WorktreeSnapshot {
171 pub worktree_path: String,
172 pub git_state: Option<GitState>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct GitState {
177 pub remote_url: Option<String>,
178 pub head_sha: Option<String>,
179 pub current_branch: Option<String>,
180 pub diff: Option<String>,
181}
182
183#[derive(Clone)]
184pub struct ThreadCheckpoint {
185 message_id: MessageId,
186 git_checkpoint: GitStoreCheckpoint,
187}
188
189#[derive(Copy, Clone, Debug, PartialEq, Eq)]
190pub enum ThreadFeedback {
191 Positive,
192 Negative,
193}
194
195pub enum LastRestoreCheckpoint {
196 Pending {
197 message_id: MessageId,
198 },
199 Error {
200 message_id: MessageId,
201 error: String,
202 },
203}
204
205impl LastRestoreCheckpoint {
206 pub fn message_id(&self) -> MessageId {
207 match self {
208 LastRestoreCheckpoint::Pending { message_id } => *message_id,
209 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
210 }
211 }
212}
213
214#[derive(Clone, Debug, Default, Serialize, Deserialize)]
215pub enum DetailedSummaryState {
216 #[default]
217 NotGenerated,
218 Generating {
219 message_id: MessageId,
220 },
221 Generated {
222 text: SharedString,
223 message_id: MessageId,
224 },
225}
226
227#[derive(Default)]
228pub struct TotalTokenUsage {
229 pub total: usize,
230 pub max: usize,
231}
232
233impl TotalTokenUsage {
234 pub fn ratio(&self) -> TokenUsageRatio {
235 #[cfg(debug_assertions)]
236 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
237 .unwrap_or("0.8".to_string())
238 .parse()
239 .unwrap();
240 #[cfg(not(debug_assertions))]
241 let warning_threshold: f32 = 0.8;
242
243 if self.total >= self.max {
244 TokenUsageRatio::Exceeded
245 } else if self.total as f32 / self.max as f32 >= warning_threshold {
246 TokenUsageRatio::Warning
247 } else {
248 TokenUsageRatio::Normal
249 }
250 }
251
252 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
253 TotalTokenUsage {
254 total: self.total + tokens,
255 max: self.max,
256 }
257 }
258}
259
260#[derive(Debug, Default, PartialEq, Eq)]
261pub enum TokenUsageRatio {
262 #[default]
263 Normal,
264 Warning,
265 Exceeded,
266}
267
268/// A thread of conversation with the LLM.
269pub struct Thread {
270 id: ThreadId,
271 updated_at: DateTime<Utc>,
272 summary: Option<SharedString>,
273 pending_summary: Task<Option<()>>,
274 detailed_summary_state: DetailedSummaryState,
275 messages: Vec<Message>,
276 next_message_id: MessageId,
277 context: BTreeMap<ContextId, AssistantContext>,
278 context_by_message: HashMap<MessageId, Vec<ContextId>>,
279 project_context: SharedProjectContext,
280 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
281 completion_count: usize,
282 pending_completions: Vec<PendingCompletion>,
283 project: Entity<Project>,
284 prompt_builder: Arc<PromptBuilder>,
285 tools: Entity<ToolWorkingSet>,
286 tool_use: ToolUseState,
287 action_log: Entity<ActionLog>,
288 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
289 pending_checkpoint: Option<ThreadCheckpoint>,
290 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
291 request_token_usage: Vec<TokenUsage>,
292 cumulative_token_usage: TokenUsage,
293 exceeded_window_error: Option<ExceededWindowError>,
294 feedback: Option<ThreadFeedback>,
295 message_feedback: HashMap<MessageId, ThreadFeedback>,
296 last_auto_capture_at: Option<Instant>,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct ExceededWindowError {
301 /// Model used when last message exceeded context window
302 model_id: LanguageModelId,
303 /// Token count including last message
304 token_count: usize,
305}
306
307impl Thread {
308 pub fn new(
309 project: Entity<Project>,
310 tools: Entity<ToolWorkingSet>,
311 prompt_builder: Arc<PromptBuilder>,
312 system_prompt: SharedProjectContext,
313 cx: &mut Context<Self>,
314 ) -> Self {
315 Self {
316 id: ThreadId::new(),
317 updated_at: Utc::now(),
318 summary: None,
319 pending_summary: Task::ready(None),
320 detailed_summary_state: DetailedSummaryState::NotGenerated,
321 messages: Vec::new(),
322 next_message_id: MessageId(0),
323 context: BTreeMap::default(),
324 context_by_message: HashMap::default(),
325 project_context: system_prompt,
326 checkpoints_by_message: HashMap::default(),
327 completion_count: 0,
328 pending_completions: Vec::new(),
329 project: project.clone(),
330 prompt_builder,
331 tools: tools.clone(),
332 last_restore_checkpoint: None,
333 pending_checkpoint: None,
334 tool_use: ToolUseState::new(tools.clone()),
335 action_log: cx.new(|_| ActionLog::new(project.clone())),
336 initial_project_snapshot: {
337 let project_snapshot = Self::project_snapshot(project, cx);
338 cx.foreground_executor()
339 .spawn(async move { Some(project_snapshot.await) })
340 .shared()
341 },
342 request_token_usage: Vec::new(),
343 cumulative_token_usage: TokenUsage::default(),
344 exceeded_window_error: None,
345 feedback: None,
346 message_feedback: HashMap::default(),
347 last_auto_capture_at: None,
348 }
349 }
350
351 pub fn deserialize(
352 id: ThreadId,
353 serialized: SerializedThread,
354 project: Entity<Project>,
355 tools: Entity<ToolWorkingSet>,
356 prompt_builder: Arc<PromptBuilder>,
357 project_context: SharedProjectContext,
358 cx: &mut Context<Self>,
359 ) -> Self {
360 let next_message_id = MessageId(
361 serialized
362 .messages
363 .last()
364 .map(|message| message.id.0 + 1)
365 .unwrap_or(0),
366 );
367 let tool_use =
368 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
369
370 Self {
371 id,
372 updated_at: serialized.updated_at,
373 summary: Some(serialized.summary),
374 pending_summary: Task::ready(None),
375 detailed_summary_state: serialized.detailed_summary_state,
376 messages: serialized
377 .messages
378 .into_iter()
379 .map(|message| Message {
380 id: message.id,
381 role: message.role,
382 segments: message
383 .segments
384 .into_iter()
385 .map(|segment| match segment {
386 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
387 SerializedMessageSegment::Thinking { text } => {
388 MessageSegment::Thinking(text)
389 }
390 })
391 .collect(),
392 context: message.context,
393 })
394 .collect(),
395 next_message_id,
396 context: BTreeMap::default(),
397 context_by_message: HashMap::default(),
398 project_context,
399 checkpoints_by_message: HashMap::default(),
400 completion_count: 0,
401 pending_completions: Vec::new(),
402 last_restore_checkpoint: None,
403 pending_checkpoint: None,
404 project: project.clone(),
405 prompt_builder,
406 tools,
407 tool_use,
408 action_log: cx.new(|_| ActionLog::new(project)),
409 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
410 request_token_usage: serialized.request_token_usage,
411 cumulative_token_usage: serialized.cumulative_token_usage,
412 exceeded_window_error: None,
413 feedback: None,
414 message_feedback: HashMap::default(),
415 last_auto_capture_at: None,
416 }
417 }
418
419 pub fn id(&self) -> &ThreadId {
420 &self.id
421 }
422
423 pub fn is_empty(&self) -> bool {
424 self.messages.is_empty()
425 }
426
427 pub fn updated_at(&self) -> DateTime<Utc> {
428 self.updated_at
429 }
430
431 pub fn touch_updated_at(&mut self) {
432 self.updated_at = Utc::now();
433 }
434
435 pub fn summary(&self) -> Option<SharedString> {
436 self.summary.clone()
437 }
438
439 pub fn project_context(&self) -> SharedProjectContext {
440 self.project_context.clone()
441 }
442
443 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
444
445 pub fn summary_or_default(&self) -> SharedString {
446 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
447 }
448
449 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
450 let Some(current_summary) = &self.summary else {
451 // Don't allow setting summary until generated
452 return;
453 };
454
455 let mut new_summary = new_summary.into();
456
457 if new_summary.is_empty() {
458 new_summary = Self::DEFAULT_SUMMARY;
459 }
460
461 if current_summary != &new_summary {
462 self.summary = Some(new_summary);
463 cx.emit(ThreadEvent::SummaryChanged);
464 }
465 }
466
467 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
468 self.latest_detailed_summary()
469 .unwrap_or_else(|| self.text().into())
470 }
471
472 fn latest_detailed_summary(&self) -> Option<SharedString> {
473 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
474 Some(text.clone())
475 } else {
476 None
477 }
478 }
479
480 pub fn message(&self, id: MessageId) -> Option<&Message> {
481 self.messages.iter().find(|message| message.id == id)
482 }
483
484 pub fn messages(&self) -> impl Iterator<Item = &Message> {
485 self.messages.iter()
486 }
487
488 pub fn is_generating(&self) -> bool {
489 !self.pending_completions.is_empty() || !self.all_tools_finished()
490 }
491
492 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
493 &self.tools
494 }
495
496 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
497 self.tool_use
498 .pending_tool_uses()
499 .into_iter()
500 .find(|tool_use| &tool_use.id == id)
501 }
502
503 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
504 self.tool_use
505 .pending_tool_uses()
506 .into_iter()
507 .filter(|tool_use| tool_use.status.needs_confirmation())
508 }
509
510 pub fn has_pending_tool_uses(&self) -> bool {
511 !self.tool_use.pending_tool_uses().is_empty()
512 }
513
514 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
515 self.checkpoints_by_message.get(&id).cloned()
516 }
517
518 pub fn restore_checkpoint(
519 &mut self,
520 checkpoint: ThreadCheckpoint,
521 cx: &mut Context<Self>,
522 ) -> Task<Result<()>> {
523 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
524 message_id: checkpoint.message_id,
525 });
526 cx.emit(ThreadEvent::CheckpointChanged);
527 cx.notify();
528
529 let git_store = self.project().read(cx).git_store().clone();
530 let restore = git_store.update(cx, |git_store, cx| {
531 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
532 });
533
534 cx.spawn(async move |this, cx| {
535 let result = restore.await;
536 this.update(cx, |this, cx| {
537 if let Err(err) = result.as_ref() {
538 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
539 message_id: checkpoint.message_id,
540 error: err.to_string(),
541 });
542 } else {
543 this.truncate(checkpoint.message_id, cx);
544 this.last_restore_checkpoint = None;
545 }
546 this.pending_checkpoint = None;
547 cx.emit(ThreadEvent::CheckpointChanged);
548 cx.notify();
549 })?;
550 result
551 })
552 }
553
554 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
555 let pending_checkpoint = if self.is_generating() {
556 return;
557 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
558 checkpoint
559 } else {
560 return;
561 };
562
563 let git_store = self.project.read(cx).git_store().clone();
564 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
565 cx.spawn(async move |this, cx| match final_checkpoint.await {
566 Ok(final_checkpoint) => {
567 let equal = git_store
568 .update(cx, |store, cx| {
569 store.compare_checkpoints(
570 pending_checkpoint.git_checkpoint.clone(),
571 final_checkpoint.clone(),
572 cx,
573 )
574 })?
575 .await
576 .unwrap_or(false);
577
578 if equal {
579 git_store
580 .update(cx, |store, cx| {
581 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
582 })?
583 .detach();
584 } else {
585 this.update(cx, |this, cx| {
586 this.insert_checkpoint(pending_checkpoint, cx)
587 })?;
588 }
589
590 git_store
591 .update(cx, |store, cx| {
592 store.delete_checkpoint(final_checkpoint, cx)
593 })?
594 .detach();
595
596 Ok(())
597 }
598 Err(_) => this.update(cx, |this, cx| {
599 this.insert_checkpoint(pending_checkpoint, cx)
600 }),
601 })
602 .detach();
603 }
604
605 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
606 self.checkpoints_by_message
607 .insert(checkpoint.message_id, checkpoint);
608 cx.emit(ThreadEvent::CheckpointChanged);
609 cx.notify();
610 }
611
612 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
613 self.last_restore_checkpoint.as_ref()
614 }
615
616 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
617 let Some(message_ix) = self
618 .messages
619 .iter()
620 .rposition(|message| message.id == message_id)
621 else {
622 return;
623 };
624 for deleted_message in self.messages.drain(message_ix..) {
625 self.context_by_message.remove(&deleted_message.id);
626 self.checkpoints_by_message.remove(&deleted_message.id);
627 }
628 cx.notify();
629 }
630
631 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
632 self.context_by_message
633 .get(&id)
634 .into_iter()
635 .flat_map(|context| {
636 context
637 .iter()
638 .filter_map(|context_id| self.context.get(&context_id))
639 })
640 }
641
642 /// Returns whether all of the tool uses have finished running.
643 pub fn all_tools_finished(&self) -> bool {
644 // If the only pending tool uses left are the ones with errors, then
645 // that means that we've finished running all of the pending tools.
646 self.tool_use
647 .pending_tool_uses()
648 .iter()
649 .all(|tool_use| tool_use.status.is_error())
650 }
651
652 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
653 self.tool_use.tool_uses_for_message(id, cx)
654 }
655
656 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
657 self.tool_use.tool_results_for_message(id)
658 }
659
660 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
661 self.tool_use.tool_result(id)
662 }
663
664 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
665 Some(&self.tool_use.tool_result(id)?.content)
666 }
667
668 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
669 self.tool_use.tool_result_card(id).cloned()
670 }
671
672 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
673 self.tool_use.message_has_tool_results(message_id)
674 }
675
676 /// Filter out contexts that have already been included in previous messages
677 pub fn filter_new_context<'a>(
678 &self,
679 context: impl Iterator<Item = &'a AssistantContext>,
680 ) -> impl Iterator<Item = &'a AssistantContext> {
681 context.filter(|ctx| self.is_context_new(ctx))
682 }
683
684 fn is_context_new(&self, context: &AssistantContext) -> bool {
685 !self.context.contains_key(&context.id())
686 }
687
688 pub fn insert_user_message(
689 &mut self,
690 text: impl Into<String>,
691 context: Vec<AssistantContext>,
692 git_checkpoint: Option<GitStoreCheckpoint>,
693 cx: &mut Context<Self>,
694 ) -> MessageId {
695 let text = text.into();
696
697 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
698
699 let new_context: Vec<_> = context
700 .into_iter()
701 .filter(|ctx| self.is_context_new(ctx))
702 .collect();
703
704 if !new_context.is_empty() {
705 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
706 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
707 message.context = context_string;
708 }
709 }
710
711 self.action_log.update(cx, |log, cx| {
712 // Track all buffers added as context
713 for ctx in &new_context {
714 match ctx {
715 AssistantContext::File(file_ctx) => {
716 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
717 }
718 AssistantContext::Directory(dir_ctx) => {
719 for context_buffer in &dir_ctx.context_buffers {
720 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
721 }
722 }
723 AssistantContext::Symbol(symbol_ctx) => {
724 log.buffer_added_as_context(
725 symbol_ctx.context_symbol.buffer.clone(),
726 cx,
727 );
728 }
729 AssistantContext::Excerpt(excerpt_context) => {
730 log.buffer_added_as_context(
731 excerpt_context.context_buffer.buffer.clone(),
732 cx,
733 );
734 }
735 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
736 }
737 }
738 });
739 }
740
741 let context_ids = new_context
742 .iter()
743 .map(|context| context.id())
744 .collect::<Vec<_>>();
745 self.context.extend(
746 new_context
747 .into_iter()
748 .map(|context| (context.id(), context)),
749 );
750 self.context_by_message.insert(message_id, context_ids);
751
752 if let Some(git_checkpoint) = git_checkpoint {
753 self.pending_checkpoint = Some(ThreadCheckpoint {
754 message_id,
755 git_checkpoint,
756 });
757 }
758
759 self.auto_capture_telemetry(cx);
760
761 message_id
762 }
763
764 pub fn insert_message(
765 &mut self,
766 role: Role,
767 segments: Vec<MessageSegment>,
768 cx: &mut Context<Self>,
769 ) -> MessageId {
770 let id = self.next_message_id.post_inc();
771 self.messages.push(Message {
772 id,
773 role,
774 segments,
775 context: String::new(),
776 });
777 self.touch_updated_at();
778 cx.emit(ThreadEvent::MessageAdded(id));
779 id
780 }
781
782 pub fn edit_message(
783 &mut self,
784 id: MessageId,
785 new_role: Role,
786 new_segments: Vec<MessageSegment>,
787 cx: &mut Context<Self>,
788 ) -> bool {
789 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
790 return false;
791 };
792 message.role = new_role;
793 message.segments = new_segments;
794 self.touch_updated_at();
795 cx.emit(ThreadEvent::MessageEdited(id));
796 true
797 }
798
799 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
800 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
801 return false;
802 };
803 self.messages.remove(index);
804 self.context_by_message.remove(&id);
805 self.touch_updated_at();
806 cx.emit(ThreadEvent::MessageDeleted(id));
807 true
808 }
809
810 /// Returns the representation of this [`Thread`] in a textual form.
811 ///
812 /// This is the representation we use when attaching a thread as context to another thread.
813 pub fn text(&self) -> String {
814 let mut text = String::new();
815
816 for message in &self.messages {
817 text.push_str(match message.role {
818 language_model::Role::User => "User:",
819 language_model::Role::Assistant => "Assistant:",
820 language_model::Role::System => "System:",
821 });
822 text.push('\n');
823
824 for segment in &message.segments {
825 match segment {
826 MessageSegment::Text(content) => text.push_str(content),
827 MessageSegment::Thinking(content) => {
828 text.push_str(&format!("<think>{}</think>", content))
829 }
830 }
831 }
832 text.push('\n');
833 }
834
835 text
836 }
837
838 /// Serializes this thread into a format for storage or telemetry.
839 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
840 let initial_project_snapshot = self.initial_project_snapshot.clone();
841 cx.spawn(async move |this, cx| {
842 let initial_project_snapshot = initial_project_snapshot.await;
843 this.read_with(cx, |this, cx| SerializedThread {
844 version: SerializedThread::VERSION.to_string(),
845 summary: this.summary_or_default(),
846 updated_at: this.updated_at(),
847 messages: this
848 .messages()
849 .map(|message| SerializedMessage {
850 id: message.id,
851 role: message.role,
852 segments: message
853 .segments
854 .iter()
855 .map(|segment| match segment {
856 MessageSegment::Text(text) => {
857 SerializedMessageSegment::Text { text: text.clone() }
858 }
859 MessageSegment::Thinking(text) => {
860 SerializedMessageSegment::Thinking { text: text.clone() }
861 }
862 })
863 .collect(),
864 tool_uses: this
865 .tool_uses_for_message(message.id, cx)
866 .into_iter()
867 .map(|tool_use| SerializedToolUse {
868 id: tool_use.id,
869 name: tool_use.name,
870 input: tool_use.input,
871 })
872 .collect(),
873 tool_results: this
874 .tool_results_for_message(message.id)
875 .into_iter()
876 .map(|tool_result| SerializedToolResult {
877 tool_use_id: tool_result.tool_use_id.clone(),
878 is_error: tool_result.is_error,
879 content: tool_result.content.clone(),
880 })
881 .collect(),
882 context: message.context.clone(),
883 })
884 .collect(),
885 initial_project_snapshot,
886 cumulative_token_usage: this.cumulative_token_usage,
887 request_token_usage: this.request_token_usage.clone(),
888 detailed_summary_state: this.detailed_summary_state.clone(),
889 exceeded_window_error: this.exceeded_window_error.clone(),
890 })
891 })
892 }
893
894 pub fn send_to_model(
895 &mut self,
896 model: Arc<dyn LanguageModel>,
897 request_kind: RequestKind,
898 cx: &mut Context<Self>,
899 ) {
900 let mut request = self.to_completion_request(request_kind, cx);
901 if model.supports_tools() {
902 request.tools = {
903 let mut tools = Vec::new();
904 tools.extend(
905 self.tools()
906 .read(cx)
907 .enabled_tools(cx)
908 .into_iter()
909 .filter_map(|tool| {
910 // Skip tools that cannot be supported
911 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
912 Some(LanguageModelRequestTool {
913 name: tool.name(),
914 description: tool.description(),
915 input_schema,
916 })
917 }),
918 );
919
920 tools
921 };
922 }
923
924 self.stream_completion(request, model, cx);
925 }
926
927 pub fn used_tools_since_last_user_message(&self) -> bool {
928 for message in self.messages.iter().rev() {
929 if self.tool_use.message_has_tool_results(message.id) {
930 return true;
931 } else if message.role == Role::User {
932 return false;
933 }
934 }
935
936 false
937 }
938
939 pub fn to_completion_request(
940 &self,
941 request_kind: RequestKind,
942 cx: &App,
943 ) -> LanguageModelRequest {
944 let mut request = LanguageModelRequest {
945 messages: vec![],
946 tools: Vec::new(),
947 stop: Vec::new(),
948 temperature: None,
949 };
950
951 if let Some(project_context) = self.project_context.borrow().as_ref() {
952 if let Some(system_prompt) = self
953 .prompt_builder
954 .generate_assistant_system_prompt(project_context)
955 .context("failed to generate assistant system prompt")
956 .log_err()
957 {
958 request.messages.push(LanguageModelRequestMessage {
959 role: Role::System,
960 content: vec![MessageContent::Text(system_prompt)],
961 cache: true,
962 });
963 }
964 } else {
965 log::error!("project_context not set.")
966 }
967
968 for message in &self.messages {
969 let mut request_message = LanguageModelRequestMessage {
970 role: message.role,
971 content: Vec::new(),
972 cache: false,
973 };
974
975 match request_kind {
976 RequestKind::Chat => {
977 self.tool_use
978 .attach_tool_results(message.id, &mut request_message);
979 }
980 RequestKind::Summarize => {
981 // We don't care about tool use during summarization.
982 if self.tool_use.message_has_tool_results(message.id) {
983 continue;
984 }
985 }
986 }
987
988 if !message.segments.is_empty() {
989 request_message
990 .content
991 .push(MessageContent::Text(message.to_string()));
992 }
993
994 match request_kind {
995 RequestKind::Chat => {
996 self.tool_use
997 .attach_tool_uses(message.id, &mut request_message);
998 }
999 RequestKind::Summarize => {
1000 // We don't care about tool use during summarization.
1001 }
1002 };
1003
1004 request.messages.push(request_message);
1005 }
1006
1007 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1008 if let Some(last) = request.messages.last_mut() {
1009 last.cache = true;
1010 }
1011
1012 self.attached_tracked_files_state(&mut request.messages, cx);
1013
1014 request
1015 }
1016
1017 fn attached_tracked_files_state(
1018 &self,
1019 messages: &mut Vec<LanguageModelRequestMessage>,
1020 cx: &App,
1021 ) {
1022 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1023
1024 let mut stale_message = String::new();
1025
1026 let action_log = self.action_log.read(cx);
1027
1028 for stale_file in action_log.stale_buffers(cx) {
1029 let Some(file) = stale_file.read(cx).file() else {
1030 continue;
1031 };
1032
1033 if stale_message.is_empty() {
1034 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1035 }
1036
1037 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1038 }
1039
1040 let mut content = Vec::with_capacity(2);
1041
1042 if !stale_message.is_empty() {
1043 content.push(stale_message.into());
1044 }
1045
1046 if action_log.has_edited_files_since_project_diagnostics_check() {
1047 content.push(
1048 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1049 and fix all errors AND warnings you introduced! \
1050 DO NOT mention you're going to do this until you're done."
1051 .into(),
1052 );
1053 }
1054
1055 if !content.is_empty() {
1056 let context_message = LanguageModelRequestMessage {
1057 role: Role::User,
1058 content,
1059 cache: false,
1060 };
1061
1062 messages.push(context_message);
1063 }
1064 }
1065
1066 pub fn stream_completion(
1067 &mut self,
1068 request: LanguageModelRequest,
1069 model: Arc<dyn LanguageModel>,
1070 cx: &mut Context<Self>,
1071 ) {
1072 let pending_completion_id = post_inc(&mut self.completion_count);
1073 let task = cx.spawn(async move |thread, cx| {
1074 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1075 let initial_token_usage =
1076 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1077 let stream_completion = async {
1078 let (mut events, usage) = stream_completion_future.await?;
1079 let mut stop_reason = StopReason::EndTurn;
1080 let mut current_token_usage = TokenUsage::default();
1081
1082 if let Some(usage) = usage {
1083 let limit = match usage.limit {
1084 UsageLimit::Limited(limit) => limit.to_string(),
1085 UsageLimit::Unlimited => "unlimited".to_string(),
1086 };
1087 log::info!("model request usage: {} / {}", usage.amount, limit);
1088 }
1089
1090 while let Some(event) = events.next().await {
1091 let event = event?;
1092
1093 thread.update(cx, |thread, cx| {
1094 match event {
1095 LanguageModelCompletionEvent::StartMessage { .. } => {
1096 thread.insert_message(
1097 Role::Assistant,
1098 vec![MessageSegment::Text(String::new())],
1099 cx,
1100 );
1101 }
1102 LanguageModelCompletionEvent::Stop(reason) => {
1103 stop_reason = reason;
1104 }
1105 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1106 thread.update_token_usage_at_last_message(token_usage);
1107 thread.cumulative_token_usage = thread.cumulative_token_usage
1108 + token_usage
1109 - current_token_usage;
1110 current_token_usage = token_usage;
1111 }
1112 LanguageModelCompletionEvent::Text(chunk) => {
1113 if let Some(last_message) = thread.messages.last_mut() {
1114 if last_message.role == Role::Assistant {
1115 last_message.push_text(&chunk);
1116 cx.emit(ThreadEvent::StreamedAssistantText(
1117 last_message.id,
1118 chunk,
1119 ));
1120 } else {
1121 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1122 // of a new Assistant response.
1123 //
1124 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1125 // will result in duplicating the text of the chunk in the rendered Markdown.
1126 thread.insert_message(
1127 Role::Assistant,
1128 vec![MessageSegment::Text(chunk.to_string())],
1129 cx,
1130 );
1131 };
1132 }
1133 }
1134 LanguageModelCompletionEvent::Thinking(chunk) => {
1135 if let Some(last_message) = thread.messages.last_mut() {
1136 if last_message.role == Role::Assistant {
1137 last_message.push_thinking(&chunk);
1138 cx.emit(ThreadEvent::StreamedAssistantThinking(
1139 last_message.id,
1140 chunk,
1141 ));
1142 } else {
1143 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1144 // of a new Assistant response.
1145 //
1146 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1147 // will result in duplicating the text of the chunk in the rendered Markdown.
1148 thread.insert_message(
1149 Role::Assistant,
1150 vec![MessageSegment::Thinking(chunk.to_string())],
1151 cx,
1152 );
1153 };
1154 }
1155 }
1156 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1157 let last_assistant_message_id = thread
1158 .messages
1159 .iter_mut()
1160 .rfind(|message| message.role == Role::Assistant)
1161 .map(|message| message.id)
1162 .unwrap_or_else(|| {
1163 thread.insert_message(Role::Assistant, vec![], cx)
1164 });
1165
1166 thread.tool_use.request_tool_use(
1167 last_assistant_message_id,
1168 tool_use,
1169 cx,
1170 );
1171 }
1172 }
1173
1174 thread.touch_updated_at();
1175 cx.emit(ThreadEvent::StreamedCompletion);
1176 cx.notify();
1177
1178 thread.auto_capture_telemetry(cx);
1179 })?;
1180
1181 smol::future::yield_now().await;
1182 }
1183
1184 thread.update(cx, |thread, cx| {
1185 thread
1186 .pending_completions
1187 .retain(|completion| completion.id != pending_completion_id);
1188
1189 if thread.summary.is_none() && thread.messages.len() >= 2 {
1190 thread.summarize(cx);
1191 }
1192 })?;
1193
1194 anyhow::Ok(stop_reason)
1195 };
1196
1197 let result = stream_completion.await;
1198
1199 thread
1200 .update(cx, |thread, cx| {
1201 thread.finalize_pending_checkpoint(cx);
1202 match result.as_ref() {
1203 Ok(stop_reason) => match stop_reason {
1204 StopReason::ToolUse => {
1205 let tool_uses = thread.use_pending_tools(cx);
1206 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1207 }
1208 StopReason::EndTurn => {}
1209 StopReason::MaxTokens => {}
1210 },
1211 Err(error) => {
1212 if error.is::<PaymentRequiredError>() {
1213 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1214 } else if error.is::<MaxMonthlySpendReachedError>() {
1215 cx.emit(ThreadEvent::ShowError(
1216 ThreadError::MaxMonthlySpendReached,
1217 ));
1218 } else if let Some(error) =
1219 error.downcast_ref::<ModelRequestLimitReachedError>()
1220 {
1221 cx.emit(ThreadEvent::ShowError(
1222 ThreadError::ModelRequestLimitReached { plan: error.plan },
1223 ));
1224 } else if let Some(known_error) =
1225 error.downcast_ref::<LanguageModelKnownError>()
1226 {
1227 match known_error {
1228 LanguageModelKnownError::ContextWindowLimitExceeded {
1229 tokens,
1230 } => {
1231 thread.exceeded_window_error = Some(ExceededWindowError {
1232 model_id: model.id(),
1233 token_count: *tokens,
1234 });
1235 cx.notify();
1236 }
1237 }
1238 } else {
1239 let error_message = error
1240 .chain()
1241 .map(|err| err.to_string())
1242 .collect::<Vec<_>>()
1243 .join("\n");
1244 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1245 header: "Error interacting with language model".into(),
1246 message: SharedString::from(error_message.clone()),
1247 }));
1248 }
1249
1250 thread.cancel_last_completion(cx);
1251 }
1252 }
1253 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1254
1255 thread.auto_capture_telemetry(cx);
1256
1257 if let Ok(initial_usage) = initial_token_usage {
1258 let usage = thread.cumulative_token_usage - initial_usage;
1259
1260 telemetry::event!(
1261 "Assistant Thread Completion",
1262 thread_id = thread.id().to_string(),
1263 model = model.telemetry_id(),
1264 model_provider = model.provider_id().to_string(),
1265 input_tokens = usage.input_tokens,
1266 output_tokens = usage.output_tokens,
1267 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1268 cache_read_input_tokens = usage.cache_read_input_tokens,
1269 );
1270 }
1271 })
1272 .ok();
1273 });
1274
1275 self.pending_completions.push(PendingCompletion {
1276 id: pending_completion_id,
1277 _task: task,
1278 });
1279 }
1280
1281 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1282 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1283 return;
1284 };
1285
1286 if !model.provider.is_authenticated(cx) {
1287 return;
1288 }
1289
1290 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1291 request.messages.push(LanguageModelRequestMessage {
1292 role: Role::User,
1293 content: vec![
1294 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1295 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1296 If the conversation is about a specific subject, include it in the title. \
1297 Be descriptive. DO NOT speak in the first person."
1298 .into(),
1299 ],
1300 cache: false,
1301 });
1302
1303 self.pending_summary = cx.spawn(async move |this, cx| {
1304 async move {
1305 let stream = model.model.stream_completion_text(request, &cx);
1306 let mut messages = stream.await?;
1307
1308 let mut new_summary = String::new();
1309 while let Some(message) = messages.stream.next().await {
1310 let text = message?;
1311 let mut lines = text.lines();
1312 new_summary.extend(lines.next());
1313
1314 // Stop if the LLM generated multiple lines.
1315 if lines.next().is_some() {
1316 break;
1317 }
1318 }
1319
1320 this.update(cx, |this, cx| {
1321 if !new_summary.is_empty() {
1322 this.summary = Some(new_summary.into());
1323 }
1324
1325 cx.emit(ThreadEvent::SummaryGenerated);
1326 })?;
1327
1328 anyhow::Ok(())
1329 }
1330 .log_err()
1331 .await
1332 });
1333 }
1334
1335 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1336 let last_message_id = self.messages.last().map(|message| message.id)?;
1337
1338 match &self.detailed_summary_state {
1339 DetailedSummaryState::Generating { message_id, .. }
1340 | DetailedSummaryState::Generated { message_id, .. }
1341 if *message_id == last_message_id =>
1342 {
1343 // Already up-to-date
1344 return None;
1345 }
1346 _ => {}
1347 }
1348
1349 let ConfiguredModel { model, provider } =
1350 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1351
1352 if !provider.is_authenticated(cx) {
1353 return None;
1354 }
1355
1356 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1357
1358 request.messages.push(LanguageModelRequestMessage {
1359 role: Role::User,
1360 content: vec![
1361 "Generate a detailed summary of this conversation. Include:\n\
1362 1. A brief overview of what was discussed\n\
1363 2. Key facts or information discovered\n\
1364 3. Outcomes or conclusions reached\n\
1365 4. Any action items or next steps if any\n\
1366 Format it in Markdown with headings and bullet points."
1367 .into(),
1368 ],
1369 cache: false,
1370 });
1371
1372 let task = cx.spawn(async move |thread, cx| {
1373 let stream = model.stream_completion_text(request, &cx);
1374 let Some(mut messages) = stream.await.log_err() else {
1375 thread
1376 .update(cx, |this, _cx| {
1377 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1378 })
1379 .log_err();
1380
1381 return;
1382 };
1383
1384 let mut new_detailed_summary = String::new();
1385
1386 while let Some(chunk) = messages.stream.next().await {
1387 if let Some(chunk) = chunk.log_err() {
1388 new_detailed_summary.push_str(&chunk);
1389 }
1390 }
1391
1392 thread
1393 .update(cx, |this, _cx| {
1394 this.detailed_summary_state = DetailedSummaryState::Generated {
1395 text: new_detailed_summary.into(),
1396 message_id: last_message_id,
1397 };
1398 })
1399 .log_err();
1400 });
1401
1402 self.detailed_summary_state = DetailedSummaryState::Generating {
1403 message_id: last_message_id,
1404 };
1405
1406 Some(task)
1407 }
1408
1409 pub fn is_generating_detailed_summary(&self) -> bool {
1410 matches!(
1411 self.detailed_summary_state,
1412 DetailedSummaryState::Generating { .. }
1413 )
1414 }
1415
1416 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1417 self.auto_capture_telemetry(cx);
1418 let request = self.to_completion_request(RequestKind::Chat, cx);
1419 let messages = Arc::new(request.messages);
1420 let pending_tool_uses = self
1421 .tool_use
1422 .pending_tool_uses()
1423 .into_iter()
1424 .filter(|tool_use| tool_use.status.is_idle())
1425 .cloned()
1426 .collect::<Vec<_>>();
1427
1428 for tool_use in pending_tool_uses.iter() {
1429 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1430 if tool.needs_confirmation(&tool_use.input, cx)
1431 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1432 {
1433 self.tool_use.confirm_tool_use(
1434 tool_use.id.clone(),
1435 tool_use.ui_text.clone(),
1436 tool_use.input.clone(),
1437 messages.clone(),
1438 tool,
1439 );
1440 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1441 } else {
1442 self.run_tool(
1443 tool_use.id.clone(),
1444 tool_use.ui_text.clone(),
1445 tool_use.input.clone(),
1446 &messages,
1447 tool,
1448 cx,
1449 );
1450 }
1451 }
1452 }
1453
1454 pending_tool_uses
1455 }
1456
1457 pub fn run_tool(
1458 &mut self,
1459 tool_use_id: LanguageModelToolUseId,
1460 ui_text: impl Into<SharedString>,
1461 input: serde_json::Value,
1462 messages: &[LanguageModelRequestMessage],
1463 tool: Arc<dyn Tool>,
1464 cx: &mut Context<Thread>,
1465 ) {
1466 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1467 self.tool_use
1468 .run_pending_tool(tool_use_id, ui_text.into(), task);
1469 }
1470
1471 fn spawn_tool_use(
1472 &mut self,
1473 tool_use_id: LanguageModelToolUseId,
1474 messages: &[LanguageModelRequestMessage],
1475 input: serde_json::Value,
1476 tool: Arc<dyn Tool>,
1477 cx: &mut Context<Thread>,
1478 ) -> Task<()> {
1479 let tool_name: Arc<str> = tool.name().into();
1480
1481 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1482 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1483 } else {
1484 tool.run(
1485 input,
1486 messages,
1487 self.project.clone(),
1488 self.action_log.clone(),
1489 cx,
1490 )
1491 };
1492
1493 // Store the card separately if it exists
1494 if let Some(card) = tool_result.card.clone() {
1495 self.tool_use
1496 .insert_tool_result_card(tool_use_id.clone(), card);
1497 }
1498
1499 cx.spawn({
1500 async move |thread: WeakEntity<Thread>, cx| {
1501 let output = tool_result.output.await;
1502
1503 thread
1504 .update(cx, |thread, cx| {
1505 let pending_tool_use = thread.tool_use.insert_tool_output(
1506 tool_use_id.clone(),
1507 tool_name,
1508 output,
1509 cx,
1510 );
1511 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1512 })
1513 .ok();
1514 }
1515 })
1516 }
1517
1518 fn tool_finished(
1519 &mut self,
1520 tool_use_id: LanguageModelToolUseId,
1521 pending_tool_use: Option<PendingToolUse>,
1522 canceled: bool,
1523 cx: &mut Context<Self>,
1524 ) {
1525 if self.all_tools_finished() {
1526 let model_registry = LanguageModelRegistry::read_global(cx);
1527 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1528 self.attach_tool_results(cx);
1529 if !canceled {
1530 self.send_to_model(model, RequestKind::Chat, cx);
1531 }
1532 }
1533 }
1534
1535 cx.emit(ThreadEvent::ToolFinished {
1536 tool_use_id,
1537 pending_tool_use,
1538 });
1539 }
1540
1541 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1542 // Insert a user message to contain the tool results.
1543 self.insert_user_message(
1544 // TODO: Sending up a user message without any content results in the model sending back
1545 // responses that also don't have any content. We currently don't handle this case well,
1546 // so for now we provide some text to keep the model on track.
1547 "Here are the tool results.",
1548 Vec::new(),
1549 None,
1550 cx,
1551 );
1552 }
1553
1554 /// Cancels the last pending completion, if there are any pending.
1555 ///
1556 /// Returns whether a completion was canceled.
1557 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1558 let canceled = if self.pending_completions.pop().is_some() {
1559 true
1560 } else {
1561 let mut canceled = false;
1562 for pending_tool_use in self.tool_use.cancel_pending() {
1563 canceled = true;
1564 self.tool_finished(
1565 pending_tool_use.id.clone(),
1566 Some(pending_tool_use),
1567 true,
1568 cx,
1569 );
1570 }
1571 canceled
1572 };
1573 self.finalize_pending_checkpoint(cx);
1574 canceled
1575 }
1576
1577 pub fn feedback(&self) -> Option<ThreadFeedback> {
1578 self.feedback
1579 }
1580
1581 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1582 self.message_feedback.get(&message_id).copied()
1583 }
1584
1585 pub fn report_message_feedback(
1586 &mut self,
1587 message_id: MessageId,
1588 feedback: ThreadFeedback,
1589 cx: &mut Context<Self>,
1590 ) -> Task<Result<()>> {
1591 if self.message_feedback.get(&message_id) == Some(&feedback) {
1592 return Task::ready(Ok(()));
1593 }
1594
1595 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1596 let serialized_thread = self.serialize(cx);
1597 let thread_id = self.id().clone();
1598 let client = self.project.read(cx).client();
1599
1600 let enabled_tool_names: Vec<String> = self
1601 .tools()
1602 .read(cx)
1603 .enabled_tools(cx)
1604 .iter()
1605 .map(|tool| tool.name().to_string())
1606 .collect();
1607
1608 self.message_feedback.insert(message_id, feedback);
1609
1610 cx.notify();
1611
1612 let message_content = self
1613 .message(message_id)
1614 .map(|msg| msg.to_string())
1615 .unwrap_or_default();
1616
1617 cx.background_spawn(async move {
1618 let final_project_snapshot = final_project_snapshot.await;
1619 let serialized_thread = serialized_thread.await?;
1620 let thread_data =
1621 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1622
1623 let rating = match feedback {
1624 ThreadFeedback::Positive => "positive",
1625 ThreadFeedback::Negative => "negative",
1626 };
1627 telemetry::event!(
1628 "Assistant Thread Rated",
1629 rating,
1630 thread_id,
1631 enabled_tool_names,
1632 message_id = message_id.0,
1633 message_content,
1634 thread_data,
1635 final_project_snapshot
1636 );
1637 client.telemetry().flush_events();
1638
1639 Ok(())
1640 })
1641 }
1642
1643 pub fn report_feedback(
1644 &mut self,
1645 feedback: ThreadFeedback,
1646 cx: &mut Context<Self>,
1647 ) -> Task<Result<()>> {
1648 let last_assistant_message_id = self
1649 .messages
1650 .iter()
1651 .rev()
1652 .find(|msg| msg.role == Role::Assistant)
1653 .map(|msg| msg.id);
1654
1655 if let Some(message_id) = last_assistant_message_id {
1656 self.report_message_feedback(message_id, feedback, cx)
1657 } else {
1658 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1659 let serialized_thread = self.serialize(cx);
1660 let thread_id = self.id().clone();
1661 let client = self.project.read(cx).client();
1662 self.feedback = Some(feedback);
1663 cx.notify();
1664
1665 cx.background_spawn(async move {
1666 let final_project_snapshot = final_project_snapshot.await;
1667 let serialized_thread = serialized_thread.await?;
1668 let thread_data = serde_json::to_value(serialized_thread)
1669 .unwrap_or_else(|_| serde_json::Value::Null);
1670
1671 let rating = match feedback {
1672 ThreadFeedback::Positive => "positive",
1673 ThreadFeedback::Negative => "negative",
1674 };
1675 telemetry::event!(
1676 "Assistant Thread Rated",
1677 rating,
1678 thread_id,
1679 thread_data,
1680 final_project_snapshot
1681 );
1682 client.telemetry().flush_events();
1683
1684 Ok(())
1685 })
1686 }
1687 }
1688
1689 /// Create a snapshot of the current project state including git information and unsaved buffers.
1690 fn project_snapshot(
1691 project: Entity<Project>,
1692 cx: &mut Context<Self>,
1693 ) -> Task<Arc<ProjectSnapshot>> {
1694 let git_store = project.read(cx).git_store().clone();
1695 let worktree_snapshots: Vec<_> = project
1696 .read(cx)
1697 .visible_worktrees(cx)
1698 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1699 .collect();
1700
1701 cx.spawn(async move |_, cx| {
1702 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1703
1704 let mut unsaved_buffers = Vec::new();
1705 cx.update(|app_cx| {
1706 let buffer_store = project.read(app_cx).buffer_store();
1707 for buffer_handle in buffer_store.read(app_cx).buffers() {
1708 let buffer = buffer_handle.read(app_cx);
1709 if buffer.is_dirty() {
1710 if let Some(file) = buffer.file() {
1711 let path = file.path().to_string_lossy().to_string();
1712 unsaved_buffers.push(path);
1713 }
1714 }
1715 }
1716 })
1717 .ok();
1718
1719 Arc::new(ProjectSnapshot {
1720 worktree_snapshots,
1721 unsaved_buffer_paths: unsaved_buffers,
1722 timestamp: Utc::now(),
1723 })
1724 })
1725 }
1726
1727 fn worktree_snapshot(
1728 worktree: Entity<project::Worktree>,
1729 git_store: Entity<GitStore>,
1730 cx: &App,
1731 ) -> Task<WorktreeSnapshot> {
1732 cx.spawn(async move |cx| {
1733 // Get worktree path and snapshot
1734 let worktree_info = cx.update(|app_cx| {
1735 let worktree = worktree.read(app_cx);
1736 let path = worktree.abs_path().to_string_lossy().to_string();
1737 let snapshot = worktree.snapshot();
1738 (path, snapshot)
1739 });
1740
1741 let Ok((worktree_path, _snapshot)) = worktree_info else {
1742 return WorktreeSnapshot {
1743 worktree_path: String::new(),
1744 git_state: None,
1745 };
1746 };
1747
1748 let git_state = git_store
1749 .update(cx, |git_store, cx| {
1750 git_store
1751 .repositories()
1752 .values()
1753 .find(|repo| {
1754 repo.read(cx)
1755 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1756 .is_some()
1757 })
1758 .cloned()
1759 })
1760 .ok()
1761 .flatten()
1762 .map(|repo| {
1763 repo.update(cx, |repo, _| {
1764 let current_branch =
1765 repo.branch.as_ref().map(|branch| branch.name.to_string());
1766 repo.send_job(None, |state, _| async move {
1767 let RepositoryState::Local { backend, .. } = state else {
1768 return GitState {
1769 remote_url: None,
1770 head_sha: None,
1771 current_branch,
1772 diff: None,
1773 };
1774 };
1775
1776 let remote_url = backend.remote_url("origin");
1777 let head_sha = backend.head_sha();
1778 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1779
1780 GitState {
1781 remote_url,
1782 head_sha,
1783 current_branch,
1784 diff,
1785 }
1786 })
1787 })
1788 });
1789
1790 let git_state = match git_state {
1791 Some(git_state) => match git_state.ok() {
1792 Some(git_state) => git_state.await.ok(),
1793 None => None,
1794 },
1795 None => None,
1796 };
1797
1798 WorktreeSnapshot {
1799 worktree_path,
1800 git_state,
1801 }
1802 })
1803 }
1804
1805 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1806 let mut markdown = Vec::new();
1807
1808 if let Some(summary) = self.summary() {
1809 writeln!(markdown, "# {summary}\n")?;
1810 };
1811
1812 for message in self.messages() {
1813 writeln!(
1814 markdown,
1815 "## {role}\n",
1816 role = match message.role {
1817 Role::User => "User",
1818 Role::Assistant => "Assistant",
1819 Role::System => "System",
1820 }
1821 )?;
1822
1823 if !message.context.is_empty() {
1824 writeln!(markdown, "{}", message.context)?;
1825 }
1826
1827 for segment in &message.segments {
1828 match segment {
1829 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1830 MessageSegment::Thinking(text) => {
1831 writeln!(markdown, "<think>{}</think>\n", text)?
1832 }
1833 }
1834 }
1835
1836 for tool_use in self.tool_uses_for_message(message.id, cx) {
1837 writeln!(
1838 markdown,
1839 "**Use Tool: {} ({})**",
1840 tool_use.name, tool_use.id
1841 )?;
1842 writeln!(markdown, "```json")?;
1843 writeln!(
1844 markdown,
1845 "{}",
1846 serde_json::to_string_pretty(&tool_use.input)?
1847 )?;
1848 writeln!(markdown, "```")?;
1849 }
1850
1851 for tool_result in self.tool_results_for_message(message.id) {
1852 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1853 if tool_result.is_error {
1854 write!(markdown, " (Error)")?;
1855 }
1856
1857 writeln!(markdown, "**\n")?;
1858 writeln!(markdown, "{}", tool_result.content)?;
1859 }
1860 }
1861
1862 Ok(String::from_utf8_lossy(&markdown).to_string())
1863 }
1864
1865 pub fn keep_edits_in_range(
1866 &mut self,
1867 buffer: Entity<language::Buffer>,
1868 buffer_range: Range<language::Anchor>,
1869 cx: &mut Context<Self>,
1870 ) {
1871 self.action_log.update(cx, |action_log, cx| {
1872 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1873 });
1874 }
1875
1876 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1877 self.action_log
1878 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1879 }
1880
1881 pub fn reject_edits_in_ranges(
1882 &mut self,
1883 buffer: Entity<language::Buffer>,
1884 buffer_ranges: Vec<Range<language::Anchor>>,
1885 cx: &mut Context<Self>,
1886 ) -> Task<Result<()>> {
1887 self.action_log.update(cx, |action_log, cx| {
1888 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1889 })
1890 }
1891
1892 pub fn action_log(&self) -> &Entity<ActionLog> {
1893 &self.action_log
1894 }
1895
1896 pub fn project(&self) -> &Entity<Project> {
1897 &self.project
1898 }
1899
1900 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1901 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1902 return;
1903 }
1904
1905 let now = Instant::now();
1906 if let Some(last) = self.last_auto_capture_at {
1907 if now.duration_since(last).as_secs() < 10 {
1908 return;
1909 }
1910 }
1911
1912 self.last_auto_capture_at = Some(now);
1913
1914 let thread_id = self.id().clone();
1915 let github_login = self
1916 .project
1917 .read(cx)
1918 .user_store()
1919 .read(cx)
1920 .current_user()
1921 .map(|user| user.github_login.clone());
1922 let client = self.project.read(cx).client().clone();
1923 let serialize_task = self.serialize(cx);
1924
1925 cx.background_executor()
1926 .spawn(async move {
1927 if let Ok(serialized_thread) = serialize_task.await {
1928 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1929 telemetry::event!(
1930 "Agent Thread Auto-Captured",
1931 thread_id = thread_id.to_string(),
1932 thread_data = thread_data,
1933 auto_capture_reason = "tracked_user",
1934 github_login = github_login
1935 );
1936
1937 client.telemetry().flush_events();
1938 }
1939 }
1940 })
1941 .detach();
1942 }
1943
1944 pub fn cumulative_token_usage(&self) -> TokenUsage {
1945 self.cumulative_token_usage
1946 }
1947
1948 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1949 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1950 return TotalTokenUsage::default();
1951 };
1952
1953 let max = model.model.max_token_count();
1954
1955 let index = self
1956 .messages
1957 .iter()
1958 .position(|msg| msg.id == message_id)
1959 .unwrap_or(0);
1960
1961 if index == 0 {
1962 return TotalTokenUsage { total: 0, max };
1963 }
1964
1965 let token_usage = &self
1966 .request_token_usage
1967 .get(index - 1)
1968 .cloned()
1969 .unwrap_or_default();
1970
1971 TotalTokenUsage {
1972 total: token_usage.total_tokens() as usize,
1973 max,
1974 }
1975 }
1976
1977 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1978 let model_registry = LanguageModelRegistry::read_global(cx);
1979 let Some(model) = model_registry.default_model() else {
1980 return TotalTokenUsage::default();
1981 };
1982
1983 let max = model.model.max_token_count();
1984
1985 if let Some(exceeded_error) = &self.exceeded_window_error {
1986 if model.model.id() == exceeded_error.model_id {
1987 return TotalTokenUsage {
1988 total: exceeded_error.token_count,
1989 max,
1990 };
1991 }
1992 }
1993
1994 let total = self
1995 .token_usage_at_last_message()
1996 .unwrap_or_default()
1997 .total_tokens() as usize;
1998
1999 TotalTokenUsage { total, max }
2000 }
2001
2002 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2003 self.request_token_usage
2004 .get(self.messages.len().saturating_sub(1))
2005 .or_else(|| self.request_token_usage.last())
2006 .cloned()
2007 }
2008
2009 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2010 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2011 self.request_token_usage
2012 .resize(self.messages.len(), placeholder);
2013
2014 if let Some(last) = self.request_token_usage.last_mut() {
2015 *last = token_usage;
2016 }
2017 }
2018
2019 pub fn deny_tool_use(
2020 &mut self,
2021 tool_use_id: LanguageModelToolUseId,
2022 tool_name: Arc<str>,
2023 cx: &mut Context<Self>,
2024 ) {
2025 let err = Err(anyhow::anyhow!(
2026 "Permission to run tool action denied by user"
2027 ));
2028
2029 self.tool_use
2030 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2031 self.tool_finished(tool_use_id.clone(), None, true, cx);
2032 }
2033}
2034
2035#[derive(Debug, Clone, Error)]
2036pub enum ThreadError {
2037 #[error("Payment required")]
2038 PaymentRequired,
2039 #[error("Max monthly spend reached")]
2040 MaxMonthlySpendReached,
2041 #[error("Model request limit reached")]
2042 ModelRequestLimitReached { plan: Plan },
2043 #[error("Message {header}: {message}")]
2044 Message {
2045 header: SharedString,
2046 message: SharedString,
2047 },
2048}
2049
2050#[derive(Debug, Clone)]
2051pub enum ThreadEvent {
2052 ShowError(ThreadError),
2053 StreamedCompletion,
2054 StreamedAssistantText(MessageId, String),
2055 StreamedAssistantThinking(MessageId, String),
2056 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2057 MessageAdded(MessageId),
2058 MessageEdited(MessageId),
2059 MessageDeleted(MessageId),
2060 SummaryGenerated,
2061 SummaryChanged,
2062 UsePendingTools {
2063 tool_uses: Vec<PendingToolUse>,
2064 },
2065 ToolFinished {
2066 #[allow(unused)]
2067 tool_use_id: LanguageModelToolUseId,
2068 /// The pending tool use that corresponds to this tool.
2069 pending_tool_use: Option<PendingToolUse>,
2070 },
2071 CheckpointChanged,
2072 ToolConfirmationNeeded,
2073}
2074
2075impl EventEmitter<ThreadEvent> for Thread {}
2076
2077struct PendingCompletion {
2078 id: usize,
2079 _task: Task<()>,
2080}
2081
2082#[cfg(test)]
2083mod tests {
2084 use super::*;
2085 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2086 use assistant_settings::AssistantSettings;
2087 use context_server::ContextServerSettings;
2088 use editor::EditorSettings;
2089 use gpui::TestAppContext;
2090 use project::{FakeFs, Project};
2091 use prompt_store::PromptBuilder;
2092 use serde_json::json;
2093 use settings::{Settings, SettingsStore};
2094 use std::sync::Arc;
2095 use theme::ThemeSettings;
2096 use util::path;
2097 use workspace::Workspace;
2098
2099 #[gpui::test]
2100 async fn test_message_with_context(cx: &mut TestAppContext) {
2101 init_test_settings(cx);
2102
2103 let project = create_test_project(
2104 cx,
2105 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2106 )
2107 .await;
2108
2109 let (_workspace, _thread_store, thread, context_store) =
2110 setup_test_environment(cx, project.clone()).await;
2111
2112 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2113 .await
2114 .unwrap();
2115
2116 let context =
2117 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2118
2119 // Insert user message with context
2120 let message_id = thread.update(cx, |thread, cx| {
2121 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2122 });
2123
2124 // Check content and context in message object
2125 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2126
2127 // Use different path format strings based on platform for the test
2128 #[cfg(windows)]
2129 let path_part = r"test\code.rs";
2130 #[cfg(not(windows))]
2131 let path_part = "test/code.rs";
2132
2133 let expected_context = format!(
2134 r#"
2135<context>
2136The following items were attached by the user. You don't need to use other tools to read them.
2137
2138<files>
2139```rs {path_part}
2140fn main() {{
2141 println!("Hello, world!");
2142}}
2143```
2144</files>
2145</context>
2146"#
2147 );
2148
2149 assert_eq!(message.role, Role::User);
2150 assert_eq!(message.segments.len(), 1);
2151 assert_eq!(
2152 message.segments[0],
2153 MessageSegment::Text("Please explain this code".to_string())
2154 );
2155 assert_eq!(message.context, expected_context);
2156
2157 // Check message in request
2158 let request = thread.read_with(cx, |thread, cx| {
2159 thread.to_completion_request(RequestKind::Chat, cx)
2160 });
2161
2162 assert_eq!(request.messages.len(), 2);
2163 let expected_full_message = format!("{}Please explain this code", expected_context);
2164 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2165 }
2166
2167 #[gpui::test]
2168 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2169 init_test_settings(cx);
2170
2171 let project = create_test_project(
2172 cx,
2173 json!({
2174 "file1.rs": "fn function1() {}\n",
2175 "file2.rs": "fn function2() {}\n",
2176 "file3.rs": "fn function3() {}\n",
2177 }),
2178 )
2179 .await;
2180
2181 let (_, _thread_store, thread, context_store) =
2182 setup_test_environment(cx, project.clone()).await;
2183
2184 // Open files individually
2185 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2186 .await
2187 .unwrap();
2188 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2189 .await
2190 .unwrap();
2191 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2192 .await
2193 .unwrap();
2194
2195 // Get the context objects
2196 let contexts = context_store.update(cx, |store, _| store.context().clone());
2197 assert_eq!(contexts.len(), 3);
2198
2199 // First message with context 1
2200 let message1_id = thread.update(cx, |thread, cx| {
2201 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2202 });
2203
2204 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2205 let message2_id = thread.update(cx, |thread, cx| {
2206 thread.insert_user_message(
2207 "Message 2",
2208 vec![contexts[0].clone(), contexts[1].clone()],
2209 None,
2210 cx,
2211 )
2212 });
2213
2214 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2215 let message3_id = thread.update(cx, |thread, cx| {
2216 thread.insert_user_message(
2217 "Message 3",
2218 vec![
2219 contexts[0].clone(),
2220 contexts[1].clone(),
2221 contexts[2].clone(),
2222 ],
2223 None,
2224 cx,
2225 )
2226 });
2227
2228 // Check what contexts are included in each message
2229 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2230 (
2231 thread.message(message1_id).unwrap().clone(),
2232 thread.message(message2_id).unwrap().clone(),
2233 thread.message(message3_id).unwrap().clone(),
2234 )
2235 });
2236
2237 // First message should include context 1
2238 assert!(message1.context.contains("file1.rs"));
2239
2240 // Second message should include only context 2 (not 1)
2241 assert!(!message2.context.contains("file1.rs"));
2242 assert!(message2.context.contains("file2.rs"));
2243
2244 // Third message should include only context 3 (not 1 or 2)
2245 assert!(!message3.context.contains("file1.rs"));
2246 assert!(!message3.context.contains("file2.rs"));
2247 assert!(message3.context.contains("file3.rs"));
2248
2249 // Check entire request to make sure all contexts are properly included
2250 let request = thread.read_with(cx, |thread, cx| {
2251 thread.to_completion_request(RequestKind::Chat, cx)
2252 });
2253
2254 // The request should contain all 3 messages
2255 assert_eq!(request.messages.len(), 4);
2256
2257 // Check that the contexts are properly formatted in each message
2258 assert!(request.messages[1].string_contents().contains("file1.rs"));
2259 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2260 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2261
2262 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2263 assert!(request.messages[2].string_contents().contains("file2.rs"));
2264 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2265
2266 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2267 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2268 assert!(request.messages[3].string_contents().contains("file3.rs"));
2269 }
2270
2271 #[gpui::test]
2272 async fn test_message_without_files(cx: &mut TestAppContext) {
2273 init_test_settings(cx);
2274
2275 let project = create_test_project(
2276 cx,
2277 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2278 )
2279 .await;
2280
2281 let (_, _thread_store, thread, _context_store) =
2282 setup_test_environment(cx, project.clone()).await;
2283
2284 // Insert user message without any context (empty context vector)
2285 let message_id = thread.update(cx, |thread, cx| {
2286 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2287 });
2288
2289 // Check content and context in message object
2290 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2291
2292 // Context should be empty when no files are included
2293 assert_eq!(message.role, Role::User);
2294 assert_eq!(message.segments.len(), 1);
2295 assert_eq!(
2296 message.segments[0],
2297 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2298 );
2299 assert_eq!(message.context, "");
2300
2301 // Check message in request
2302 let request = thread.read_with(cx, |thread, cx| {
2303 thread.to_completion_request(RequestKind::Chat, cx)
2304 });
2305
2306 assert_eq!(request.messages.len(), 2);
2307 assert_eq!(
2308 request.messages[1].string_contents(),
2309 "What is the best way to learn Rust?"
2310 );
2311
2312 // Add second message, also without context
2313 let message2_id = thread.update(cx, |thread, cx| {
2314 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2315 });
2316
2317 let message2 =
2318 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2319 assert_eq!(message2.context, "");
2320
2321 // Check that both messages appear in the request
2322 let request = thread.read_with(cx, |thread, cx| {
2323 thread.to_completion_request(RequestKind::Chat, cx)
2324 });
2325
2326 assert_eq!(request.messages.len(), 3);
2327 assert_eq!(
2328 request.messages[1].string_contents(),
2329 "What is the best way to learn Rust?"
2330 );
2331 assert_eq!(
2332 request.messages[2].string_contents(),
2333 "Are there any good books?"
2334 );
2335 }
2336
2337 #[gpui::test]
2338 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2339 init_test_settings(cx);
2340
2341 let project = create_test_project(
2342 cx,
2343 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2344 )
2345 .await;
2346
2347 let (_workspace, _thread_store, thread, context_store) =
2348 setup_test_environment(cx, project.clone()).await;
2349
2350 // Open buffer and add it to context
2351 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2352 .await
2353 .unwrap();
2354
2355 let context =
2356 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2357
2358 // Insert user message with the buffer as context
2359 thread.update(cx, |thread, cx| {
2360 thread.insert_user_message("Explain this code", vec![context], None, cx)
2361 });
2362
2363 // Create a request and check that it doesn't have a stale buffer warning yet
2364 let initial_request = thread.read_with(cx, |thread, cx| {
2365 thread.to_completion_request(RequestKind::Chat, cx)
2366 });
2367
2368 // Make sure we don't have a stale file warning yet
2369 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2370 msg.string_contents()
2371 .contains("These files changed since last read:")
2372 });
2373 assert!(
2374 !has_stale_warning,
2375 "Should not have stale buffer warning before buffer is modified"
2376 );
2377
2378 // Modify the buffer
2379 buffer.update(cx, |buffer, cx| {
2380 // Find a position at the end of line 1
2381 buffer.edit(
2382 [(1..1, "\n println!(\"Added a new line\");\n")],
2383 None,
2384 cx,
2385 );
2386 });
2387
2388 // Insert another user message without context
2389 thread.update(cx, |thread, cx| {
2390 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2391 });
2392
2393 // Create a new request and check for the stale buffer warning
2394 let new_request = thread.read_with(cx, |thread, cx| {
2395 thread.to_completion_request(RequestKind::Chat, cx)
2396 });
2397
2398 // We should have a stale file warning as the last message
2399 let last_message = new_request
2400 .messages
2401 .last()
2402 .expect("Request should have messages");
2403
2404 // The last message should be the stale buffer notification
2405 assert_eq!(last_message.role, Role::User);
2406
2407 // Check the exact content of the message
2408 let expected_content = "These files changed since last read:\n- code.rs\n";
2409 assert_eq!(
2410 last_message.string_contents(),
2411 expected_content,
2412 "Last message should be exactly the stale buffer notification"
2413 );
2414 }
2415
2416 fn init_test_settings(cx: &mut TestAppContext) {
2417 cx.update(|cx| {
2418 let settings_store = SettingsStore::test(cx);
2419 cx.set_global(settings_store);
2420 language::init(cx);
2421 Project::init_settings(cx);
2422 AssistantSettings::register(cx);
2423 thread_store::init(cx);
2424 workspace::init_settings(cx);
2425 ThemeSettings::register(cx);
2426 ContextServerSettings::register(cx);
2427 EditorSettings::register(cx);
2428 });
2429 }
2430
2431 // Helper to create a test project with test files
2432 async fn create_test_project(
2433 cx: &mut TestAppContext,
2434 files: serde_json::Value,
2435 ) -> Entity<Project> {
2436 let fs = FakeFs::new(cx.executor());
2437 fs.insert_tree(path!("/test"), files).await;
2438 Project::test(fs, [path!("/test").as_ref()], cx).await
2439 }
2440
2441 async fn setup_test_environment(
2442 cx: &mut TestAppContext,
2443 project: Entity<Project>,
2444 ) -> (
2445 Entity<Workspace>,
2446 Entity<ThreadStore>,
2447 Entity<Thread>,
2448 Entity<ContextStore>,
2449 ) {
2450 let (workspace, cx) =
2451 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2452
2453 let thread_store = cx
2454 .update(|_, cx| {
2455 ThreadStore::load(
2456 project.clone(),
2457 cx.new(|_| ToolWorkingSet::default()),
2458 Arc::new(PromptBuilder::new(None).unwrap()),
2459 cx,
2460 )
2461 })
2462 .await;
2463
2464 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2465 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2466
2467 (workspace, thread_store, thread, context_store)
2468 }
2469
2470 async fn add_file_to_context(
2471 project: &Entity<Project>,
2472 context_store: &Entity<ContextStore>,
2473 path: &str,
2474 cx: &mut TestAppContext,
2475 ) -> Result<Entity<language::Buffer>> {
2476 let buffer_path = project
2477 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2478 .unwrap();
2479
2480 let buffer = project
2481 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2482 .await
2483 .unwrap();
2484
2485 context_store
2486 .update(cx, |store, cx| {
2487 store.add_file_from_buffer(buffer.clone(), cx)
2488 })
2489 .await?;
2490
2491 Ok(buffer)
2492 }
2493}