1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
18};
19use language_model::{
20 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
21 LanguageModelId, LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry,
22 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
23 LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
24 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
25 TokenUsage,
26};
27use project::Project;
28use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
29use prompt_store::PromptBuilder;
30use proto::Plan;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use settings::Settings;
34use thiserror::Error;
35use util::{ResultExt as _, TryFutureExt as _, post_inc};
36use uuid::Uuid;
37
38use crate::context::{AssistantContext, ContextId, format_context_as_string};
39use crate::thread_store::{
40 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
41 SerializedToolUse, SharedProjectContext,
42};
43use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
44
45#[derive(
46 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
47)]
48pub struct ThreadId(Arc<str>);
49
50impl ThreadId {
51 pub fn new() -> Self {
52 Self(Uuid::new_v4().to_string().into())
53 }
54}
55
56impl std::fmt::Display for ThreadId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62impl From<&str> for ThreadId {
63 fn from(value: &str) -> Self {
64 Self(value.into())
65 }
66}
67
68/// The ID of the user prompt that initiated a request.
69///
70/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
72pub struct PromptId(Arc<str>);
73
74impl PromptId {
75 pub fn new() -> Self {
76 Self(Uuid::new_v4().to_string().into())
77 }
78}
79
80impl std::fmt::Display for PromptId {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", self.0)
83 }
84}
85
86#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
87pub struct MessageId(pub(crate) usize);
88
89impl MessageId {
90 fn post_inc(&mut self) -> Self {
91 Self(post_inc(&mut self.0))
92 }
93}
94
95/// A message in a [`Thread`].
96#[derive(Debug, Clone)]
97pub struct Message {
98 pub id: MessageId,
99 pub role: Role,
100 pub segments: Vec<MessageSegment>,
101 pub context: String,
102 pub images: Vec<LanguageModelImage>,
103}
104
105impl Message {
106 /// Returns whether the message contains any meaningful text that should be displayed
107 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
108 pub fn should_display_content(&self) -> bool {
109 self.segments.iter().all(|segment| segment.should_display())
110 }
111
112 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
113 if let Some(MessageSegment::Thinking {
114 text: segment,
115 signature: current_signature,
116 }) = self.segments.last_mut()
117 {
118 if let Some(signature) = signature {
119 *current_signature = Some(signature);
120 }
121 segment.push_str(text);
122 } else {
123 self.segments.push(MessageSegment::Thinking {
124 text: text.to_string(),
125 signature,
126 });
127 }
128 }
129
130 pub fn push_text(&mut self, text: &str) {
131 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
132 segment.push_str(text);
133 } else {
134 self.segments.push(MessageSegment::Text(text.to_string()));
135 }
136 }
137
138 pub fn to_string(&self) -> String {
139 let mut result = String::new();
140
141 if !self.context.is_empty() {
142 result.push_str(&self.context);
143 }
144
145 for segment in &self.segments {
146 match segment {
147 MessageSegment::Text(text) => result.push_str(text),
148 MessageSegment::Thinking { text, .. } => {
149 result.push_str("<think>\n");
150 result.push_str(text);
151 result.push_str("\n</think>");
152 }
153 MessageSegment::RedactedThinking(_) => {}
154 }
155 }
156
157 result
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
162pub enum MessageSegment {
163 Text(String),
164 Thinking {
165 text: String,
166 signature: Option<String>,
167 },
168 RedactedThinking(Vec<u8>),
169}
170
171impl MessageSegment {
172 pub fn should_display(&self) -> bool {
173 match self {
174 Self::Text(text) => text.is_empty(),
175 Self::Thinking { text, .. } => text.is_empty(),
176 Self::RedactedThinking(_) => false,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ProjectSnapshot {
183 pub worktree_snapshots: Vec<WorktreeSnapshot>,
184 pub unsaved_buffer_paths: Vec<String>,
185 pub timestamp: DateTime<Utc>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WorktreeSnapshot {
190 pub worktree_path: String,
191 pub git_state: Option<GitState>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct GitState {
196 pub remote_url: Option<String>,
197 pub head_sha: Option<String>,
198 pub current_branch: Option<String>,
199 pub diff: Option<String>,
200}
201
202#[derive(Clone)]
203pub struct ThreadCheckpoint {
204 message_id: MessageId,
205 git_checkpoint: GitStoreCheckpoint,
206}
207
208#[derive(Copy, Clone, Debug, PartialEq, Eq)]
209pub enum ThreadFeedback {
210 Positive,
211 Negative,
212}
213
214pub enum LastRestoreCheckpoint {
215 Pending {
216 message_id: MessageId,
217 },
218 Error {
219 message_id: MessageId,
220 error: String,
221 },
222}
223
224impl LastRestoreCheckpoint {
225 pub fn message_id(&self) -> MessageId {
226 match self {
227 LastRestoreCheckpoint::Pending { message_id } => *message_id,
228 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default, Serialize, Deserialize)]
234pub enum DetailedSummaryState {
235 #[default]
236 NotGenerated,
237 Generating {
238 message_id: MessageId,
239 },
240 Generated {
241 text: SharedString,
242 message_id: MessageId,
243 },
244}
245
246#[derive(Default)]
247pub struct TotalTokenUsage {
248 pub total: usize,
249 pub max: usize,
250}
251
252impl TotalTokenUsage {
253 pub fn ratio(&self) -> TokenUsageRatio {
254 #[cfg(debug_assertions)]
255 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
256 .unwrap_or("0.8".to_string())
257 .parse()
258 .unwrap();
259 #[cfg(not(debug_assertions))]
260 let warning_threshold: f32 = 0.8;
261
262 if self.total >= self.max {
263 TokenUsageRatio::Exceeded
264 } else if self.total as f32 / self.max as f32 >= warning_threshold {
265 TokenUsageRatio::Warning
266 } else {
267 TokenUsageRatio::Normal
268 }
269 }
270
271 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
272 TotalTokenUsage {
273 total: self.total + tokens,
274 max: self.max,
275 }
276 }
277}
278
279#[derive(Debug, Default, PartialEq, Eq)]
280pub enum TokenUsageRatio {
281 #[default]
282 Normal,
283 Warning,
284 Exceeded,
285}
286
287/// A thread of conversation with the LLM.
288pub struct Thread {
289 id: ThreadId,
290 updated_at: DateTime<Utc>,
291 summary: Option<SharedString>,
292 pending_summary: Task<Option<()>>,
293 detailed_summary_state: DetailedSummaryState,
294 messages: Vec<Message>,
295 next_message_id: MessageId,
296 last_prompt_id: PromptId,
297 context: BTreeMap<ContextId, AssistantContext>,
298 context_by_message: HashMap<MessageId, Vec<ContextId>>,
299 project_context: SharedProjectContext,
300 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
301 completion_count: usize,
302 pending_completions: Vec<PendingCompletion>,
303 project: Entity<Project>,
304 prompt_builder: Arc<PromptBuilder>,
305 tools: Entity<ToolWorkingSet>,
306 tool_use: ToolUseState,
307 action_log: Entity<ActionLog>,
308 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
309 pending_checkpoint: Option<ThreadCheckpoint>,
310 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
311 request_token_usage: Vec<TokenUsage>,
312 cumulative_token_usage: TokenUsage,
313 exceeded_window_error: Option<ExceededWindowError>,
314 feedback: Option<ThreadFeedback>,
315 message_feedback: HashMap<MessageId, ThreadFeedback>,
316 last_auto_capture_at: Option<Instant>,
317 request_callback: Option<
318 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
319 >,
320 remaining_turns: u32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ExceededWindowError {
325 /// Model used when last message exceeded context window
326 model_id: LanguageModelId,
327 /// Token count including last message
328 token_count: usize,
329}
330
331impl Thread {
332 pub fn new(
333 project: Entity<Project>,
334 tools: Entity<ToolWorkingSet>,
335 prompt_builder: Arc<PromptBuilder>,
336 system_prompt: SharedProjectContext,
337 cx: &mut Context<Self>,
338 ) -> Self {
339 Self {
340 id: ThreadId::new(),
341 updated_at: Utc::now(),
342 summary: None,
343 pending_summary: Task::ready(None),
344 detailed_summary_state: DetailedSummaryState::NotGenerated,
345 messages: Vec::new(),
346 next_message_id: MessageId(0),
347 last_prompt_id: PromptId::new(),
348 context: BTreeMap::default(),
349 context_by_message: HashMap::default(),
350 project_context: system_prompt,
351 checkpoints_by_message: HashMap::default(),
352 completion_count: 0,
353 pending_completions: Vec::new(),
354 project: project.clone(),
355 prompt_builder,
356 tools: tools.clone(),
357 last_restore_checkpoint: None,
358 pending_checkpoint: None,
359 tool_use: ToolUseState::new(tools.clone()),
360 action_log: cx.new(|_| ActionLog::new(project.clone())),
361 initial_project_snapshot: {
362 let project_snapshot = Self::project_snapshot(project, cx);
363 cx.foreground_executor()
364 .spawn(async move { Some(project_snapshot.await) })
365 .shared()
366 },
367 request_token_usage: Vec::new(),
368 cumulative_token_usage: TokenUsage::default(),
369 exceeded_window_error: None,
370 feedback: None,
371 message_feedback: HashMap::default(),
372 last_auto_capture_at: None,
373 request_callback: None,
374 remaining_turns: u32::MAX,
375 }
376 }
377
378 pub fn deserialize(
379 id: ThreadId,
380 serialized: SerializedThread,
381 project: Entity<Project>,
382 tools: Entity<ToolWorkingSet>,
383 prompt_builder: Arc<PromptBuilder>,
384 project_context: SharedProjectContext,
385 cx: &mut Context<Self>,
386 ) -> Self {
387 let next_message_id = MessageId(
388 serialized
389 .messages
390 .last()
391 .map(|message| message.id.0 + 1)
392 .unwrap_or(0),
393 );
394 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
395
396 Self {
397 id,
398 updated_at: serialized.updated_at,
399 summary: Some(serialized.summary),
400 pending_summary: Task::ready(None),
401 detailed_summary_state: serialized.detailed_summary_state,
402 messages: serialized
403 .messages
404 .into_iter()
405 .map(|message| Message {
406 id: message.id,
407 role: message.role,
408 segments: message
409 .segments
410 .into_iter()
411 .map(|segment| match segment {
412 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
413 SerializedMessageSegment::Thinking { text, signature } => {
414 MessageSegment::Thinking { text, signature }
415 }
416 SerializedMessageSegment::RedactedThinking { data } => {
417 MessageSegment::RedactedThinking(data)
418 }
419 })
420 .collect(),
421 context: message.context,
422 images: Vec::new(),
423 })
424 .collect(),
425 next_message_id,
426 last_prompt_id: PromptId::new(),
427 context: BTreeMap::default(),
428 context_by_message: HashMap::default(),
429 project_context,
430 checkpoints_by_message: HashMap::default(),
431 completion_count: 0,
432 pending_completions: Vec::new(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 project: project.clone(),
436 prompt_builder,
437 tools,
438 tool_use,
439 action_log: cx.new(|_| ActionLog::new(project)),
440 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
441 request_token_usage: serialized.request_token_usage,
442 cumulative_token_usage: serialized.cumulative_token_usage,
443 exceeded_window_error: None,
444 feedback: None,
445 message_feedback: HashMap::default(),
446 last_auto_capture_at: None,
447 request_callback: None,
448 remaining_turns: u32::MAX,
449 }
450 }
451
452 pub fn set_request_callback(
453 &mut self,
454 callback: impl 'static
455 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
456 ) {
457 self.request_callback = Some(Box::new(callback));
458 }
459
460 pub fn id(&self) -> &ThreadId {
461 &self.id
462 }
463
464 pub fn is_empty(&self) -> bool {
465 self.messages.is_empty()
466 }
467
468 pub fn updated_at(&self) -> DateTime<Utc> {
469 self.updated_at
470 }
471
472 pub fn touch_updated_at(&mut self) {
473 self.updated_at = Utc::now();
474 }
475
476 pub fn advance_prompt_id(&mut self) {
477 self.last_prompt_id = PromptId::new();
478 }
479
480 pub fn summary(&self) -> Option<SharedString> {
481 self.summary.clone()
482 }
483
484 pub fn project_context(&self) -> SharedProjectContext {
485 self.project_context.clone()
486 }
487
488 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
489
490 pub fn summary_or_default(&self) -> SharedString {
491 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
492 }
493
494 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
495 let Some(current_summary) = &self.summary else {
496 // Don't allow setting summary until generated
497 return;
498 };
499
500 let mut new_summary = new_summary.into();
501
502 if new_summary.is_empty() {
503 new_summary = Self::DEFAULT_SUMMARY;
504 }
505
506 if current_summary != &new_summary {
507 self.summary = Some(new_summary);
508 cx.emit(ThreadEvent::SummaryChanged);
509 }
510 }
511
512 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
513 self.latest_detailed_summary()
514 .unwrap_or_else(|| self.text().into())
515 }
516
517 fn latest_detailed_summary(&self) -> Option<SharedString> {
518 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
519 Some(text.clone())
520 } else {
521 None
522 }
523 }
524
525 pub fn message(&self, id: MessageId) -> Option<&Message> {
526 let index = self
527 .messages
528 .binary_search_by(|message| message.id.cmp(&id))
529 .ok()?;
530
531 self.messages.get(index)
532 }
533
534 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
535 self.messages.iter()
536 }
537
538 pub fn is_generating(&self) -> bool {
539 !self.pending_completions.is_empty() || !self.all_tools_finished()
540 }
541
542 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
543 &self.tools
544 }
545
546 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
547 self.tool_use
548 .pending_tool_uses()
549 .into_iter()
550 .find(|tool_use| &tool_use.id == id)
551 }
552
553 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
554 self.tool_use
555 .pending_tool_uses()
556 .into_iter()
557 .filter(|tool_use| tool_use.status.needs_confirmation())
558 }
559
560 pub fn has_pending_tool_uses(&self) -> bool {
561 !self.tool_use.pending_tool_uses().is_empty()
562 }
563
564 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
565 self.checkpoints_by_message.get(&id).cloned()
566 }
567
568 pub fn restore_checkpoint(
569 &mut self,
570 checkpoint: ThreadCheckpoint,
571 cx: &mut Context<Self>,
572 ) -> Task<Result<()>> {
573 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
574 message_id: checkpoint.message_id,
575 });
576 cx.emit(ThreadEvent::CheckpointChanged);
577 cx.notify();
578
579 let git_store = self.project().read(cx).git_store().clone();
580 let restore = git_store.update(cx, |git_store, cx| {
581 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
582 });
583
584 cx.spawn(async move |this, cx| {
585 let result = restore.await;
586 this.update(cx, |this, cx| {
587 if let Err(err) = result.as_ref() {
588 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
589 message_id: checkpoint.message_id,
590 error: err.to_string(),
591 });
592 } else {
593 this.truncate(checkpoint.message_id, cx);
594 this.last_restore_checkpoint = None;
595 }
596 this.pending_checkpoint = None;
597 cx.emit(ThreadEvent::CheckpointChanged);
598 cx.notify();
599 })?;
600 result
601 })
602 }
603
604 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
605 let pending_checkpoint = if self.is_generating() {
606 return;
607 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
608 checkpoint
609 } else {
610 return;
611 };
612
613 let git_store = self.project.read(cx).git_store().clone();
614 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
615 cx.spawn(async move |this, cx| match final_checkpoint.await {
616 Ok(final_checkpoint) => {
617 let equal = git_store
618 .update(cx, |store, cx| {
619 store.compare_checkpoints(
620 pending_checkpoint.git_checkpoint.clone(),
621 final_checkpoint.clone(),
622 cx,
623 )
624 })?
625 .await
626 .unwrap_or(false);
627
628 if !equal {
629 this.update(cx, |this, cx| {
630 this.insert_checkpoint(pending_checkpoint, cx)
631 })?;
632 }
633
634 Ok(())
635 }
636 Err(_) => this.update(cx, |this, cx| {
637 this.insert_checkpoint(pending_checkpoint, cx)
638 }),
639 })
640 .detach();
641 }
642
643 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
644 self.checkpoints_by_message
645 .insert(checkpoint.message_id, checkpoint);
646 cx.emit(ThreadEvent::CheckpointChanged);
647 cx.notify();
648 }
649
650 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
651 self.last_restore_checkpoint.as_ref()
652 }
653
654 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
655 let Some(message_ix) = self
656 .messages
657 .iter()
658 .rposition(|message| message.id == message_id)
659 else {
660 return;
661 };
662 for deleted_message in self.messages.drain(message_ix..) {
663 self.context_by_message.remove(&deleted_message.id);
664 self.checkpoints_by_message.remove(&deleted_message.id);
665 }
666 cx.notify();
667 }
668
669 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
670 self.context_by_message
671 .get(&id)
672 .into_iter()
673 .flat_map(|context| {
674 context
675 .iter()
676 .filter_map(|context_id| self.context.get(&context_id))
677 })
678 }
679
680 pub fn is_turn_end(&self, ix: usize) -> bool {
681 if self.messages.is_empty() {
682 return false;
683 }
684
685 if !self.is_generating() && ix == self.messages.len() - 1 {
686 return true;
687 }
688
689 let Some(message) = self.messages.get(ix) else {
690 return false;
691 };
692
693 if message.role != Role::Assistant {
694 return false;
695 }
696
697 self.messages
698 .get(ix + 1)
699 .and_then(|message| {
700 self.message(message.id)
701 .map(|next_message| next_message.role == Role::User)
702 })
703 .unwrap_or(false)
704 }
705
706 /// Returns whether all of the tool uses have finished running.
707 pub fn all_tools_finished(&self) -> bool {
708 // If the only pending tool uses left are the ones with errors, then
709 // that means that we've finished running all of the pending tools.
710 self.tool_use
711 .pending_tool_uses()
712 .iter()
713 .all(|tool_use| tool_use.status.is_error())
714 }
715
716 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
717 self.tool_use.tool_uses_for_message(id, cx)
718 }
719
720 pub fn tool_results_for_message(
721 &self,
722 assistant_message_id: MessageId,
723 ) -> Vec<&LanguageModelToolResult> {
724 self.tool_use.tool_results_for_message(assistant_message_id)
725 }
726
727 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
728 self.tool_use.tool_result(id)
729 }
730
731 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
732 Some(&self.tool_use.tool_result(id)?.content)
733 }
734
735 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
736 self.tool_use.tool_result_card(id).cloned()
737 }
738
739 /// Filter out contexts that have already been included in previous messages
740 pub fn filter_new_context<'a>(
741 &self,
742 context: impl Iterator<Item = &'a AssistantContext>,
743 ) -> impl Iterator<Item = &'a AssistantContext> {
744 context.filter(|ctx| self.is_context_new(ctx))
745 }
746
747 fn is_context_new(&self, context: &AssistantContext) -> bool {
748 !self.context.contains_key(&context.id())
749 }
750
751 pub fn insert_user_message(
752 &mut self,
753 text: impl Into<String>,
754 context: Vec<AssistantContext>,
755 git_checkpoint: Option<GitStoreCheckpoint>,
756 cx: &mut Context<Self>,
757 ) -> MessageId {
758 let text = text.into();
759
760 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
761
762 let new_context: Vec<_> = context
763 .into_iter()
764 .filter(|ctx| self.is_context_new(ctx))
765 .collect();
766
767 if !new_context.is_empty() {
768 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
769 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
770 message.context = context_string;
771 }
772 }
773
774 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
775 message.images = new_context
776 .iter()
777 .filter_map(|context| {
778 if let AssistantContext::Image(image_context) = context {
779 image_context.image_task.clone().now_or_never().flatten()
780 } else {
781 None
782 }
783 })
784 .collect::<Vec<_>>();
785 }
786
787 self.action_log.update(cx, |log, cx| {
788 // Track all buffers added as context
789 for ctx in &new_context {
790 match ctx {
791 AssistantContext::File(file_ctx) => {
792 log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
793 }
794 AssistantContext::Directory(dir_ctx) => {
795 for context_buffer in &dir_ctx.context_buffers {
796 log.track_buffer(context_buffer.buffer.clone(), cx);
797 }
798 }
799 AssistantContext::Symbol(symbol_ctx) => {
800 log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
801 }
802 AssistantContext::Selection(selection_context) => {
803 log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
804 }
805 AssistantContext::FetchedUrl(_)
806 | AssistantContext::Thread(_)
807 | AssistantContext::Rules(_)
808 | AssistantContext::Image(_) => {}
809 }
810 }
811 });
812 }
813
814 let context_ids = new_context
815 .iter()
816 .map(|context| context.id())
817 .collect::<Vec<_>>();
818 self.context.extend(
819 new_context
820 .into_iter()
821 .map(|context| (context.id(), context)),
822 );
823 self.context_by_message.insert(message_id, context_ids);
824
825 if let Some(git_checkpoint) = git_checkpoint {
826 self.pending_checkpoint = Some(ThreadCheckpoint {
827 message_id,
828 git_checkpoint,
829 });
830 }
831
832 self.auto_capture_telemetry(cx);
833
834 message_id
835 }
836
837 pub fn insert_message(
838 &mut self,
839 role: Role,
840 segments: Vec<MessageSegment>,
841 cx: &mut Context<Self>,
842 ) -> MessageId {
843 let id = self.next_message_id.post_inc();
844 self.messages.push(Message {
845 id,
846 role,
847 segments,
848 context: String::new(),
849 images: Vec::new(),
850 });
851 self.touch_updated_at();
852 cx.emit(ThreadEvent::MessageAdded(id));
853 id
854 }
855
856 pub fn edit_message(
857 &mut self,
858 id: MessageId,
859 new_role: Role,
860 new_segments: Vec<MessageSegment>,
861 cx: &mut Context<Self>,
862 ) -> bool {
863 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
864 return false;
865 };
866 message.role = new_role;
867 message.segments = new_segments;
868 self.touch_updated_at();
869 cx.emit(ThreadEvent::MessageEdited(id));
870 true
871 }
872
873 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
874 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
875 return false;
876 };
877 self.messages.remove(index);
878 self.context_by_message.remove(&id);
879 self.touch_updated_at();
880 cx.emit(ThreadEvent::MessageDeleted(id));
881 true
882 }
883
884 /// Returns the representation of this [`Thread`] in a textual form.
885 ///
886 /// This is the representation we use when attaching a thread as context to another thread.
887 pub fn text(&self) -> String {
888 let mut text = String::new();
889
890 for message in &self.messages {
891 text.push_str(match message.role {
892 language_model::Role::User => "User:",
893 language_model::Role::Assistant => "Assistant:",
894 language_model::Role::System => "System:",
895 });
896 text.push('\n');
897
898 for segment in &message.segments {
899 match segment {
900 MessageSegment::Text(content) => text.push_str(content),
901 MessageSegment::Thinking { text: content, .. } => {
902 text.push_str(&format!("<think>{}</think>", content))
903 }
904 MessageSegment::RedactedThinking(_) => {}
905 }
906 }
907 text.push('\n');
908 }
909
910 text
911 }
912
913 /// Serializes this thread into a format for storage or telemetry.
914 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
915 let initial_project_snapshot = self.initial_project_snapshot.clone();
916 cx.spawn(async move |this, cx| {
917 let initial_project_snapshot = initial_project_snapshot.await;
918 this.read_with(cx, |this, cx| SerializedThread {
919 version: SerializedThread::VERSION.to_string(),
920 summary: this.summary_or_default(),
921 updated_at: this.updated_at(),
922 messages: this
923 .messages()
924 .map(|message| SerializedMessage {
925 id: message.id,
926 role: message.role,
927 segments: message
928 .segments
929 .iter()
930 .map(|segment| match segment {
931 MessageSegment::Text(text) => {
932 SerializedMessageSegment::Text { text: text.clone() }
933 }
934 MessageSegment::Thinking { text, signature } => {
935 SerializedMessageSegment::Thinking {
936 text: text.clone(),
937 signature: signature.clone(),
938 }
939 }
940 MessageSegment::RedactedThinking(data) => {
941 SerializedMessageSegment::RedactedThinking {
942 data: data.clone(),
943 }
944 }
945 })
946 .collect(),
947 tool_uses: this
948 .tool_uses_for_message(message.id, cx)
949 .into_iter()
950 .map(|tool_use| SerializedToolUse {
951 id: tool_use.id,
952 name: tool_use.name,
953 input: tool_use.input,
954 })
955 .collect(),
956 tool_results: this
957 .tool_results_for_message(message.id)
958 .into_iter()
959 .map(|tool_result| SerializedToolResult {
960 tool_use_id: tool_result.tool_use_id.clone(),
961 is_error: tool_result.is_error,
962 content: tool_result.content.clone(),
963 })
964 .collect(),
965 context: message.context.clone(),
966 })
967 .collect(),
968 initial_project_snapshot,
969 cumulative_token_usage: this.cumulative_token_usage,
970 request_token_usage: this.request_token_usage.clone(),
971 detailed_summary_state: this.detailed_summary_state.clone(),
972 exceeded_window_error: this.exceeded_window_error.clone(),
973 })
974 })
975 }
976
977 pub fn remaining_turns(&self) -> u32 {
978 self.remaining_turns
979 }
980
981 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
982 self.remaining_turns = remaining_turns;
983 }
984
985 pub fn send_to_model(
986 &mut self,
987 model: Arc<dyn LanguageModel>,
988 window: Option<AnyWindowHandle>,
989 cx: &mut Context<Self>,
990 ) {
991 if self.remaining_turns == 0 {
992 return;
993 }
994
995 self.remaining_turns -= 1;
996
997 let mut request = self.to_completion_request(cx);
998 if model.supports_tools() {
999 request.tools = {
1000 let mut tools = Vec::new();
1001 tools.extend(
1002 self.tools()
1003 .read(cx)
1004 .enabled_tools(cx)
1005 .into_iter()
1006 .filter_map(|tool| {
1007 // Skip tools that cannot be supported
1008 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
1009 Some(LanguageModelRequestTool {
1010 name: tool.name(),
1011 description: tool.description(),
1012 input_schema,
1013 })
1014 }),
1015 );
1016
1017 tools
1018 };
1019 }
1020
1021 self.stream_completion(request, model, window, cx);
1022 }
1023
1024 pub fn used_tools_since_last_user_message(&self) -> bool {
1025 for message in self.messages.iter().rev() {
1026 if self.tool_use.message_has_tool_results(message.id) {
1027 return true;
1028 } else if message.role == Role::User {
1029 return false;
1030 }
1031 }
1032
1033 false
1034 }
1035
1036 pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
1037 let mut request = LanguageModelRequest {
1038 thread_id: Some(self.id.to_string()),
1039 prompt_id: Some(self.last_prompt_id.to_string()),
1040 messages: vec![],
1041 tools: Vec::new(),
1042 stop: Vec::new(),
1043 temperature: None,
1044 };
1045
1046 if let Some(project_context) = self.project_context.borrow().as_ref() {
1047 match self
1048 .prompt_builder
1049 .generate_assistant_system_prompt(project_context)
1050 {
1051 Err(err) => {
1052 let message = format!("{err:?}").into();
1053 log::error!("{message}");
1054 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1055 header: "Error generating system prompt".into(),
1056 message,
1057 }));
1058 }
1059 Ok(system_prompt) => {
1060 request.messages.push(LanguageModelRequestMessage {
1061 role: Role::System,
1062 content: vec![MessageContent::Text(system_prompt)],
1063 cache: true,
1064 });
1065 }
1066 }
1067 } else {
1068 let message = "Context for system prompt unexpectedly not ready.".into();
1069 log::error!("{message}");
1070 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1071 header: "Error generating system prompt".into(),
1072 message,
1073 }));
1074 }
1075
1076 for message in &self.messages {
1077 let mut request_message = LanguageModelRequestMessage {
1078 role: message.role,
1079 content: Vec::new(),
1080 cache: false,
1081 };
1082
1083 if !message.context.is_empty() {
1084 request_message
1085 .content
1086 .push(MessageContent::Text(message.context.to_string()));
1087 }
1088
1089 if !message.images.is_empty() {
1090 // Some providers only support image parts after an initial text part
1091 if request_message.content.is_empty() {
1092 request_message
1093 .content
1094 .push(MessageContent::Text("Images attached by user:".to_string()));
1095 }
1096
1097 for image in &message.images {
1098 request_message
1099 .content
1100 .push(MessageContent::Image(image.clone()))
1101 }
1102 }
1103
1104 for segment in &message.segments {
1105 match segment {
1106 MessageSegment::Text(text) => {
1107 if !text.is_empty() {
1108 request_message
1109 .content
1110 .push(MessageContent::Text(text.into()));
1111 }
1112 }
1113 MessageSegment::Thinking { text, signature } => {
1114 if !text.is_empty() {
1115 request_message.content.push(MessageContent::Thinking {
1116 text: text.into(),
1117 signature: signature.clone(),
1118 });
1119 }
1120 }
1121 MessageSegment::RedactedThinking(data) => {
1122 request_message
1123 .content
1124 .push(MessageContent::RedactedThinking(data.clone()));
1125 }
1126 };
1127 }
1128
1129 self.tool_use
1130 .attach_tool_uses(message.id, &mut request_message);
1131
1132 request.messages.push(request_message);
1133
1134 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1135 request.messages.push(tool_results_message);
1136 }
1137 }
1138
1139 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1140 if let Some(last) = request.messages.last_mut() {
1141 last.cache = true;
1142 }
1143
1144 self.attached_tracked_files_state(&mut request.messages, cx);
1145
1146 request
1147 }
1148
1149 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1150 let mut request = LanguageModelRequest {
1151 thread_id: None,
1152 prompt_id: None,
1153 messages: vec![],
1154 tools: Vec::new(),
1155 stop: Vec::new(),
1156 temperature: None,
1157 };
1158
1159 for message in &self.messages {
1160 let mut request_message = LanguageModelRequestMessage {
1161 role: message.role,
1162 content: Vec::new(),
1163 cache: false,
1164 };
1165
1166 for segment in &message.segments {
1167 match segment {
1168 MessageSegment::Text(text) => request_message
1169 .content
1170 .push(MessageContent::Text(text.clone())),
1171 MessageSegment::Thinking { .. } => {}
1172 MessageSegment::RedactedThinking(_) => {}
1173 }
1174 }
1175
1176 if request_message.content.is_empty() {
1177 continue;
1178 }
1179
1180 request.messages.push(request_message);
1181 }
1182
1183 request.messages.push(LanguageModelRequestMessage {
1184 role: Role::User,
1185 content: vec![MessageContent::Text(added_user_message)],
1186 cache: false,
1187 });
1188
1189 request
1190 }
1191
1192 fn attached_tracked_files_state(
1193 &self,
1194 messages: &mut Vec<LanguageModelRequestMessage>,
1195 cx: &App,
1196 ) {
1197 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1198
1199 let mut stale_message = String::new();
1200
1201 let action_log = self.action_log.read(cx);
1202
1203 for stale_file in action_log.stale_buffers(cx) {
1204 let Some(file) = stale_file.read(cx).file() else {
1205 continue;
1206 };
1207
1208 if stale_message.is_empty() {
1209 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1210 }
1211
1212 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1213 }
1214
1215 let mut content = Vec::with_capacity(2);
1216
1217 if !stale_message.is_empty() {
1218 content.push(stale_message.into());
1219 }
1220
1221 if !content.is_empty() {
1222 let context_message = LanguageModelRequestMessage {
1223 role: Role::User,
1224 content,
1225 cache: false,
1226 };
1227
1228 messages.push(context_message);
1229 }
1230 }
1231
1232 pub fn stream_completion(
1233 &mut self,
1234 request: LanguageModelRequest,
1235 model: Arc<dyn LanguageModel>,
1236 window: Option<AnyWindowHandle>,
1237 cx: &mut Context<Self>,
1238 ) {
1239 let pending_completion_id = post_inc(&mut self.completion_count);
1240 let mut request_callback_parameters = if self.request_callback.is_some() {
1241 Some((request.clone(), Vec::new()))
1242 } else {
1243 None
1244 };
1245 let prompt_id = self.last_prompt_id.clone();
1246 let tool_use_metadata = ToolUseMetadata {
1247 model: model.clone(),
1248 thread_id: self.id.clone(),
1249 prompt_id: prompt_id.clone(),
1250 };
1251
1252 let task = cx.spawn(async move |thread, cx| {
1253 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1254 let initial_token_usage =
1255 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1256 let stream_completion = async {
1257 let (mut events, usage) = stream_completion_future.await?;
1258
1259 let mut stop_reason = StopReason::EndTurn;
1260 let mut current_token_usage = TokenUsage::default();
1261
1262 if let Some(usage) = usage {
1263 thread
1264 .update(cx, |_thread, cx| {
1265 cx.emit(ThreadEvent::UsageUpdated(usage));
1266 })
1267 .ok();
1268 }
1269
1270 let mut request_assistant_message_id = None;
1271
1272 while let Some(event) = events.next().await {
1273 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1274 response_events
1275 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1276 }
1277
1278 thread.update(cx, |thread, cx| {
1279 let event = match event {
1280 Ok(event) => event,
1281 Err(LanguageModelCompletionError::BadInputJson {
1282 id,
1283 tool_name,
1284 raw_input: invalid_input_json,
1285 json_parse_error,
1286 }) => {
1287 thread.receive_invalid_tool_json(
1288 id,
1289 tool_name,
1290 invalid_input_json,
1291 json_parse_error,
1292 window,
1293 cx,
1294 );
1295 return Ok(());
1296 }
1297 Err(LanguageModelCompletionError::Other(error)) => {
1298 return Err(error);
1299 }
1300 };
1301
1302 match event {
1303 LanguageModelCompletionEvent::StartMessage { .. } => {
1304 request_assistant_message_id = Some(thread.insert_message(
1305 Role::Assistant,
1306 vec![MessageSegment::Text(String::new())],
1307 cx,
1308 ));
1309 }
1310 LanguageModelCompletionEvent::Stop(reason) => {
1311 stop_reason = reason;
1312 }
1313 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1314 thread.update_token_usage_at_last_message(token_usage);
1315 thread.cumulative_token_usage = thread.cumulative_token_usage
1316 + token_usage
1317 - current_token_usage;
1318 current_token_usage = token_usage;
1319 }
1320 LanguageModelCompletionEvent::Text(chunk) => {
1321 cx.emit(ThreadEvent::ReceivedTextChunk);
1322 if let Some(last_message) = thread.messages.last_mut() {
1323 if last_message.role == Role::Assistant
1324 && !thread.tool_use.has_tool_results(last_message.id)
1325 {
1326 last_message.push_text(&chunk);
1327 cx.emit(ThreadEvent::StreamedAssistantText(
1328 last_message.id,
1329 chunk,
1330 ));
1331 } else {
1332 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1333 // of a new Assistant response.
1334 //
1335 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1336 // will result in duplicating the text of the chunk in the rendered Markdown.
1337 request_assistant_message_id = Some(thread.insert_message(
1338 Role::Assistant,
1339 vec![MessageSegment::Text(chunk.to_string())],
1340 cx,
1341 ));
1342 };
1343 }
1344 }
1345 LanguageModelCompletionEvent::Thinking {
1346 text: chunk,
1347 signature,
1348 } => {
1349 if let Some(last_message) = thread.messages.last_mut() {
1350 if last_message.role == Role::Assistant
1351 && !thread.tool_use.has_tool_results(last_message.id)
1352 {
1353 last_message.push_thinking(&chunk, signature);
1354 cx.emit(ThreadEvent::StreamedAssistantThinking(
1355 last_message.id,
1356 chunk,
1357 ));
1358 } else {
1359 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1360 // of a new Assistant response.
1361 //
1362 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1363 // will result in duplicating the text of the chunk in the rendered Markdown.
1364 request_assistant_message_id = Some(thread.insert_message(
1365 Role::Assistant,
1366 vec![MessageSegment::Thinking {
1367 text: chunk.to_string(),
1368 signature,
1369 }],
1370 cx,
1371 ));
1372 };
1373 }
1374 }
1375 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1376 let last_assistant_message_id = request_assistant_message_id
1377 .unwrap_or_else(|| {
1378 let new_assistant_message_id =
1379 thread.insert_message(Role::Assistant, vec![], cx);
1380 request_assistant_message_id =
1381 Some(new_assistant_message_id);
1382 new_assistant_message_id
1383 });
1384
1385 let tool_use_id = tool_use.id.clone();
1386 let streamed_input = if tool_use.is_input_complete {
1387 None
1388 } else {
1389 Some((&tool_use.input).clone())
1390 };
1391
1392 let ui_text = thread.tool_use.request_tool_use(
1393 last_assistant_message_id,
1394 tool_use,
1395 tool_use_metadata.clone(),
1396 cx,
1397 );
1398
1399 if let Some(input) = streamed_input {
1400 cx.emit(ThreadEvent::StreamedToolUse {
1401 tool_use_id,
1402 ui_text,
1403 input,
1404 });
1405 }
1406 }
1407 }
1408
1409 thread.touch_updated_at();
1410 cx.emit(ThreadEvent::StreamedCompletion);
1411 cx.notify();
1412
1413 thread.auto_capture_telemetry(cx);
1414 Ok(())
1415 })??;
1416
1417 smol::future::yield_now().await;
1418 }
1419
1420 thread.update(cx, |thread, cx| {
1421 thread
1422 .pending_completions
1423 .retain(|completion| completion.id != pending_completion_id);
1424
1425 // If there is a response without tool use, summarize the message. Otherwise,
1426 // allow two tool uses before summarizing.
1427 if thread.summary.is_none()
1428 && thread.messages.len() >= 2
1429 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1430 {
1431 thread.summarize(cx);
1432 }
1433 })?;
1434
1435 anyhow::Ok(stop_reason)
1436 };
1437
1438 let result = stream_completion.await;
1439
1440 thread
1441 .update(cx, |thread, cx| {
1442 thread.finalize_pending_checkpoint(cx);
1443 match result.as_ref() {
1444 Ok(stop_reason) => match stop_reason {
1445 StopReason::ToolUse => {
1446 let tool_uses = thread.use_pending_tools(window, cx);
1447 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1448 }
1449 StopReason::EndTurn => {}
1450 StopReason::MaxTokens => {}
1451 },
1452 Err(error) => {
1453 if error.is::<PaymentRequiredError>() {
1454 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1455 } else if error.is::<MaxMonthlySpendReachedError>() {
1456 cx.emit(ThreadEvent::ShowError(
1457 ThreadError::MaxMonthlySpendReached,
1458 ));
1459 } else if let Some(error) =
1460 error.downcast_ref::<ModelRequestLimitReachedError>()
1461 {
1462 cx.emit(ThreadEvent::ShowError(
1463 ThreadError::ModelRequestLimitReached { plan: error.plan },
1464 ));
1465 } else if let Some(known_error) =
1466 error.downcast_ref::<LanguageModelKnownError>()
1467 {
1468 match known_error {
1469 LanguageModelKnownError::ContextWindowLimitExceeded {
1470 tokens,
1471 } => {
1472 thread.exceeded_window_error = Some(ExceededWindowError {
1473 model_id: model.id(),
1474 token_count: *tokens,
1475 });
1476 cx.notify();
1477 }
1478 }
1479 } else {
1480 let error_message = error
1481 .chain()
1482 .map(|err| err.to_string())
1483 .collect::<Vec<_>>()
1484 .join("\n");
1485 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1486 header: "Error interacting with language model".into(),
1487 message: SharedString::from(error_message.clone()),
1488 }));
1489 }
1490
1491 thread.cancel_last_completion(window, cx);
1492 }
1493 }
1494 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1495
1496 if let Some((request_callback, (request, response_events))) = thread
1497 .request_callback
1498 .as_mut()
1499 .zip(request_callback_parameters.as_ref())
1500 {
1501 request_callback(request, response_events);
1502 }
1503
1504 thread.auto_capture_telemetry(cx);
1505
1506 if let Ok(initial_usage) = initial_token_usage {
1507 let usage = thread.cumulative_token_usage - initial_usage;
1508
1509 telemetry::event!(
1510 "Assistant Thread Completion",
1511 thread_id = thread.id().to_string(),
1512 prompt_id = prompt_id,
1513 model = model.telemetry_id(),
1514 model_provider = model.provider_id().to_string(),
1515 input_tokens = usage.input_tokens,
1516 output_tokens = usage.output_tokens,
1517 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1518 cache_read_input_tokens = usage.cache_read_input_tokens,
1519 );
1520 }
1521 })
1522 .ok();
1523 });
1524
1525 self.pending_completions.push(PendingCompletion {
1526 id: pending_completion_id,
1527 _task: task,
1528 });
1529 }
1530
1531 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1532 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1533 return;
1534 };
1535
1536 if !model.provider.is_authenticated(cx) {
1537 return;
1538 }
1539
1540 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1541 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1542 If the conversation is about a specific subject, include it in the title. \
1543 Be descriptive. DO NOT speak in the first person.";
1544
1545 let request = self.to_summarize_request(added_user_message.into());
1546
1547 self.pending_summary = cx.spawn(async move |this, cx| {
1548 async move {
1549 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1550 let (mut messages, usage) = stream.await?;
1551
1552 if let Some(usage) = usage {
1553 this.update(cx, |_thread, cx| {
1554 cx.emit(ThreadEvent::UsageUpdated(usage));
1555 })
1556 .ok();
1557 }
1558
1559 let mut new_summary = String::new();
1560 while let Some(message) = messages.stream.next().await {
1561 let text = message?;
1562 let mut lines = text.lines();
1563 new_summary.extend(lines.next());
1564
1565 // Stop if the LLM generated multiple lines.
1566 if lines.next().is_some() {
1567 break;
1568 }
1569 }
1570
1571 this.update(cx, |this, cx| {
1572 if !new_summary.is_empty() {
1573 this.summary = Some(new_summary.into());
1574 }
1575
1576 cx.emit(ThreadEvent::SummaryGenerated);
1577 })?;
1578
1579 anyhow::Ok(())
1580 }
1581 .log_err()
1582 .await
1583 });
1584 }
1585
1586 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1587 let last_message_id = self.messages.last().map(|message| message.id)?;
1588
1589 match &self.detailed_summary_state {
1590 DetailedSummaryState::Generating { message_id, .. }
1591 | DetailedSummaryState::Generated { message_id, .. }
1592 if *message_id == last_message_id =>
1593 {
1594 // Already up-to-date
1595 return None;
1596 }
1597 _ => {}
1598 }
1599
1600 let ConfiguredModel { model, provider } =
1601 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1602
1603 if !provider.is_authenticated(cx) {
1604 return None;
1605 }
1606
1607 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1608 1. A brief overview of what was discussed\n\
1609 2. Key facts or information discovered\n\
1610 3. Outcomes or conclusions reached\n\
1611 4. Any action items or next steps if any\n\
1612 Format it in Markdown with headings and bullet points.";
1613
1614 let request = self.to_summarize_request(added_user_message.into());
1615
1616 let task = cx.spawn(async move |thread, cx| {
1617 let stream = model.stream_completion_text(request, &cx);
1618 let Some(mut messages) = stream.await.log_err() else {
1619 thread
1620 .update(cx, |this, _cx| {
1621 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1622 })
1623 .log_err();
1624
1625 return;
1626 };
1627
1628 let mut new_detailed_summary = String::new();
1629
1630 while let Some(chunk) = messages.stream.next().await {
1631 if let Some(chunk) = chunk.log_err() {
1632 new_detailed_summary.push_str(&chunk);
1633 }
1634 }
1635
1636 thread
1637 .update(cx, |this, _cx| {
1638 this.detailed_summary_state = DetailedSummaryState::Generated {
1639 text: new_detailed_summary.into(),
1640 message_id: last_message_id,
1641 };
1642 })
1643 .log_err();
1644 });
1645
1646 self.detailed_summary_state = DetailedSummaryState::Generating {
1647 message_id: last_message_id,
1648 };
1649
1650 Some(task)
1651 }
1652
1653 pub fn is_generating_detailed_summary(&self) -> bool {
1654 matches!(
1655 self.detailed_summary_state,
1656 DetailedSummaryState::Generating { .. }
1657 )
1658 }
1659
1660 pub fn use_pending_tools(
1661 &mut self,
1662 window: Option<AnyWindowHandle>,
1663 cx: &mut Context<Self>,
1664 ) -> Vec<PendingToolUse> {
1665 self.auto_capture_telemetry(cx);
1666 let request = self.to_completion_request(cx);
1667 let messages = Arc::new(request.messages);
1668 let pending_tool_uses = self
1669 .tool_use
1670 .pending_tool_uses()
1671 .into_iter()
1672 .filter(|tool_use| tool_use.status.is_idle())
1673 .cloned()
1674 .collect::<Vec<_>>();
1675
1676 for tool_use in pending_tool_uses.iter() {
1677 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1678 if tool.needs_confirmation(&tool_use.input, cx)
1679 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1680 {
1681 self.tool_use.confirm_tool_use(
1682 tool_use.id.clone(),
1683 tool_use.ui_text.clone(),
1684 tool_use.input.clone(),
1685 messages.clone(),
1686 tool,
1687 );
1688 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1689 } else {
1690 self.run_tool(
1691 tool_use.id.clone(),
1692 tool_use.ui_text.clone(),
1693 tool_use.input.clone(),
1694 &messages,
1695 tool,
1696 window,
1697 cx,
1698 );
1699 }
1700 }
1701 }
1702
1703 pending_tool_uses
1704 }
1705
1706 pub fn receive_invalid_tool_json(
1707 &mut self,
1708 tool_use_id: LanguageModelToolUseId,
1709 tool_name: Arc<str>,
1710 invalid_json: Arc<str>,
1711 error: String,
1712 window: Option<AnyWindowHandle>,
1713 cx: &mut Context<Thread>,
1714 ) {
1715 log::error!("The model returned invalid input JSON: {invalid_json}");
1716
1717 let pending_tool_use = self.tool_use.insert_tool_output(
1718 tool_use_id.clone(),
1719 tool_name,
1720 Err(anyhow!("Error parsing input JSON: {error}")),
1721 cx,
1722 );
1723 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1724 pending_tool_use.ui_text.clone()
1725 } else {
1726 log::error!(
1727 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1728 );
1729 format!("Unknown tool {}", tool_use_id).into()
1730 };
1731
1732 cx.emit(ThreadEvent::InvalidToolInput {
1733 tool_use_id: tool_use_id.clone(),
1734 ui_text,
1735 invalid_input_json: invalid_json,
1736 });
1737
1738 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1739 }
1740
1741 pub fn run_tool(
1742 &mut self,
1743 tool_use_id: LanguageModelToolUseId,
1744 ui_text: impl Into<SharedString>,
1745 input: serde_json::Value,
1746 messages: &[LanguageModelRequestMessage],
1747 tool: Arc<dyn Tool>,
1748 window: Option<AnyWindowHandle>,
1749 cx: &mut Context<Thread>,
1750 ) {
1751 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1752 self.tool_use
1753 .run_pending_tool(tool_use_id, ui_text.into(), task);
1754 }
1755
1756 fn spawn_tool_use(
1757 &mut self,
1758 tool_use_id: LanguageModelToolUseId,
1759 messages: &[LanguageModelRequestMessage],
1760 input: serde_json::Value,
1761 tool: Arc<dyn Tool>,
1762 window: Option<AnyWindowHandle>,
1763 cx: &mut Context<Thread>,
1764 ) -> Task<()> {
1765 let tool_name: Arc<str> = tool.name().into();
1766
1767 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1768 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1769 } else {
1770 tool.run(
1771 input,
1772 messages,
1773 self.project.clone(),
1774 self.action_log.clone(),
1775 window,
1776 cx,
1777 )
1778 };
1779
1780 // Store the card separately if it exists
1781 if let Some(card) = tool_result.card.clone() {
1782 self.tool_use
1783 .insert_tool_result_card(tool_use_id.clone(), card);
1784 }
1785
1786 cx.spawn({
1787 async move |thread: WeakEntity<Thread>, cx| {
1788 let output = tool_result.output.await;
1789
1790 thread
1791 .update(cx, |thread, cx| {
1792 let pending_tool_use = thread.tool_use.insert_tool_output(
1793 tool_use_id.clone(),
1794 tool_name,
1795 output,
1796 cx,
1797 );
1798 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1799 })
1800 .ok();
1801 }
1802 })
1803 }
1804
1805 fn tool_finished(
1806 &mut self,
1807 tool_use_id: LanguageModelToolUseId,
1808 pending_tool_use: Option<PendingToolUse>,
1809 canceled: bool,
1810 window: Option<AnyWindowHandle>,
1811 cx: &mut Context<Self>,
1812 ) {
1813 if self.all_tools_finished() {
1814 let model_registry = LanguageModelRegistry::read_global(cx);
1815 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1816 if !canceled {
1817 self.send_to_model(model, window, cx);
1818 }
1819 self.auto_capture_telemetry(cx);
1820 }
1821 }
1822
1823 cx.emit(ThreadEvent::ToolFinished {
1824 tool_use_id,
1825 pending_tool_use,
1826 });
1827 }
1828
1829 /// Cancels the last pending completion, if there are any pending.
1830 ///
1831 /// Returns whether a completion was canceled.
1832 pub fn cancel_last_completion(
1833 &mut self,
1834 window: Option<AnyWindowHandle>,
1835 cx: &mut Context<Self>,
1836 ) -> bool {
1837 let mut canceled = self.pending_completions.pop().is_some();
1838
1839 for pending_tool_use in self.tool_use.cancel_pending() {
1840 canceled = true;
1841 self.tool_finished(
1842 pending_tool_use.id.clone(),
1843 Some(pending_tool_use),
1844 true,
1845 window,
1846 cx,
1847 );
1848 }
1849
1850 self.finalize_pending_checkpoint(cx);
1851 canceled
1852 }
1853
1854 pub fn feedback(&self) -> Option<ThreadFeedback> {
1855 self.feedback
1856 }
1857
1858 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1859 self.message_feedback.get(&message_id).copied()
1860 }
1861
1862 pub fn report_message_feedback(
1863 &mut self,
1864 message_id: MessageId,
1865 feedback: ThreadFeedback,
1866 cx: &mut Context<Self>,
1867 ) -> Task<Result<()>> {
1868 if self.message_feedback.get(&message_id) == Some(&feedback) {
1869 return Task::ready(Ok(()));
1870 }
1871
1872 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1873 let serialized_thread = self.serialize(cx);
1874 let thread_id = self.id().clone();
1875 let client = self.project.read(cx).client();
1876
1877 let enabled_tool_names: Vec<String> = self
1878 .tools()
1879 .read(cx)
1880 .enabled_tools(cx)
1881 .iter()
1882 .map(|tool| tool.name().to_string())
1883 .collect();
1884
1885 self.message_feedback.insert(message_id, feedback);
1886
1887 cx.notify();
1888
1889 let message_content = self
1890 .message(message_id)
1891 .map(|msg| msg.to_string())
1892 .unwrap_or_default();
1893
1894 cx.background_spawn(async move {
1895 let final_project_snapshot = final_project_snapshot.await;
1896 let serialized_thread = serialized_thread.await?;
1897 let thread_data =
1898 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1899
1900 let rating = match feedback {
1901 ThreadFeedback::Positive => "positive",
1902 ThreadFeedback::Negative => "negative",
1903 };
1904 telemetry::event!(
1905 "Assistant Thread Rated",
1906 rating,
1907 thread_id,
1908 enabled_tool_names,
1909 message_id = message_id.0,
1910 message_content,
1911 thread_data,
1912 final_project_snapshot
1913 );
1914 client.telemetry().flush_events().await;
1915
1916 Ok(())
1917 })
1918 }
1919
1920 pub fn report_feedback(
1921 &mut self,
1922 feedback: ThreadFeedback,
1923 cx: &mut Context<Self>,
1924 ) -> Task<Result<()>> {
1925 let last_assistant_message_id = self
1926 .messages
1927 .iter()
1928 .rev()
1929 .find(|msg| msg.role == Role::Assistant)
1930 .map(|msg| msg.id);
1931
1932 if let Some(message_id) = last_assistant_message_id {
1933 self.report_message_feedback(message_id, feedback, cx)
1934 } else {
1935 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1936 let serialized_thread = self.serialize(cx);
1937 let thread_id = self.id().clone();
1938 let client = self.project.read(cx).client();
1939 self.feedback = Some(feedback);
1940 cx.notify();
1941
1942 cx.background_spawn(async move {
1943 let final_project_snapshot = final_project_snapshot.await;
1944 let serialized_thread = serialized_thread.await?;
1945 let thread_data = serde_json::to_value(serialized_thread)
1946 .unwrap_or_else(|_| serde_json::Value::Null);
1947
1948 let rating = match feedback {
1949 ThreadFeedback::Positive => "positive",
1950 ThreadFeedback::Negative => "negative",
1951 };
1952 telemetry::event!(
1953 "Assistant Thread Rated",
1954 rating,
1955 thread_id,
1956 thread_data,
1957 final_project_snapshot
1958 );
1959 client.telemetry().flush_events().await;
1960
1961 Ok(())
1962 })
1963 }
1964 }
1965
1966 /// Create a snapshot of the current project state including git information and unsaved buffers.
1967 fn project_snapshot(
1968 project: Entity<Project>,
1969 cx: &mut Context<Self>,
1970 ) -> Task<Arc<ProjectSnapshot>> {
1971 let git_store = project.read(cx).git_store().clone();
1972 let worktree_snapshots: Vec<_> = project
1973 .read(cx)
1974 .visible_worktrees(cx)
1975 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1976 .collect();
1977
1978 cx.spawn(async move |_, cx| {
1979 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1980
1981 let mut unsaved_buffers = Vec::new();
1982 cx.update(|app_cx| {
1983 let buffer_store = project.read(app_cx).buffer_store();
1984 for buffer_handle in buffer_store.read(app_cx).buffers() {
1985 let buffer = buffer_handle.read(app_cx);
1986 if buffer.is_dirty() {
1987 if let Some(file) = buffer.file() {
1988 let path = file.path().to_string_lossy().to_string();
1989 unsaved_buffers.push(path);
1990 }
1991 }
1992 }
1993 })
1994 .ok();
1995
1996 Arc::new(ProjectSnapshot {
1997 worktree_snapshots,
1998 unsaved_buffer_paths: unsaved_buffers,
1999 timestamp: Utc::now(),
2000 })
2001 })
2002 }
2003
2004 fn worktree_snapshot(
2005 worktree: Entity<project::Worktree>,
2006 git_store: Entity<GitStore>,
2007 cx: &App,
2008 ) -> Task<WorktreeSnapshot> {
2009 cx.spawn(async move |cx| {
2010 // Get worktree path and snapshot
2011 let worktree_info = cx.update(|app_cx| {
2012 let worktree = worktree.read(app_cx);
2013 let path = worktree.abs_path().to_string_lossy().to_string();
2014 let snapshot = worktree.snapshot();
2015 (path, snapshot)
2016 });
2017
2018 let Ok((worktree_path, _snapshot)) = worktree_info else {
2019 return WorktreeSnapshot {
2020 worktree_path: String::new(),
2021 git_state: None,
2022 };
2023 };
2024
2025 let git_state = git_store
2026 .update(cx, |git_store, cx| {
2027 git_store
2028 .repositories()
2029 .values()
2030 .find(|repo| {
2031 repo.read(cx)
2032 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2033 .is_some()
2034 })
2035 .cloned()
2036 })
2037 .ok()
2038 .flatten()
2039 .map(|repo| {
2040 repo.update(cx, |repo, _| {
2041 let current_branch =
2042 repo.branch.as_ref().map(|branch| branch.name.to_string());
2043 repo.send_job(None, |state, _| async move {
2044 let RepositoryState::Local { backend, .. } = state else {
2045 return GitState {
2046 remote_url: None,
2047 head_sha: None,
2048 current_branch,
2049 diff: None,
2050 };
2051 };
2052
2053 let remote_url = backend.remote_url("origin");
2054 let head_sha = backend.head_sha();
2055 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2056
2057 GitState {
2058 remote_url,
2059 head_sha,
2060 current_branch,
2061 diff,
2062 }
2063 })
2064 })
2065 });
2066
2067 let git_state = match git_state {
2068 Some(git_state) => match git_state.ok() {
2069 Some(git_state) => git_state.await.ok(),
2070 None => None,
2071 },
2072 None => None,
2073 };
2074
2075 WorktreeSnapshot {
2076 worktree_path,
2077 git_state,
2078 }
2079 })
2080 }
2081
2082 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2083 let mut markdown = Vec::new();
2084
2085 if let Some(summary) = self.summary() {
2086 writeln!(markdown, "# {summary}\n")?;
2087 };
2088
2089 for message in self.messages() {
2090 writeln!(
2091 markdown,
2092 "## {role}\n",
2093 role = match message.role {
2094 Role::User => "User",
2095 Role::Assistant => "Assistant",
2096 Role::System => "System",
2097 }
2098 )?;
2099
2100 if !message.context.is_empty() {
2101 writeln!(markdown, "{}", message.context)?;
2102 }
2103
2104 for segment in &message.segments {
2105 match segment {
2106 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2107 MessageSegment::Thinking { text, .. } => {
2108 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2109 }
2110 MessageSegment::RedactedThinking(_) => {}
2111 }
2112 }
2113
2114 for tool_use in self.tool_uses_for_message(message.id, cx) {
2115 writeln!(
2116 markdown,
2117 "**Use Tool: {} ({})**",
2118 tool_use.name, tool_use.id
2119 )?;
2120 writeln!(markdown, "```json")?;
2121 writeln!(
2122 markdown,
2123 "{}",
2124 serde_json::to_string_pretty(&tool_use.input)?
2125 )?;
2126 writeln!(markdown, "```")?;
2127 }
2128
2129 for tool_result in self.tool_results_for_message(message.id) {
2130 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2131 if tool_result.is_error {
2132 write!(markdown, " (Error)")?;
2133 }
2134
2135 writeln!(markdown, "**\n")?;
2136 writeln!(markdown, "{}", tool_result.content)?;
2137 }
2138 }
2139
2140 Ok(String::from_utf8_lossy(&markdown).to_string())
2141 }
2142
2143 pub fn keep_edits_in_range(
2144 &mut self,
2145 buffer: Entity<language::Buffer>,
2146 buffer_range: Range<language::Anchor>,
2147 cx: &mut Context<Self>,
2148 ) {
2149 self.action_log.update(cx, |action_log, cx| {
2150 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2151 });
2152 }
2153
2154 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2155 self.action_log
2156 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2157 }
2158
2159 pub fn reject_edits_in_ranges(
2160 &mut self,
2161 buffer: Entity<language::Buffer>,
2162 buffer_ranges: Vec<Range<language::Anchor>>,
2163 cx: &mut Context<Self>,
2164 ) -> Task<Result<()>> {
2165 self.action_log.update(cx, |action_log, cx| {
2166 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2167 })
2168 }
2169
2170 pub fn action_log(&self) -> &Entity<ActionLog> {
2171 &self.action_log
2172 }
2173
2174 pub fn project(&self) -> &Entity<Project> {
2175 &self.project
2176 }
2177
2178 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2179 if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
2180 return;
2181 }
2182
2183 let now = Instant::now();
2184 if let Some(last) = self.last_auto_capture_at {
2185 if now.duration_since(last).as_secs() < 10 {
2186 return;
2187 }
2188 }
2189
2190 self.last_auto_capture_at = Some(now);
2191
2192 let thread_id = self.id().clone();
2193 let github_login = self
2194 .project
2195 .read(cx)
2196 .user_store()
2197 .read(cx)
2198 .current_user()
2199 .map(|user| user.github_login.clone());
2200 let client = self.project.read(cx).client().clone();
2201 let serialize_task = self.serialize(cx);
2202
2203 cx.background_executor()
2204 .spawn(async move {
2205 if let Ok(serialized_thread) = serialize_task.await {
2206 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2207 telemetry::event!(
2208 "Agent Thread Auto-Captured",
2209 thread_id = thread_id.to_string(),
2210 thread_data = thread_data,
2211 auto_capture_reason = "tracked_user",
2212 github_login = github_login
2213 );
2214
2215 client.telemetry().flush_events().await;
2216 }
2217 }
2218 })
2219 .detach();
2220 }
2221
2222 pub fn cumulative_token_usage(&self) -> TokenUsage {
2223 self.cumulative_token_usage
2224 }
2225
2226 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2227 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2228 return TotalTokenUsage::default();
2229 };
2230
2231 let max = model.model.max_token_count();
2232
2233 let index = self
2234 .messages
2235 .iter()
2236 .position(|msg| msg.id == message_id)
2237 .unwrap_or(0);
2238
2239 if index == 0 {
2240 return TotalTokenUsage { total: 0, max };
2241 }
2242
2243 let token_usage = &self
2244 .request_token_usage
2245 .get(index - 1)
2246 .cloned()
2247 .unwrap_or_default();
2248
2249 TotalTokenUsage {
2250 total: token_usage.total_tokens() as usize,
2251 max,
2252 }
2253 }
2254
2255 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2256 let model_registry = LanguageModelRegistry::read_global(cx);
2257 let Some(model) = model_registry.default_model() else {
2258 return TotalTokenUsage::default();
2259 };
2260
2261 let max = model.model.max_token_count();
2262
2263 if let Some(exceeded_error) = &self.exceeded_window_error {
2264 if model.model.id() == exceeded_error.model_id {
2265 return TotalTokenUsage {
2266 total: exceeded_error.token_count,
2267 max,
2268 };
2269 }
2270 }
2271
2272 let total = self
2273 .token_usage_at_last_message()
2274 .unwrap_or_default()
2275 .total_tokens() as usize;
2276
2277 TotalTokenUsage { total, max }
2278 }
2279
2280 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2281 self.request_token_usage
2282 .get(self.messages.len().saturating_sub(1))
2283 .or_else(|| self.request_token_usage.last())
2284 .cloned()
2285 }
2286
2287 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2288 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2289 self.request_token_usage
2290 .resize(self.messages.len(), placeholder);
2291
2292 if let Some(last) = self.request_token_usage.last_mut() {
2293 *last = token_usage;
2294 }
2295 }
2296
2297 pub fn deny_tool_use(
2298 &mut self,
2299 tool_use_id: LanguageModelToolUseId,
2300 tool_name: Arc<str>,
2301 window: Option<AnyWindowHandle>,
2302 cx: &mut Context<Self>,
2303 ) {
2304 let err = Err(anyhow::anyhow!(
2305 "Permission to run tool action denied by user"
2306 ));
2307
2308 self.tool_use
2309 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2310 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2311 }
2312}
2313
2314#[derive(Debug, Clone, Error)]
2315pub enum ThreadError {
2316 #[error("Payment required")]
2317 PaymentRequired,
2318 #[error("Max monthly spend reached")]
2319 MaxMonthlySpendReached,
2320 #[error("Model request limit reached")]
2321 ModelRequestLimitReached { plan: Plan },
2322 #[error("Message {header}: {message}")]
2323 Message {
2324 header: SharedString,
2325 message: SharedString,
2326 },
2327}
2328
2329#[derive(Debug, Clone)]
2330pub enum ThreadEvent {
2331 ShowError(ThreadError),
2332 UsageUpdated(RequestUsage),
2333 StreamedCompletion,
2334 ReceivedTextChunk,
2335 StreamedAssistantText(MessageId, String),
2336 StreamedAssistantThinking(MessageId, String),
2337 StreamedToolUse {
2338 tool_use_id: LanguageModelToolUseId,
2339 ui_text: Arc<str>,
2340 input: serde_json::Value,
2341 },
2342 InvalidToolInput {
2343 tool_use_id: LanguageModelToolUseId,
2344 ui_text: Arc<str>,
2345 invalid_input_json: Arc<str>,
2346 },
2347 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2348 MessageAdded(MessageId),
2349 MessageEdited(MessageId),
2350 MessageDeleted(MessageId),
2351 SummaryGenerated,
2352 SummaryChanged,
2353 UsePendingTools {
2354 tool_uses: Vec<PendingToolUse>,
2355 },
2356 ToolFinished {
2357 #[allow(unused)]
2358 tool_use_id: LanguageModelToolUseId,
2359 /// The pending tool use that corresponds to this tool.
2360 pending_tool_use: Option<PendingToolUse>,
2361 },
2362 CheckpointChanged,
2363 ToolConfirmationNeeded,
2364}
2365
2366impl EventEmitter<ThreadEvent> for Thread {}
2367
2368struct PendingCompletion {
2369 id: usize,
2370 _task: Task<()>,
2371}
2372
2373#[cfg(test)]
2374mod tests {
2375 use super::*;
2376 use crate::{ThreadStore, context_store::ContextStore, thread_store};
2377 use assistant_settings::AssistantSettings;
2378 use context_server::ContextServerSettings;
2379 use editor::EditorSettings;
2380 use gpui::TestAppContext;
2381 use project::{FakeFs, Project};
2382 use prompt_store::PromptBuilder;
2383 use serde_json::json;
2384 use settings::{Settings, SettingsStore};
2385 use std::sync::Arc;
2386 use theme::ThemeSettings;
2387 use util::path;
2388 use workspace::Workspace;
2389
2390 #[gpui::test]
2391 async fn test_message_with_context(cx: &mut TestAppContext) {
2392 init_test_settings(cx);
2393
2394 let project = create_test_project(
2395 cx,
2396 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2397 )
2398 .await;
2399
2400 let (_workspace, _thread_store, thread, context_store) =
2401 setup_test_environment(cx, project.clone()).await;
2402
2403 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2404 .await
2405 .unwrap();
2406
2407 let context =
2408 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2409
2410 // Insert user message with context
2411 let message_id = thread.update(cx, |thread, cx| {
2412 thread.insert_user_message("Please explain this code", vec![context], None, cx)
2413 });
2414
2415 // Check content and context in message object
2416 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2417
2418 // Use different path format strings based on platform for the test
2419 #[cfg(windows)]
2420 let path_part = r"test\code.rs";
2421 #[cfg(not(windows))]
2422 let path_part = "test/code.rs";
2423
2424 let expected_context = format!(
2425 r#"
2426<context>
2427The following items were attached by the user. You don't need to use other tools to read them.
2428
2429<files>
2430```rs {path_part}
2431fn main() {{
2432 println!("Hello, world!");
2433}}
2434```
2435</files>
2436</context>
2437"#
2438 );
2439
2440 assert_eq!(message.role, Role::User);
2441 assert_eq!(message.segments.len(), 1);
2442 assert_eq!(
2443 message.segments[0],
2444 MessageSegment::Text("Please explain this code".to_string())
2445 );
2446 assert_eq!(message.context, expected_context);
2447
2448 // Check message in request
2449 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2450
2451 assert_eq!(request.messages.len(), 2);
2452 let expected_full_message = format!("{}Please explain this code", expected_context);
2453 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2454 }
2455
2456 #[gpui::test]
2457 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2458 init_test_settings(cx);
2459
2460 let project = create_test_project(
2461 cx,
2462 json!({
2463 "file1.rs": "fn function1() {}\n",
2464 "file2.rs": "fn function2() {}\n",
2465 "file3.rs": "fn function3() {}\n",
2466 }),
2467 )
2468 .await;
2469
2470 let (_, _thread_store, thread, context_store) =
2471 setup_test_environment(cx, project.clone()).await;
2472
2473 // Open files individually
2474 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2475 .await
2476 .unwrap();
2477 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2478 .await
2479 .unwrap();
2480 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2481 .await
2482 .unwrap();
2483
2484 // Get the context objects
2485 let contexts = context_store.update(cx, |store, _| store.context().clone());
2486 assert_eq!(contexts.len(), 3);
2487
2488 // First message with context 1
2489 let message1_id = thread.update(cx, |thread, cx| {
2490 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2491 });
2492
2493 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2494 let message2_id = thread.update(cx, |thread, cx| {
2495 thread.insert_user_message(
2496 "Message 2",
2497 vec![contexts[0].clone(), contexts[1].clone()],
2498 None,
2499 cx,
2500 )
2501 });
2502
2503 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2504 let message3_id = thread.update(cx, |thread, cx| {
2505 thread.insert_user_message(
2506 "Message 3",
2507 vec![
2508 contexts[0].clone(),
2509 contexts[1].clone(),
2510 contexts[2].clone(),
2511 ],
2512 None,
2513 cx,
2514 )
2515 });
2516
2517 // Check what contexts are included in each message
2518 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2519 (
2520 thread.message(message1_id).unwrap().clone(),
2521 thread.message(message2_id).unwrap().clone(),
2522 thread.message(message3_id).unwrap().clone(),
2523 )
2524 });
2525
2526 // First message should include context 1
2527 assert!(message1.context.contains("file1.rs"));
2528
2529 // Second message should include only context 2 (not 1)
2530 assert!(!message2.context.contains("file1.rs"));
2531 assert!(message2.context.contains("file2.rs"));
2532
2533 // Third message should include only context 3 (not 1 or 2)
2534 assert!(!message3.context.contains("file1.rs"));
2535 assert!(!message3.context.contains("file2.rs"));
2536 assert!(message3.context.contains("file3.rs"));
2537
2538 // Check entire request to make sure all contexts are properly included
2539 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2540
2541 // The request should contain all 3 messages
2542 assert_eq!(request.messages.len(), 4);
2543
2544 // Check that the contexts are properly formatted in each message
2545 assert!(request.messages[1].string_contents().contains("file1.rs"));
2546 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2547 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2548
2549 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2550 assert!(request.messages[2].string_contents().contains("file2.rs"));
2551 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2552
2553 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2554 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2555 assert!(request.messages[3].string_contents().contains("file3.rs"));
2556 }
2557
2558 #[gpui::test]
2559 async fn test_message_without_files(cx: &mut TestAppContext) {
2560 init_test_settings(cx);
2561
2562 let project = create_test_project(
2563 cx,
2564 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2565 )
2566 .await;
2567
2568 let (_, _thread_store, thread, _context_store) =
2569 setup_test_environment(cx, project.clone()).await;
2570
2571 // Insert user message without any context (empty context vector)
2572 let message_id = thread.update(cx, |thread, cx| {
2573 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2574 });
2575
2576 // Check content and context in message object
2577 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2578
2579 // Context should be empty when no files are included
2580 assert_eq!(message.role, Role::User);
2581 assert_eq!(message.segments.len(), 1);
2582 assert_eq!(
2583 message.segments[0],
2584 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2585 );
2586 assert_eq!(message.context, "");
2587
2588 // Check message in request
2589 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2590
2591 assert_eq!(request.messages.len(), 2);
2592 assert_eq!(
2593 request.messages[1].string_contents(),
2594 "What is the best way to learn Rust?"
2595 );
2596
2597 // Add second message, also without context
2598 let message2_id = thread.update(cx, |thread, cx| {
2599 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2600 });
2601
2602 let message2 =
2603 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2604 assert_eq!(message2.context, "");
2605
2606 // Check that both messages appear in the request
2607 let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2608
2609 assert_eq!(request.messages.len(), 3);
2610 assert_eq!(
2611 request.messages[1].string_contents(),
2612 "What is the best way to learn Rust?"
2613 );
2614 assert_eq!(
2615 request.messages[2].string_contents(),
2616 "Are there any good books?"
2617 );
2618 }
2619
2620 #[gpui::test]
2621 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2622 init_test_settings(cx);
2623
2624 let project = create_test_project(
2625 cx,
2626 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2627 )
2628 .await;
2629
2630 let (_workspace, _thread_store, thread, context_store) =
2631 setup_test_environment(cx, project.clone()).await;
2632
2633 // Open buffer and add it to context
2634 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2635 .await
2636 .unwrap();
2637
2638 let context =
2639 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2640
2641 // Insert user message with the buffer as context
2642 thread.update(cx, |thread, cx| {
2643 thread.insert_user_message("Explain this code", vec![context], None, cx)
2644 });
2645
2646 // Create a request and check that it doesn't have a stale buffer warning yet
2647 let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2648
2649 // Make sure we don't have a stale file warning yet
2650 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2651 msg.string_contents()
2652 .contains("These files changed since last read:")
2653 });
2654 assert!(
2655 !has_stale_warning,
2656 "Should not have stale buffer warning before buffer is modified"
2657 );
2658
2659 // Modify the buffer
2660 buffer.update(cx, |buffer, cx| {
2661 // Find a position at the end of line 1
2662 buffer.edit(
2663 [(1..1, "\n println!(\"Added a new line\");\n")],
2664 None,
2665 cx,
2666 );
2667 });
2668
2669 // Insert another user message without context
2670 thread.update(cx, |thread, cx| {
2671 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2672 });
2673
2674 // Create a new request and check for the stale buffer warning
2675 let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
2676
2677 // We should have a stale file warning as the last message
2678 let last_message = new_request
2679 .messages
2680 .last()
2681 .expect("Request should have messages");
2682
2683 // The last message should be the stale buffer notification
2684 assert_eq!(last_message.role, Role::User);
2685
2686 // Check the exact content of the message
2687 let expected_content = "These files changed since last read:\n- code.rs\n";
2688 assert_eq!(
2689 last_message.string_contents(),
2690 expected_content,
2691 "Last message should be exactly the stale buffer notification"
2692 );
2693 }
2694
2695 fn init_test_settings(cx: &mut TestAppContext) {
2696 cx.update(|cx| {
2697 let settings_store = SettingsStore::test(cx);
2698 cx.set_global(settings_store);
2699 language::init(cx);
2700 Project::init_settings(cx);
2701 AssistantSettings::register(cx);
2702 prompt_store::init(cx);
2703 thread_store::init(cx);
2704 workspace::init_settings(cx);
2705 ThemeSettings::register(cx);
2706 ContextServerSettings::register(cx);
2707 EditorSettings::register(cx);
2708 });
2709 }
2710
2711 // Helper to create a test project with test files
2712 async fn create_test_project(
2713 cx: &mut TestAppContext,
2714 files: serde_json::Value,
2715 ) -> Entity<Project> {
2716 let fs = FakeFs::new(cx.executor());
2717 fs.insert_tree(path!("/test"), files).await;
2718 Project::test(fs, [path!("/test").as_ref()], cx).await
2719 }
2720
2721 async fn setup_test_environment(
2722 cx: &mut TestAppContext,
2723 project: Entity<Project>,
2724 ) -> (
2725 Entity<Workspace>,
2726 Entity<ThreadStore>,
2727 Entity<Thread>,
2728 Entity<ContextStore>,
2729 ) {
2730 let (workspace, cx) =
2731 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2732
2733 let thread_store = cx
2734 .update(|_, cx| {
2735 ThreadStore::load(
2736 project.clone(),
2737 cx.new(|_| ToolWorkingSet::default()),
2738 Arc::new(PromptBuilder::new(None).unwrap()),
2739 cx,
2740 )
2741 })
2742 .await
2743 .unwrap();
2744
2745 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2746 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2747
2748 (workspace, thread_store, thread, context_store)
2749 }
2750
2751 async fn add_file_to_context(
2752 project: &Entity<Project>,
2753 context_store: &Entity<ContextStore>,
2754 path: &str,
2755 cx: &mut TestAppContext,
2756 ) -> Result<Entity<language::Buffer>> {
2757 let buffer_path = project
2758 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2759 .unwrap();
2760
2761 let buffer = project
2762 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2763 .await
2764 .unwrap();
2765
2766 context_store
2767 .update(cx, |store, cx| {
2768 store.add_file_from_buffer(buffer.clone(), cx)
2769 })
2770 .await?;
2771
2772 Ok(buffer)
2773 }
2774}