1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
26 StopReason, TokenUsage,
27};
28use postage::stream::Stream as _;
29use project::Project;
30use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
31use prompt_store::{ModelContext, PromptBuilder};
32use proto::Plan;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use thiserror::Error;
37use util::{ResultExt as _, TryFutureExt as _, post_inc};
38use uuid::Uuid;
39use zed_llm_client::CompletionMode;
40
41use crate::ThreadStore;
42use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
43use crate::thread_store::{
44 SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
45 SerializedToolResult, SerializedToolUse, SharedProjectContext,
46};
47use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72/// The ID of the user prompt that initiated a request.
73///
74/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
75#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
76pub struct PromptId(Arc<str>);
77
78impl PromptId {
79 pub fn new() -> Self {
80 Self(Uuid::new_v4().to_string().into())
81 }
82}
83
84impl std::fmt::Display for PromptId {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
91pub struct MessageId(pub(crate) usize);
92
93impl MessageId {
94 fn post_inc(&mut self) -> Self {
95 Self(post_inc(&mut self.0))
96 }
97}
98
99/// A message in a [`Thread`].
100#[derive(Debug, Clone)]
101pub struct Message {
102 pub id: MessageId,
103 pub role: Role,
104 pub segments: Vec<MessageSegment>,
105 pub loaded_context: LoadedContext,
106}
107
108impl Message {
109 /// Returns whether the message contains any meaningful text that should be displayed
110 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
111 pub fn should_display_content(&self) -> bool {
112 self.segments.iter().all(|segment| segment.should_display())
113 }
114
115 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
116 if let Some(MessageSegment::Thinking {
117 text: segment,
118 signature: current_signature,
119 }) = self.segments.last_mut()
120 {
121 if let Some(signature) = signature {
122 *current_signature = Some(signature);
123 }
124 segment.push_str(text);
125 } else {
126 self.segments.push(MessageSegment::Thinking {
127 text: text.to_string(),
128 signature,
129 });
130 }
131 }
132
133 pub fn push_text(&mut self, text: &str) {
134 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Text(text.to_string()));
138 }
139 }
140
141 pub fn to_string(&self) -> String {
142 let mut result = String::new();
143
144 if !self.loaded_context.text.is_empty() {
145 result.push_str(&self.loaded_context.text);
146 }
147
148 for segment in &self.segments {
149 match segment {
150 MessageSegment::Text(text) => result.push_str(text),
151 MessageSegment::Thinking { text, .. } => {
152 result.push_str("<think>\n");
153 result.push_str(text);
154 result.push_str("\n</think>");
155 }
156 MessageSegment::RedactedThinking(_) => {}
157 }
158 }
159
160 result
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum MessageSegment {
166 Text(String),
167 Thinking {
168 text: String,
169 signature: Option<String>,
170 },
171 RedactedThinking(Vec<u8>),
172}
173
174impl MessageSegment {
175 pub fn should_display(&self) -> bool {
176 match self {
177 Self::Text(text) => text.is_empty(),
178 Self::Thinking { text, .. } => text.is_empty(),
179 Self::RedactedThinking(_) => false,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ProjectSnapshot {
186 pub worktree_snapshots: Vec<WorktreeSnapshot>,
187 pub unsaved_buffer_paths: Vec<String>,
188 pub timestamp: DateTime<Utc>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct WorktreeSnapshot {
193 pub worktree_path: String,
194 pub git_state: Option<GitState>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct GitState {
199 pub remote_url: Option<String>,
200 pub head_sha: Option<String>,
201 pub current_branch: Option<String>,
202 pub diff: Option<String>,
203}
204
205#[derive(Clone)]
206pub struct ThreadCheckpoint {
207 message_id: MessageId,
208 git_checkpoint: GitStoreCheckpoint,
209}
210
211#[derive(Copy, Clone, Debug, PartialEq, Eq)]
212pub enum ThreadFeedback {
213 Positive,
214 Negative,
215}
216
217pub enum LastRestoreCheckpoint {
218 Pending {
219 message_id: MessageId,
220 },
221 Error {
222 message_id: MessageId,
223 error: String,
224 },
225}
226
227impl LastRestoreCheckpoint {
228 pub fn message_id(&self) -> MessageId {
229 match self {
230 LastRestoreCheckpoint::Pending { message_id } => *message_id,
231 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
232 }
233 }
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub enum DetailedSummaryState {
238 #[default]
239 NotGenerated,
240 Generating {
241 message_id: MessageId,
242 },
243 Generated {
244 text: SharedString,
245 message_id: MessageId,
246 },
247}
248
249impl DetailedSummaryState {
250 fn text(&self) -> Option<SharedString> {
251 if let Self::Generated { text, .. } = self {
252 Some(text.clone())
253 } else {
254 None
255 }
256 }
257}
258
259#[derive(Default)]
260pub struct TotalTokenUsage {
261 pub total: usize,
262 pub max: usize,
263}
264
265impl TotalTokenUsage {
266 pub fn ratio(&self) -> TokenUsageRatio {
267 #[cfg(debug_assertions)]
268 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
269 .unwrap_or("0.8".to_string())
270 .parse()
271 .unwrap();
272 #[cfg(not(debug_assertions))]
273 let warning_threshold: f32 = 0.8;
274
275 // When the maximum is unknown because there is no selected model,
276 // avoid showing the token limit warning.
277 if self.max == 0 {
278 TokenUsageRatio::Normal
279 } else if self.total >= self.max {
280 TokenUsageRatio::Exceeded
281 } else if self.total as f32 / self.max as f32 >= warning_threshold {
282 TokenUsageRatio::Warning
283 } else {
284 TokenUsageRatio::Normal
285 }
286 }
287
288 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
289 TotalTokenUsage {
290 total: self.total + tokens,
291 max: self.max,
292 }
293 }
294}
295
296#[derive(Debug, Default, PartialEq, Eq)]
297pub enum TokenUsageRatio {
298 #[default]
299 Normal,
300 Warning,
301 Exceeded,
302}
303
304/// A thread of conversation with the LLM.
305pub struct Thread {
306 id: ThreadId,
307 updated_at: DateTime<Utc>,
308 summary: Option<SharedString>,
309 pending_summary: Task<Option<()>>,
310 detailed_summary_task: Task<Option<()>>,
311 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
312 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
313 completion_mode: Option<CompletionMode>,
314 messages: Vec<Message>,
315 next_message_id: MessageId,
316 last_prompt_id: PromptId,
317 project_context: SharedProjectContext,
318 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
319 completion_count: usize,
320 pending_completions: Vec<PendingCompletion>,
321 project: Entity<Project>,
322 prompt_builder: Arc<PromptBuilder>,
323 tools: Entity<ToolWorkingSet>,
324 tool_use: ToolUseState,
325 action_log: Entity<ActionLog>,
326 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
327 pending_checkpoint: Option<ThreadCheckpoint>,
328 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
329 request_token_usage: Vec<TokenUsage>,
330 cumulative_token_usage: TokenUsage,
331 exceeded_window_error: Option<ExceededWindowError>,
332 feedback: Option<ThreadFeedback>,
333 message_feedback: HashMap<MessageId, ThreadFeedback>,
334 last_auto_capture_at: Option<Instant>,
335 request_callback: Option<
336 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
337 >,
338 remaining_turns: u32,
339 configured_model: Option<ConfiguredModel>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct ExceededWindowError {
344 /// Model used when last message exceeded context window
345 model_id: LanguageModelId,
346 /// Token count including last message
347 token_count: usize,
348}
349
350impl Thread {
351 pub fn new(
352 project: Entity<Project>,
353 tools: Entity<ToolWorkingSet>,
354 prompt_builder: Arc<PromptBuilder>,
355 system_prompt: SharedProjectContext,
356 cx: &mut Context<Self>,
357 ) -> Self {
358 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
359 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
360
361 Self {
362 id: ThreadId::new(),
363 updated_at: Utc::now(),
364 summary: None,
365 pending_summary: Task::ready(None),
366 detailed_summary_task: Task::ready(None),
367 detailed_summary_tx,
368 detailed_summary_rx,
369 completion_mode: None,
370 messages: Vec::new(),
371 next_message_id: MessageId(0),
372 last_prompt_id: PromptId::new(),
373 project_context: system_prompt,
374 checkpoints_by_message: HashMap::default(),
375 completion_count: 0,
376 pending_completions: Vec::new(),
377 project: project.clone(),
378 prompt_builder,
379 tools: tools.clone(),
380 last_restore_checkpoint: None,
381 pending_checkpoint: None,
382 tool_use: ToolUseState::new(tools.clone()),
383 action_log: cx.new(|_| ActionLog::new(project.clone())),
384 initial_project_snapshot: {
385 let project_snapshot = Self::project_snapshot(project, cx);
386 cx.foreground_executor()
387 .spawn(async move { Some(project_snapshot.await) })
388 .shared()
389 },
390 request_token_usage: Vec::new(),
391 cumulative_token_usage: TokenUsage::default(),
392 exceeded_window_error: None,
393 feedback: None,
394 message_feedback: HashMap::default(),
395 last_auto_capture_at: None,
396 request_callback: None,
397 remaining_turns: u32::MAX,
398 configured_model,
399 }
400 }
401
402 pub fn deserialize(
403 id: ThreadId,
404 serialized: SerializedThread,
405 project: Entity<Project>,
406 tools: Entity<ToolWorkingSet>,
407 prompt_builder: Arc<PromptBuilder>,
408 project_context: SharedProjectContext,
409 cx: &mut Context<Self>,
410 ) -> Self {
411 let next_message_id = MessageId(
412 serialized
413 .messages
414 .last()
415 .map(|message| message.id.0 + 1)
416 .unwrap_or(0),
417 );
418 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
419 let (detailed_summary_tx, detailed_summary_rx) =
420 postage::watch::channel_with(serialized.detailed_summary_state);
421
422 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
423 serialized
424 .model
425 .and_then(|model| {
426 let model = SelectedModel {
427 provider: model.provider.clone().into(),
428 model: model.model.clone().into(),
429 };
430 registry.select_model(&model, cx)
431 })
432 .or_else(|| registry.default_model())
433 });
434
435 Self {
436 id,
437 updated_at: serialized.updated_at,
438 summary: Some(serialized.summary),
439 pending_summary: Task::ready(None),
440 detailed_summary_task: Task::ready(None),
441 detailed_summary_tx,
442 detailed_summary_rx,
443 completion_mode: None,
444 messages: serialized
445 .messages
446 .into_iter()
447 .map(|message| Message {
448 id: message.id,
449 role: message.role,
450 segments: message
451 .segments
452 .into_iter()
453 .map(|segment| match segment {
454 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
455 SerializedMessageSegment::Thinking { text, signature } => {
456 MessageSegment::Thinking { text, signature }
457 }
458 SerializedMessageSegment::RedactedThinking { data } => {
459 MessageSegment::RedactedThinking(data)
460 }
461 })
462 .collect(),
463 loaded_context: LoadedContext {
464 contexts: Vec::new(),
465 text: message.context,
466 images: Vec::new(),
467 },
468 })
469 .collect(),
470 next_message_id,
471 last_prompt_id: PromptId::new(),
472 project_context,
473 checkpoints_by_message: HashMap::default(),
474 completion_count: 0,
475 pending_completions: Vec::new(),
476 last_restore_checkpoint: None,
477 pending_checkpoint: None,
478 project: project.clone(),
479 prompt_builder,
480 tools,
481 tool_use,
482 action_log: cx.new(|_| ActionLog::new(project)),
483 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
484 request_token_usage: serialized.request_token_usage,
485 cumulative_token_usage: serialized.cumulative_token_usage,
486 exceeded_window_error: None,
487 feedback: None,
488 message_feedback: HashMap::default(),
489 last_auto_capture_at: None,
490 request_callback: None,
491 remaining_turns: u32::MAX,
492 configured_model,
493 }
494 }
495
496 pub fn set_request_callback(
497 &mut self,
498 callback: impl 'static
499 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
500 ) {
501 self.request_callback = Some(Box::new(callback));
502 }
503
504 pub fn id(&self) -> &ThreadId {
505 &self.id
506 }
507
508 pub fn is_empty(&self) -> bool {
509 self.messages.is_empty()
510 }
511
512 pub fn updated_at(&self) -> DateTime<Utc> {
513 self.updated_at
514 }
515
516 pub fn touch_updated_at(&mut self) {
517 self.updated_at = Utc::now();
518 }
519
520 pub fn advance_prompt_id(&mut self) {
521 self.last_prompt_id = PromptId::new();
522 }
523
524 pub fn summary(&self) -> Option<SharedString> {
525 self.summary.clone()
526 }
527
528 pub fn project_context(&self) -> SharedProjectContext {
529 self.project_context.clone()
530 }
531
532 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
533 if self.configured_model.is_none() {
534 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
535 }
536 self.configured_model.clone()
537 }
538
539 pub fn configured_model(&self) -> Option<ConfiguredModel> {
540 self.configured_model.clone()
541 }
542
543 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
544 self.configured_model = model;
545 cx.notify();
546 }
547
548 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
549
550 pub fn summary_or_default(&self) -> SharedString {
551 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
552 }
553
554 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
555 let Some(current_summary) = &self.summary else {
556 // Don't allow setting summary until generated
557 return;
558 };
559
560 let mut new_summary = new_summary.into();
561
562 if new_summary.is_empty() {
563 new_summary = Self::DEFAULT_SUMMARY;
564 }
565
566 if current_summary != &new_summary {
567 self.summary = Some(new_summary);
568 cx.emit(ThreadEvent::SummaryChanged);
569 }
570 }
571
572 pub fn completion_mode(&self) -> Option<CompletionMode> {
573 self.completion_mode
574 }
575
576 pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
577 self.completion_mode = mode;
578 }
579
580 pub fn message(&self, id: MessageId) -> Option<&Message> {
581 let index = self
582 .messages
583 .binary_search_by(|message| message.id.cmp(&id))
584 .ok()?;
585
586 self.messages.get(index)
587 }
588
589 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
590 self.messages.iter()
591 }
592
593 pub fn is_generating(&self) -> bool {
594 !self.pending_completions.is_empty() || !self.all_tools_finished()
595 }
596
597 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
598 &self.tools
599 }
600
601 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
602 self.tool_use
603 .pending_tool_uses()
604 .into_iter()
605 .find(|tool_use| &tool_use.id == id)
606 }
607
608 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
609 self.tool_use
610 .pending_tool_uses()
611 .into_iter()
612 .filter(|tool_use| tool_use.status.needs_confirmation())
613 }
614
615 pub fn has_pending_tool_uses(&self) -> bool {
616 !self.tool_use.pending_tool_uses().is_empty()
617 }
618
619 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
620 self.checkpoints_by_message.get(&id).cloned()
621 }
622
623 pub fn restore_checkpoint(
624 &mut self,
625 checkpoint: ThreadCheckpoint,
626 cx: &mut Context<Self>,
627 ) -> Task<Result<()>> {
628 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
629 message_id: checkpoint.message_id,
630 });
631 cx.emit(ThreadEvent::CheckpointChanged);
632 cx.notify();
633
634 let git_store = self.project().read(cx).git_store().clone();
635 let restore = git_store.update(cx, |git_store, cx| {
636 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
637 });
638
639 cx.spawn(async move |this, cx| {
640 let result = restore.await;
641 this.update(cx, |this, cx| {
642 if let Err(err) = result.as_ref() {
643 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
644 message_id: checkpoint.message_id,
645 error: err.to_string(),
646 });
647 } else {
648 this.truncate(checkpoint.message_id, cx);
649 this.last_restore_checkpoint = None;
650 }
651 this.pending_checkpoint = None;
652 cx.emit(ThreadEvent::CheckpointChanged);
653 cx.notify();
654 })?;
655 result
656 })
657 }
658
659 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
660 let pending_checkpoint = if self.is_generating() {
661 return;
662 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
663 checkpoint
664 } else {
665 return;
666 };
667
668 let git_store = self.project.read(cx).git_store().clone();
669 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
670 cx.spawn(async move |this, cx| match final_checkpoint.await {
671 Ok(final_checkpoint) => {
672 let equal = git_store
673 .update(cx, |store, cx| {
674 store.compare_checkpoints(
675 pending_checkpoint.git_checkpoint.clone(),
676 final_checkpoint.clone(),
677 cx,
678 )
679 })?
680 .await
681 .unwrap_or(false);
682
683 if !equal {
684 this.update(cx, |this, cx| {
685 this.insert_checkpoint(pending_checkpoint, cx)
686 })?;
687 }
688
689 Ok(())
690 }
691 Err(_) => this.update(cx, |this, cx| {
692 this.insert_checkpoint(pending_checkpoint, cx)
693 }),
694 })
695 .detach();
696 }
697
698 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
699 self.checkpoints_by_message
700 .insert(checkpoint.message_id, checkpoint);
701 cx.emit(ThreadEvent::CheckpointChanged);
702 cx.notify();
703 }
704
705 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
706 self.last_restore_checkpoint.as_ref()
707 }
708
709 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
710 let Some(message_ix) = self
711 .messages
712 .iter()
713 .rposition(|message| message.id == message_id)
714 else {
715 return;
716 };
717 for deleted_message in self.messages.drain(message_ix..) {
718 self.checkpoints_by_message.remove(&deleted_message.id);
719 }
720 cx.notify();
721 }
722
723 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
724 self.messages
725 .iter()
726 .find(|message| message.id == id)
727 .into_iter()
728 .flat_map(|message| message.loaded_context.contexts.iter())
729 }
730
731 pub fn is_turn_end(&self, ix: usize) -> bool {
732 if self.messages.is_empty() {
733 return false;
734 }
735
736 if !self.is_generating() && ix == self.messages.len() - 1 {
737 return true;
738 }
739
740 let Some(message) = self.messages.get(ix) else {
741 return false;
742 };
743
744 if message.role != Role::Assistant {
745 return false;
746 }
747
748 self.messages
749 .get(ix + 1)
750 .and_then(|message| {
751 self.message(message.id)
752 .map(|next_message| next_message.role == Role::User)
753 })
754 .unwrap_or(false)
755 }
756
757 /// Returns whether all of the tool uses have finished running.
758 pub fn all_tools_finished(&self) -> bool {
759 // If the only pending tool uses left are the ones with errors, then
760 // that means that we've finished running all of the pending tools.
761 self.tool_use
762 .pending_tool_uses()
763 .iter()
764 .all(|tool_use| tool_use.status.is_error())
765 }
766
767 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
768 self.tool_use.tool_uses_for_message(id, cx)
769 }
770
771 pub fn tool_results_for_message(
772 &self,
773 assistant_message_id: MessageId,
774 ) -> Vec<&LanguageModelToolResult> {
775 self.tool_use.tool_results_for_message(assistant_message_id)
776 }
777
778 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
779 self.tool_use.tool_result(id)
780 }
781
782 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
783 Some(&self.tool_use.tool_result(id)?.content)
784 }
785
786 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
787 self.tool_use.tool_result_card(id).cloned()
788 }
789
790 /// Return tools that are both enabled and supported by the model
791 pub fn available_tools(
792 &self,
793 cx: &App,
794 model: Arc<dyn LanguageModel>,
795 ) -> Vec<LanguageModelRequestTool> {
796 if model.supports_tools() {
797 self.tools()
798 .read(cx)
799 .enabled_tools(cx)
800 .into_iter()
801 .filter_map(|tool| {
802 // Skip tools that cannot be supported
803 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
804 Some(LanguageModelRequestTool {
805 name: tool.name(),
806 description: tool.description(),
807 input_schema,
808 })
809 })
810 .collect()
811 } else {
812 Vec::default()
813 }
814 }
815
816 pub fn insert_user_message(
817 &mut self,
818 text: impl Into<String>,
819 loaded_context: ContextLoadResult,
820 git_checkpoint: Option<GitStoreCheckpoint>,
821 cx: &mut Context<Self>,
822 ) -> MessageId {
823 if !loaded_context.referenced_buffers.is_empty() {
824 self.action_log.update(cx, |log, cx| {
825 for buffer in loaded_context.referenced_buffers {
826 log.track_buffer(buffer, cx);
827 }
828 });
829 }
830
831 let message_id = self.insert_message(
832 Role::User,
833 vec![MessageSegment::Text(text.into())],
834 loaded_context.loaded_context,
835 cx,
836 );
837
838 if let Some(git_checkpoint) = git_checkpoint {
839 self.pending_checkpoint = Some(ThreadCheckpoint {
840 message_id,
841 git_checkpoint,
842 });
843 }
844
845 self.auto_capture_telemetry(cx);
846
847 message_id
848 }
849
850 pub fn insert_assistant_message(
851 &mut self,
852 segments: Vec<MessageSegment>,
853 cx: &mut Context<Self>,
854 ) -> MessageId {
855 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
856 }
857
858 pub fn insert_message(
859 &mut self,
860 role: Role,
861 segments: Vec<MessageSegment>,
862 loaded_context: LoadedContext,
863 cx: &mut Context<Self>,
864 ) -> MessageId {
865 let id = self.next_message_id.post_inc();
866 self.messages.push(Message {
867 id,
868 role,
869 segments,
870 loaded_context,
871 });
872 self.touch_updated_at();
873 cx.emit(ThreadEvent::MessageAdded(id));
874 id
875 }
876
877 pub fn edit_message(
878 &mut self,
879 id: MessageId,
880 new_role: Role,
881 new_segments: Vec<MessageSegment>,
882 cx: &mut Context<Self>,
883 ) -> bool {
884 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
885 return false;
886 };
887 message.role = new_role;
888 message.segments = new_segments;
889 self.touch_updated_at();
890 cx.emit(ThreadEvent::MessageEdited(id));
891 true
892 }
893
894 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
895 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
896 return false;
897 };
898 self.messages.remove(index);
899 self.touch_updated_at();
900 cx.emit(ThreadEvent::MessageDeleted(id));
901 true
902 }
903
904 /// Returns the representation of this [`Thread`] in a textual form.
905 ///
906 /// This is the representation we use when attaching a thread as context to another thread.
907 pub fn text(&self) -> String {
908 let mut text = String::new();
909
910 for message in &self.messages {
911 text.push_str(match message.role {
912 language_model::Role::User => "User:",
913 language_model::Role::Assistant => "Assistant:",
914 language_model::Role::System => "System:",
915 });
916 text.push('\n');
917
918 for segment in &message.segments {
919 match segment {
920 MessageSegment::Text(content) => text.push_str(content),
921 MessageSegment::Thinking { text: content, .. } => {
922 text.push_str(&format!("<think>{}</think>", content))
923 }
924 MessageSegment::RedactedThinking(_) => {}
925 }
926 }
927 text.push('\n');
928 }
929
930 text
931 }
932
933 /// Serializes this thread into a format for storage or telemetry.
934 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
935 let initial_project_snapshot = self.initial_project_snapshot.clone();
936 cx.spawn(async move |this, cx| {
937 let initial_project_snapshot = initial_project_snapshot.await;
938 this.read_with(cx, |this, cx| SerializedThread {
939 version: SerializedThread::VERSION.to_string(),
940 summary: this.summary_or_default(),
941 updated_at: this.updated_at(),
942 messages: this
943 .messages()
944 .map(|message| SerializedMessage {
945 id: message.id,
946 role: message.role,
947 segments: message
948 .segments
949 .iter()
950 .map(|segment| match segment {
951 MessageSegment::Text(text) => {
952 SerializedMessageSegment::Text { text: text.clone() }
953 }
954 MessageSegment::Thinking { text, signature } => {
955 SerializedMessageSegment::Thinking {
956 text: text.clone(),
957 signature: signature.clone(),
958 }
959 }
960 MessageSegment::RedactedThinking(data) => {
961 SerializedMessageSegment::RedactedThinking {
962 data: data.clone(),
963 }
964 }
965 })
966 .collect(),
967 tool_uses: this
968 .tool_uses_for_message(message.id, cx)
969 .into_iter()
970 .map(|tool_use| SerializedToolUse {
971 id: tool_use.id,
972 name: tool_use.name,
973 input: tool_use.input,
974 })
975 .collect(),
976 tool_results: this
977 .tool_results_for_message(message.id)
978 .into_iter()
979 .map(|tool_result| SerializedToolResult {
980 tool_use_id: tool_result.tool_use_id.clone(),
981 is_error: tool_result.is_error,
982 content: tool_result.content.clone(),
983 })
984 .collect(),
985 context: message.loaded_context.text.clone(),
986 })
987 .collect(),
988 initial_project_snapshot,
989 cumulative_token_usage: this.cumulative_token_usage,
990 request_token_usage: this.request_token_usage.clone(),
991 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
992 exceeded_window_error: this.exceeded_window_error.clone(),
993 model: this
994 .configured_model
995 .as_ref()
996 .map(|model| SerializedLanguageModel {
997 provider: model.provider.id().0.to_string(),
998 model: model.model.id().0.to_string(),
999 }),
1000 })
1001 })
1002 }
1003
1004 pub fn remaining_turns(&self) -> u32 {
1005 self.remaining_turns
1006 }
1007
1008 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1009 self.remaining_turns = remaining_turns;
1010 }
1011
1012 pub fn send_to_model(
1013 &mut self,
1014 model: Arc<dyn LanguageModel>,
1015 window: Option<AnyWindowHandle>,
1016 cx: &mut Context<Self>,
1017 ) {
1018 if self.remaining_turns == 0 {
1019 return;
1020 }
1021
1022 self.remaining_turns -= 1;
1023
1024 let request = self.to_completion_request(model.clone(), cx);
1025
1026 self.stream_completion(request, model, window, cx);
1027 }
1028
1029 pub fn used_tools_since_last_user_message(&self) -> bool {
1030 for message in self.messages.iter().rev() {
1031 if self.tool_use.message_has_tool_results(message.id) {
1032 return true;
1033 } else if message.role == Role::User {
1034 return false;
1035 }
1036 }
1037
1038 false
1039 }
1040
1041 pub fn to_completion_request(
1042 &self,
1043 model: Arc<dyn LanguageModel>,
1044 cx: &mut Context<Self>,
1045 ) -> LanguageModelRequest {
1046 let mut request = LanguageModelRequest {
1047 thread_id: Some(self.id.to_string()),
1048 prompt_id: Some(self.last_prompt_id.to_string()),
1049 mode: None,
1050 messages: vec![],
1051 tools: Vec::new(),
1052 stop: Vec::new(),
1053 temperature: None,
1054 };
1055
1056 let available_tools = self.available_tools(cx, model.clone());
1057 let available_tool_names = available_tools
1058 .iter()
1059 .map(|tool| tool.name.clone())
1060 .collect();
1061
1062 let model_context = &ModelContext {
1063 available_tools: available_tool_names,
1064 };
1065
1066 if let Some(project_context) = self.project_context.borrow().as_ref() {
1067 match self
1068 .prompt_builder
1069 .generate_assistant_system_prompt(project_context, model_context)
1070 {
1071 Err(err) => {
1072 let message = format!("{err:?}").into();
1073 log::error!("{message}");
1074 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1075 header: "Error generating system prompt".into(),
1076 message,
1077 }));
1078 }
1079 Ok(system_prompt) => {
1080 request.messages.push(LanguageModelRequestMessage {
1081 role: Role::System,
1082 content: vec![MessageContent::Text(system_prompt)],
1083 cache: true,
1084 });
1085 }
1086 }
1087 } else {
1088 let message = "Context for system prompt unexpectedly not ready.".into();
1089 log::error!("{message}");
1090 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1091 header: "Error generating system prompt".into(),
1092 message,
1093 }));
1094 }
1095
1096 for message in &self.messages {
1097 let mut request_message = LanguageModelRequestMessage {
1098 role: message.role,
1099 content: Vec::new(),
1100 cache: false,
1101 };
1102
1103 message
1104 .loaded_context
1105 .add_to_request_message(&mut request_message);
1106
1107 for segment in &message.segments {
1108 match segment {
1109 MessageSegment::Text(text) => {
1110 if !text.is_empty() {
1111 request_message
1112 .content
1113 .push(MessageContent::Text(text.into()));
1114 }
1115 }
1116 MessageSegment::Thinking { text, signature } => {
1117 if !text.is_empty() {
1118 request_message.content.push(MessageContent::Thinking {
1119 text: text.into(),
1120 signature: signature.clone(),
1121 });
1122 }
1123 }
1124 MessageSegment::RedactedThinking(data) => {
1125 request_message
1126 .content
1127 .push(MessageContent::RedactedThinking(data.clone()));
1128 }
1129 };
1130 }
1131
1132 self.tool_use
1133 .attach_tool_uses(message.id, &mut request_message);
1134
1135 request.messages.push(request_message);
1136
1137 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1138 request.messages.push(tool_results_message);
1139 }
1140 }
1141
1142 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1143 if let Some(last) = request.messages.last_mut() {
1144 last.cache = true;
1145 }
1146
1147 self.attached_tracked_files_state(&mut request.messages, cx);
1148
1149 request.tools = available_tools;
1150 request.mode = if model.supports_max_mode() {
1151 self.completion_mode
1152 } else {
1153 None
1154 };
1155
1156 request
1157 }
1158
1159 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1160 let mut request = LanguageModelRequest {
1161 thread_id: None,
1162 prompt_id: None,
1163 mode: None,
1164 messages: vec![],
1165 tools: Vec::new(),
1166 stop: Vec::new(),
1167 temperature: None,
1168 };
1169
1170 for message in &self.messages {
1171 let mut request_message = LanguageModelRequestMessage {
1172 role: message.role,
1173 content: Vec::new(),
1174 cache: false,
1175 };
1176
1177 for segment in &message.segments {
1178 match segment {
1179 MessageSegment::Text(text) => request_message
1180 .content
1181 .push(MessageContent::Text(text.clone())),
1182 MessageSegment::Thinking { .. } => {}
1183 MessageSegment::RedactedThinking(_) => {}
1184 }
1185 }
1186
1187 if request_message.content.is_empty() {
1188 continue;
1189 }
1190
1191 request.messages.push(request_message);
1192 }
1193
1194 request.messages.push(LanguageModelRequestMessage {
1195 role: Role::User,
1196 content: vec![MessageContent::Text(added_user_message)],
1197 cache: false,
1198 });
1199
1200 request
1201 }
1202
1203 fn attached_tracked_files_state(
1204 &self,
1205 messages: &mut Vec<LanguageModelRequestMessage>,
1206 cx: &App,
1207 ) {
1208 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1209
1210 let mut stale_message = String::new();
1211
1212 let action_log = self.action_log.read(cx);
1213
1214 for stale_file in action_log.stale_buffers(cx) {
1215 let Some(file) = stale_file.read(cx).file() else {
1216 continue;
1217 };
1218
1219 if stale_message.is_empty() {
1220 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1221 }
1222
1223 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1224 }
1225
1226 let mut content = Vec::with_capacity(2);
1227
1228 if !stale_message.is_empty() {
1229 content.push(stale_message.into());
1230 }
1231
1232 if !content.is_empty() {
1233 let context_message = LanguageModelRequestMessage {
1234 role: Role::User,
1235 content,
1236 cache: false,
1237 };
1238
1239 messages.push(context_message);
1240 }
1241 }
1242
1243 pub fn stream_completion(
1244 &mut self,
1245 request: LanguageModelRequest,
1246 model: Arc<dyn LanguageModel>,
1247 window: Option<AnyWindowHandle>,
1248 cx: &mut Context<Self>,
1249 ) {
1250 let pending_completion_id = post_inc(&mut self.completion_count);
1251 let mut request_callback_parameters = if self.request_callback.is_some() {
1252 Some((request.clone(), Vec::new()))
1253 } else {
1254 None
1255 };
1256 let prompt_id = self.last_prompt_id.clone();
1257 let tool_use_metadata = ToolUseMetadata {
1258 model: model.clone(),
1259 thread_id: self.id.clone(),
1260 prompt_id: prompt_id.clone(),
1261 };
1262
1263 let task = cx.spawn(async move |thread, cx| {
1264 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1265 let initial_token_usage =
1266 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1267 let stream_completion = async {
1268 let (mut events, usage) = stream_completion_future.await?;
1269
1270 let mut stop_reason = StopReason::EndTurn;
1271 let mut current_token_usage = TokenUsage::default();
1272
1273 if let Some(usage) = usage {
1274 thread
1275 .update(cx, |_thread, cx| {
1276 cx.emit(ThreadEvent::UsageUpdated(usage));
1277 })
1278 .ok();
1279 }
1280
1281 let mut request_assistant_message_id = None;
1282
1283 while let Some(event) = events.next().await {
1284 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1285 response_events
1286 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1287 }
1288
1289 thread.update(cx, |thread, cx| {
1290 let event = match event {
1291 Ok(event) => event,
1292 Err(LanguageModelCompletionError::BadInputJson {
1293 id,
1294 tool_name,
1295 raw_input: invalid_input_json,
1296 json_parse_error,
1297 }) => {
1298 thread.receive_invalid_tool_json(
1299 id,
1300 tool_name,
1301 invalid_input_json,
1302 json_parse_error,
1303 window,
1304 cx,
1305 );
1306 return Ok(());
1307 }
1308 Err(LanguageModelCompletionError::Other(error)) => {
1309 return Err(error);
1310 }
1311 };
1312
1313 match event {
1314 LanguageModelCompletionEvent::StartMessage { .. } => {
1315 request_assistant_message_id =
1316 Some(thread.insert_assistant_message(
1317 vec![MessageSegment::Text(String::new())],
1318 cx,
1319 ));
1320 }
1321 LanguageModelCompletionEvent::Stop(reason) => {
1322 stop_reason = reason;
1323 }
1324 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1325 thread.update_token_usage_at_last_message(token_usage);
1326 thread.cumulative_token_usage = thread.cumulative_token_usage
1327 + token_usage
1328 - current_token_usage;
1329 current_token_usage = token_usage;
1330 }
1331 LanguageModelCompletionEvent::Text(chunk) => {
1332 cx.emit(ThreadEvent::ReceivedTextChunk);
1333 if let Some(last_message) = thread.messages.last_mut() {
1334 if last_message.role == Role::Assistant
1335 && !thread.tool_use.has_tool_results(last_message.id)
1336 {
1337 last_message.push_text(&chunk);
1338 cx.emit(ThreadEvent::StreamedAssistantText(
1339 last_message.id,
1340 chunk,
1341 ));
1342 } else {
1343 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1344 // of a new Assistant response.
1345 //
1346 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1347 // will result in duplicating the text of the chunk in the rendered Markdown.
1348 request_assistant_message_id =
1349 Some(thread.insert_assistant_message(
1350 vec![MessageSegment::Text(chunk.to_string())],
1351 cx,
1352 ));
1353 };
1354 }
1355 }
1356 LanguageModelCompletionEvent::Thinking {
1357 text: chunk,
1358 signature,
1359 } => {
1360 if let Some(last_message) = thread.messages.last_mut() {
1361 if last_message.role == Role::Assistant
1362 && !thread.tool_use.has_tool_results(last_message.id)
1363 {
1364 last_message.push_thinking(&chunk, signature);
1365 cx.emit(ThreadEvent::StreamedAssistantThinking(
1366 last_message.id,
1367 chunk,
1368 ));
1369 } else {
1370 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1371 // of a new Assistant response.
1372 //
1373 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1374 // will result in duplicating the text of the chunk in the rendered Markdown.
1375 request_assistant_message_id =
1376 Some(thread.insert_assistant_message(
1377 vec![MessageSegment::Thinking {
1378 text: chunk.to_string(),
1379 signature,
1380 }],
1381 cx,
1382 ));
1383 };
1384 }
1385 }
1386 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1387 let last_assistant_message_id = request_assistant_message_id
1388 .unwrap_or_else(|| {
1389 let new_assistant_message_id =
1390 thread.insert_assistant_message(vec![], cx);
1391 request_assistant_message_id =
1392 Some(new_assistant_message_id);
1393 new_assistant_message_id
1394 });
1395
1396 let tool_use_id = tool_use.id.clone();
1397 let streamed_input = if tool_use.is_input_complete {
1398 None
1399 } else {
1400 Some((&tool_use.input).clone())
1401 };
1402
1403 let ui_text = thread.tool_use.request_tool_use(
1404 last_assistant_message_id,
1405 tool_use,
1406 tool_use_metadata.clone(),
1407 cx,
1408 );
1409
1410 if let Some(input) = streamed_input {
1411 cx.emit(ThreadEvent::StreamedToolUse {
1412 tool_use_id,
1413 ui_text,
1414 input,
1415 });
1416 }
1417 }
1418 }
1419
1420 thread.touch_updated_at();
1421 cx.emit(ThreadEvent::StreamedCompletion);
1422 cx.notify();
1423
1424 thread.auto_capture_telemetry(cx);
1425 Ok(())
1426 })??;
1427
1428 smol::future::yield_now().await;
1429 }
1430
1431 thread.update(cx, |thread, cx| {
1432 thread
1433 .pending_completions
1434 .retain(|completion| completion.id != pending_completion_id);
1435
1436 // If there is a response without tool use, summarize the message. Otherwise,
1437 // allow two tool uses before summarizing.
1438 if thread.summary.is_none()
1439 && thread.messages.len() >= 2
1440 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1441 {
1442 thread.summarize(cx);
1443 }
1444 })?;
1445
1446 anyhow::Ok(stop_reason)
1447 };
1448
1449 let result = stream_completion.await;
1450
1451 thread
1452 .update(cx, |thread, cx| {
1453 thread.finalize_pending_checkpoint(cx);
1454 match result.as_ref() {
1455 Ok(stop_reason) => match stop_reason {
1456 StopReason::ToolUse => {
1457 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1458 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1459 }
1460 StopReason::EndTurn => {}
1461 StopReason::MaxTokens => {}
1462 },
1463 Err(error) => {
1464 if error.is::<PaymentRequiredError>() {
1465 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1466 } else if error.is::<MaxMonthlySpendReachedError>() {
1467 cx.emit(ThreadEvent::ShowError(
1468 ThreadError::MaxMonthlySpendReached,
1469 ));
1470 } else if let Some(error) =
1471 error.downcast_ref::<ModelRequestLimitReachedError>()
1472 {
1473 cx.emit(ThreadEvent::ShowError(
1474 ThreadError::ModelRequestLimitReached { plan: error.plan },
1475 ));
1476 } else if let Some(known_error) =
1477 error.downcast_ref::<LanguageModelKnownError>()
1478 {
1479 match known_error {
1480 LanguageModelKnownError::ContextWindowLimitExceeded {
1481 tokens,
1482 } => {
1483 thread.exceeded_window_error = Some(ExceededWindowError {
1484 model_id: model.id(),
1485 token_count: *tokens,
1486 });
1487 cx.notify();
1488 }
1489 }
1490 } else {
1491 let error_message = error
1492 .chain()
1493 .map(|err| err.to_string())
1494 .collect::<Vec<_>>()
1495 .join("\n");
1496 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1497 header: "Error interacting with language model".into(),
1498 message: SharedString::from(error_message.clone()),
1499 }));
1500 }
1501
1502 thread.cancel_last_completion(window, cx);
1503 }
1504 }
1505 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1506
1507 if let Some((request_callback, (request, response_events))) = thread
1508 .request_callback
1509 .as_mut()
1510 .zip(request_callback_parameters.as_ref())
1511 {
1512 request_callback(request, response_events);
1513 }
1514
1515 thread.auto_capture_telemetry(cx);
1516
1517 if let Ok(initial_usage) = initial_token_usage {
1518 let usage = thread.cumulative_token_usage - initial_usage;
1519
1520 telemetry::event!(
1521 "Assistant Thread Completion",
1522 thread_id = thread.id().to_string(),
1523 prompt_id = prompt_id,
1524 model = model.telemetry_id(),
1525 model_provider = model.provider_id().to_string(),
1526 input_tokens = usage.input_tokens,
1527 output_tokens = usage.output_tokens,
1528 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1529 cache_read_input_tokens = usage.cache_read_input_tokens,
1530 );
1531 }
1532 })
1533 .ok();
1534 });
1535
1536 self.pending_completions.push(PendingCompletion {
1537 id: pending_completion_id,
1538 _task: task,
1539 });
1540 }
1541
1542 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1543 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1544 return;
1545 };
1546
1547 if !model.provider.is_authenticated(cx) {
1548 return;
1549 }
1550
1551 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1552 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1553 If the conversation is about a specific subject, include it in the title. \
1554 Be descriptive. DO NOT speak in the first person.";
1555
1556 let request = self.to_summarize_request(added_user_message.into());
1557
1558 self.pending_summary = cx.spawn(async move |this, cx| {
1559 async move {
1560 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1561 let (mut messages, usage) = stream.await?;
1562
1563 if let Some(usage) = usage {
1564 this.update(cx, |_thread, cx| {
1565 cx.emit(ThreadEvent::UsageUpdated(usage));
1566 })
1567 .ok();
1568 }
1569
1570 let mut new_summary = String::new();
1571 while let Some(message) = messages.stream.next().await {
1572 let text = message?;
1573 let mut lines = text.lines();
1574 new_summary.extend(lines.next());
1575
1576 // Stop if the LLM generated multiple lines.
1577 if lines.next().is_some() {
1578 break;
1579 }
1580 }
1581
1582 this.update(cx, |this, cx| {
1583 if !new_summary.is_empty() {
1584 this.summary = Some(new_summary.into());
1585 }
1586
1587 cx.emit(ThreadEvent::SummaryGenerated);
1588 })?;
1589
1590 anyhow::Ok(())
1591 }
1592 .log_err()
1593 .await
1594 });
1595 }
1596
1597 pub fn start_generating_detailed_summary_if_needed(
1598 &mut self,
1599 thread_store: WeakEntity<ThreadStore>,
1600 cx: &mut Context<Self>,
1601 ) {
1602 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1603 return;
1604 };
1605
1606 match &*self.detailed_summary_rx.borrow() {
1607 DetailedSummaryState::Generating { message_id, .. }
1608 | DetailedSummaryState::Generated { message_id, .. }
1609 if *message_id == last_message_id =>
1610 {
1611 // Already up-to-date
1612 return;
1613 }
1614 _ => {}
1615 }
1616
1617 let Some(ConfiguredModel { model, provider }) =
1618 LanguageModelRegistry::read_global(cx).thread_summary_model()
1619 else {
1620 return;
1621 };
1622
1623 if !provider.is_authenticated(cx) {
1624 return;
1625 }
1626
1627 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1628 1. A brief overview of what was discussed\n\
1629 2. Key facts or information discovered\n\
1630 3. Outcomes or conclusions reached\n\
1631 4. Any action items or next steps if any\n\
1632 Format it in Markdown with headings and bullet points.";
1633
1634 let request = self.to_summarize_request(added_user_message.into());
1635
1636 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1637 message_id: last_message_id,
1638 };
1639
1640 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1641 // be better to allow the old task to complete, but this would require logic for choosing
1642 // which result to prefer (the old task could complete after the new one, resulting in a
1643 // stale summary).
1644 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1645 let stream = model.stream_completion_text(request, &cx);
1646 let Some(mut messages) = stream.await.log_err() else {
1647 thread
1648 .update(cx, |thread, _cx| {
1649 *thread.detailed_summary_tx.borrow_mut() =
1650 DetailedSummaryState::NotGenerated;
1651 })
1652 .ok()?;
1653 return None;
1654 };
1655
1656 let mut new_detailed_summary = String::new();
1657
1658 while let Some(chunk) = messages.stream.next().await {
1659 if let Some(chunk) = chunk.log_err() {
1660 new_detailed_summary.push_str(&chunk);
1661 }
1662 }
1663
1664 thread
1665 .update(cx, |thread, _cx| {
1666 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1667 text: new_detailed_summary.into(),
1668 message_id: last_message_id,
1669 };
1670 })
1671 .ok()?;
1672
1673 // Save thread so its summary can be reused later
1674 if let Some(thread) = thread.upgrade() {
1675 if let Ok(Ok(save_task)) = cx.update(|cx| {
1676 thread_store
1677 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1678 }) {
1679 save_task.await.log_err();
1680 }
1681 }
1682
1683 Some(())
1684 });
1685 }
1686
1687 pub async fn wait_for_detailed_summary_or_text(
1688 this: &Entity<Self>,
1689 cx: &mut AsyncApp,
1690 ) -> Option<SharedString> {
1691 let mut detailed_summary_rx = this
1692 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1693 .ok()?;
1694 loop {
1695 match detailed_summary_rx.recv().await? {
1696 DetailedSummaryState::Generating { .. } => {}
1697 DetailedSummaryState::NotGenerated => {
1698 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1699 }
1700 DetailedSummaryState::Generated { text, .. } => return Some(text),
1701 }
1702 }
1703 }
1704
1705 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1706 self.detailed_summary_rx
1707 .borrow()
1708 .text()
1709 .unwrap_or_else(|| self.text().into())
1710 }
1711
1712 pub fn is_generating_detailed_summary(&self) -> bool {
1713 matches!(
1714 &*self.detailed_summary_rx.borrow(),
1715 DetailedSummaryState::Generating { .. }
1716 )
1717 }
1718
1719 pub fn use_pending_tools(
1720 &mut self,
1721 window: Option<AnyWindowHandle>,
1722 cx: &mut Context<Self>,
1723 model: Arc<dyn LanguageModel>,
1724 ) -> Vec<PendingToolUse> {
1725 self.auto_capture_telemetry(cx);
1726 let request = self.to_completion_request(model, cx);
1727 let messages = Arc::new(request.messages);
1728 let pending_tool_uses = self
1729 .tool_use
1730 .pending_tool_uses()
1731 .into_iter()
1732 .filter(|tool_use| tool_use.status.is_idle())
1733 .cloned()
1734 .collect::<Vec<_>>();
1735
1736 for tool_use in pending_tool_uses.iter() {
1737 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1738 if tool.needs_confirmation(&tool_use.input, cx)
1739 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1740 {
1741 self.tool_use.confirm_tool_use(
1742 tool_use.id.clone(),
1743 tool_use.ui_text.clone(),
1744 tool_use.input.clone(),
1745 messages.clone(),
1746 tool,
1747 );
1748 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1749 } else {
1750 self.run_tool(
1751 tool_use.id.clone(),
1752 tool_use.ui_text.clone(),
1753 tool_use.input.clone(),
1754 &messages,
1755 tool,
1756 window,
1757 cx,
1758 );
1759 }
1760 }
1761 }
1762
1763 pending_tool_uses
1764 }
1765
1766 pub fn receive_invalid_tool_json(
1767 &mut self,
1768 tool_use_id: LanguageModelToolUseId,
1769 tool_name: Arc<str>,
1770 invalid_json: Arc<str>,
1771 error: String,
1772 window: Option<AnyWindowHandle>,
1773 cx: &mut Context<Thread>,
1774 ) {
1775 log::error!("The model returned invalid input JSON: {invalid_json}");
1776
1777 let pending_tool_use = self.tool_use.insert_tool_output(
1778 tool_use_id.clone(),
1779 tool_name,
1780 Err(anyhow!("Error parsing input JSON: {error}")),
1781 self.configured_model.as_ref(),
1782 );
1783 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1784 pending_tool_use.ui_text.clone()
1785 } else {
1786 log::error!(
1787 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1788 );
1789 format!("Unknown tool {}", tool_use_id).into()
1790 };
1791
1792 cx.emit(ThreadEvent::InvalidToolInput {
1793 tool_use_id: tool_use_id.clone(),
1794 ui_text,
1795 invalid_input_json: invalid_json,
1796 });
1797
1798 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1799 }
1800
1801 pub fn run_tool(
1802 &mut self,
1803 tool_use_id: LanguageModelToolUseId,
1804 ui_text: impl Into<SharedString>,
1805 input: serde_json::Value,
1806 messages: &[LanguageModelRequestMessage],
1807 tool: Arc<dyn Tool>,
1808 window: Option<AnyWindowHandle>,
1809 cx: &mut Context<Thread>,
1810 ) {
1811 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1812 self.tool_use
1813 .run_pending_tool(tool_use_id, ui_text.into(), task);
1814 }
1815
1816 fn spawn_tool_use(
1817 &mut self,
1818 tool_use_id: LanguageModelToolUseId,
1819 messages: &[LanguageModelRequestMessage],
1820 input: serde_json::Value,
1821 tool: Arc<dyn Tool>,
1822 window: Option<AnyWindowHandle>,
1823 cx: &mut Context<Thread>,
1824 ) -> Task<()> {
1825 let tool_name: Arc<str> = tool.name().into();
1826
1827 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1828 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1829 } else {
1830 tool.run(
1831 input,
1832 messages,
1833 self.project.clone(),
1834 self.action_log.clone(),
1835 window,
1836 cx,
1837 )
1838 };
1839
1840 // Store the card separately if it exists
1841 if let Some(card) = tool_result.card.clone() {
1842 self.tool_use
1843 .insert_tool_result_card(tool_use_id.clone(), card);
1844 }
1845
1846 cx.spawn({
1847 async move |thread: WeakEntity<Thread>, cx| {
1848 let output = tool_result.output.await;
1849
1850 thread
1851 .update(cx, |thread, cx| {
1852 let pending_tool_use = thread.tool_use.insert_tool_output(
1853 tool_use_id.clone(),
1854 tool_name,
1855 output,
1856 thread.configured_model.as_ref(),
1857 );
1858 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1859 })
1860 .ok();
1861 }
1862 })
1863 }
1864
1865 fn tool_finished(
1866 &mut self,
1867 tool_use_id: LanguageModelToolUseId,
1868 pending_tool_use: Option<PendingToolUse>,
1869 canceled: bool,
1870 window: Option<AnyWindowHandle>,
1871 cx: &mut Context<Self>,
1872 ) {
1873 if self.all_tools_finished() {
1874 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
1875 if !canceled {
1876 self.send_to_model(model.clone(), window, cx);
1877 }
1878 self.auto_capture_telemetry(cx);
1879 }
1880 }
1881
1882 cx.emit(ThreadEvent::ToolFinished {
1883 tool_use_id,
1884 pending_tool_use,
1885 });
1886 }
1887
1888 /// Cancels the last pending completion, if there are any pending.
1889 ///
1890 /// Returns whether a completion was canceled.
1891 pub fn cancel_last_completion(
1892 &mut self,
1893 window: Option<AnyWindowHandle>,
1894 cx: &mut Context<Self>,
1895 ) -> bool {
1896 let mut canceled = self.pending_completions.pop().is_some();
1897
1898 for pending_tool_use in self.tool_use.cancel_pending() {
1899 canceled = true;
1900 self.tool_finished(
1901 pending_tool_use.id.clone(),
1902 Some(pending_tool_use),
1903 true,
1904 window,
1905 cx,
1906 );
1907 }
1908
1909 self.finalize_pending_checkpoint(cx);
1910 canceled
1911 }
1912
1913 /// Signals that any in-progress editing should be canceled.
1914 ///
1915 /// This method is used to notify listeners (like ActiveThread) that
1916 /// they should cancel any editing operations.
1917 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
1918 cx.emit(ThreadEvent::CancelEditing);
1919 }
1920
1921 pub fn feedback(&self) -> Option<ThreadFeedback> {
1922 self.feedback
1923 }
1924
1925 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1926 self.message_feedback.get(&message_id).copied()
1927 }
1928
1929 pub fn report_message_feedback(
1930 &mut self,
1931 message_id: MessageId,
1932 feedback: ThreadFeedback,
1933 cx: &mut Context<Self>,
1934 ) -> Task<Result<()>> {
1935 if self.message_feedback.get(&message_id) == Some(&feedback) {
1936 return Task::ready(Ok(()));
1937 }
1938
1939 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1940 let serialized_thread = self.serialize(cx);
1941 let thread_id = self.id().clone();
1942 let client = self.project.read(cx).client();
1943
1944 let enabled_tool_names: Vec<String> = self
1945 .tools()
1946 .read(cx)
1947 .enabled_tools(cx)
1948 .iter()
1949 .map(|tool| tool.name().to_string())
1950 .collect();
1951
1952 self.message_feedback.insert(message_id, feedback);
1953
1954 cx.notify();
1955
1956 let message_content = self
1957 .message(message_id)
1958 .map(|msg| msg.to_string())
1959 .unwrap_or_default();
1960
1961 cx.background_spawn(async move {
1962 let final_project_snapshot = final_project_snapshot.await;
1963 let serialized_thread = serialized_thread.await?;
1964 let thread_data =
1965 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1966
1967 let rating = match feedback {
1968 ThreadFeedback::Positive => "positive",
1969 ThreadFeedback::Negative => "negative",
1970 };
1971 telemetry::event!(
1972 "Assistant Thread Rated",
1973 rating,
1974 thread_id,
1975 enabled_tool_names,
1976 message_id = message_id.0,
1977 message_content,
1978 thread_data,
1979 final_project_snapshot
1980 );
1981 client.telemetry().flush_events().await;
1982
1983 Ok(())
1984 })
1985 }
1986
1987 pub fn report_feedback(
1988 &mut self,
1989 feedback: ThreadFeedback,
1990 cx: &mut Context<Self>,
1991 ) -> Task<Result<()>> {
1992 let last_assistant_message_id = self
1993 .messages
1994 .iter()
1995 .rev()
1996 .find(|msg| msg.role == Role::Assistant)
1997 .map(|msg| msg.id);
1998
1999 if let Some(message_id) = last_assistant_message_id {
2000 self.report_message_feedback(message_id, feedback, cx)
2001 } else {
2002 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2003 let serialized_thread = self.serialize(cx);
2004 let thread_id = self.id().clone();
2005 let client = self.project.read(cx).client();
2006 self.feedback = Some(feedback);
2007 cx.notify();
2008
2009 cx.background_spawn(async move {
2010 let final_project_snapshot = final_project_snapshot.await;
2011 let serialized_thread = serialized_thread.await?;
2012 let thread_data = serde_json::to_value(serialized_thread)
2013 .unwrap_or_else(|_| serde_json::Value::Null);
2014
2015 let rating = match feedback {
2016 ThreadFeedback::Positive => "positive",
2017 ThreadFeedback::Negative => "negative",
2018 };
2019 telemetry::event!(
2020 "Assistant Thread Rated",
2021 rating,
2022 thread_id,
2023 thread_data,
2024 final_project_snapshot
2025 );
2026 client.telemetry().flush_events().await;
2027
2028 Ok(())
2029 })
2030 }
2031 }
2032
2033 /// Create a snapshot of the current project state including git information and unsaved buffers.
2034 fn project_snapshot(
2035 project: Entity<Project>,
2036 cx: &mut Context<Self>,
2037 ) -> Task<Arc<ProjectSnapshot>> {
2038 let git_store = project.read(cx).git_store().clone();
2039 let worktree_snapshots: Vec<_> = project
2040 .read(cx)
2041 .visible_worktrees(cx)
2042 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2043 .collect();
2044
2045 cx.spawn(async move |_, cx| {
2046 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2047
2048 let mut unsaved_buffers = Vec::new();
2049 cx.update(|app_cx| {
2050 let buffer_store = project.read(app_cx).buffer_store();
2051 for buffer_handle in buffer_store.read(app_cx).buffers() {
2052 let buffer = buffer_handle.read(app_cx);
2053 if buffer.is_dirty() {
2054 if let Some(file) = buffer.file() {
2055 let path = file.path().to_string_lossy().to_string();
2056 unsaved_buffers.push(path);
2057 }
2058 }
2059 }
2060 })
2061 .ok();
2062
2063 Arc::new(ProjectSnapshot {
2064 worktree_snapshots,
2065 unsaved_buffer_paths: unsaved_buffers,
2066 timestamp: Utc::now(),
2067 })
2068 })
2069 }
2070
2071 fn worktree_snapshot(
2072 worktree: Entity<project::Worktree>,
2073 git_store: Entity<GitStore>,
2074 cx: &App,
2075 ) -> Task<WorktreeSnapshot> {
2076 cx.spawn(async move |cx| {
2077 // Get worktree path and snapshot
2078 let worktree_info = cx.update(|app_cx| {
2079 let worktree = worktree.read(app_cx);
2080 let path = worktree.abs_path().to_string_lossy().to_string();
2081 let snapshot = worktree.snapshot();
2082 (path, snapshot)
2083 });
2084
2085 let Ok((worktree_path, _snapshot)) = worktree_info else {
2086 return WorktreeSnapshot {
2087 worktree_path: String::new(),
2088 git_state: None,
2089 };
2090 };
2091
2092 let git_state = git_store
2093 .update(cx, |git_store, cx| {
2094 git_store
2095 .repositories()
2096 .values()
2097 .find(|repo| {
2098 repo.read(cx)
2099 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2100 .is_some()
2101 })
2102 .cloned()
2103 })
2104 .ok()
2105 .flatten()
2106 .map(|repo| {
2107 repo.update(cx, |repo, _| {
2108 let current_branch =
2109 repo.branch.as_ref().map(|branch| branch.name.to_string());
2110 repo.send_job(None, |state, _| async move {
2111 let RepositoryState::Local { backend, .. } = state else {
2112 return GitState {
2113 remote_url: None,
2114 head_sha: None,
2115 current_branch,
2116 diff: None,
2117 };
2118 };
2119
2120 let remote_url = backend.remote_url("origin");
2121 let head_sha = backend.head_sha().await;
2122 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2123
2124 GitState {
2125 remote_url,
2126 head_sha,
2127 current_branch,
2128 diff,
2129 }
2130 })
2131 })
2132 });
2133
2134 let git_state = match git_state {
2135 Some(git_state) => match git_state.ok() {
2136 Some(git_state) => git_state.await.ok(),
2137 None => None,
2138 },
2139 None => None,
2140 };
2141
2142 WorktreeSnapshot {
2143 worktree_path,
2144 git_state,
2145 }
2146 })
2147 }
2148
2149 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2150 let mut markdown = Vec::new();
2151
2152 if let Some(summary) = self.summary() {
2153 writeln!(markdown, "# {summary}\n")?;
2154 };
2155
2156 for message in self.messages() {
2157 writeln!(
2158 markdown,
2159 "## {role}\n",
2160 role = match message.role {
2161 Role::User => "User",
2162 Role::Assistant => "Assistant",
2163 Role::System => "System",
2164 }
2165 )?;
2166
2167 if !message.loaded_context.text.is_empty() {
2168 writeln!(markdown, "{}", message.loaded_context.text)?;
2169 }
2170
2171 if !message.loaded_context.images.is_empty() {
2172 writeln!(
2173 markdown,
2174 "\n{} images attached as context.\n",
2175 message.loaded_context.images.len()
2176 )?;
2177 }
2178
2179 for segment in &message.segments {
2180 match segment {
2181 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2182 MessageSegment::Thinking { text, .. } => {
2183 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2184 }
2185 MessageSegment::RedactedThinking(_) => {}
2186 }
2187 }
2188
2189 for tool_use in self.tool_uses_for_message(message.id, cx) {
2190 writeln!(
2191 markdown,
2192 "**Use Tool: {} ({})**",
2193 tool_use.name, tool_use.id
2194 )?;
2195 writeln!(markdown, "```json")?;
2196 writeln!(
2197 markdown,
2198 "{}",
2199 serde_json::to_string_pretty(&tool_use.input)?
2200 )?;
2201 writeln!(markdown, "```")?;
2202 }
2203
2204 for tool_result in self.tool_results_for_message(message.id) {
2205 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2206 if tool_result.is_error {
2207 write!(markdown, " (Error)")?;
2208 }
2209
2210 writeln!(markdown, "**\n")?;
2211 writeln!(markdown, "{}", tool_result.content)?;
2212 }
2213 }
2214
2215 Ok(String::from_utf8_lossy(&markdown).to_string())
2216 }
2217
2218 pub fn keep_edits_in_range(
2219 &mut self,
2220 buffer: Entity<language::Buffer>,
2221 buffer_range: Range<language::Anchor>,
2222 cx: &mut Context<Self>,
2223 ) {
2224 self.action_log.update(cx, |action_log, cx| {
2225 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2226 });
2227 }
2228
2229 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2230 self.action_log
2231 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2232 }
2233
2234 pub fn reject_edits_in_ranges(
2235 &mut self,
2236 buffer: Entity<language::Buffer>,
2237 buffer_ranges: Vec<Range<language::Anchor>>,
2238 cx: &mut Context<Self>,
2239 ) -> Task<Result<()>> {
2240 self.action_log.update(cx, |action_log, cx| {
2241 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2242 })
2243 }
2244
2245 pub fn action_log(&self) -> &Entity<ActionLog> {
2246 &self.action_log
2247 }
2248
2249 pub fn project(&self) -> &Entity<Project> {
2250 &self.project
2251 }
2252
2253 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2254 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2255 return;
2256 }
2257
2258 let now = Instant::now();
2259 if let Some(last) = self.last_auto_capture_at {
2260 if now.duration_since(last).as_secs() < 10 {
2261 return;
2262 }
2263 }
2264
2265 self.last_auto_capture_at = Some(now);
2266
2267 let thread_id = self.id().clone();
2268 let github_login = self
2269 .project
2270 .read(cx)
2271 .user_store()
2272 .read(cx)
2273 .current_user()
2274 .map(|user| user.github_login.clone());
2275 let client = self.project.read(cx).client().clone();
2276 let serialize_task = self.serialize(cx);
2277
2278 cx.background_executor()
2279 .spawn(async move {
2280 if let Ok(serialized_thread) = serialize_task.await {
2281 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2282 telemetry::event!(
2283 "Agent Thread Auto-Captured",
2284 thread_id = thread_id.to_string(),
2285 thread_data = thread_data,
2286 auto_capture_reason = "tracked_user",
2287 github_login = github_login
2288 );
2289
2290 client.telemetry().flush_events().await;
2291 }
2292 }
2293 })
2294 .detach();
2295 }
2296
2297 pub fn cumulative_token_usage(&self) -> TokenUsage {
2298 self.cumulative_token_usage
2299 }
2300
2301 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2302 let Some(model) = self.configured_model.as_ref() else {
2303 return TotalTokenUsage::default();
2304 };
2305
2306 let max = model.model.max_token_count();
2307
2308 let index = self
2309 .messages
2310 .iter()
2311 .position(|msg| msg.id == message_id)
2312 .unwrap_or(0);
2313
2314 if index == 0 {
2315 return TotalTokenUsage { total: 0, max };
2316 }
2317
2318 let token_usage = &self
2319 .request_token_usage
2320 .get(index - 1)
2321 .cloned()
2322 .unwrap_or_default();
2323
2324 TotalTokenUsage {
2325 total: token_usage.total_tokens() as usize,
2326 max,
2327 }
2328 }
2329
2330 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2331 let model = self.configured_model.as_ref()?;
2332
2333 let max = model.model.max_token_count();
2334
2335 if let Some(exceeded_error) = &self.exceeded_window_error {
2336 if model.model.id() == exceeded_error.model_id {
2337 return Some(TotalTokenUsage {
2338 total: exceeded_error.token_count,
2339 max,
2340 });
2341 }
2342 }
2343
2344 let total = self
2345 .token_usage_at_last_message()
2346 .unwrap_or_default()
2347 .total_tokens() as usize;
2348
2349 Some(TotalTokenUsage { total, max })
2350 }
2351
2352 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2353 self.request_token_usage
2354 .get(self.messages.len().saturating_sub(1))
2355 .or_else(|| self.request_token_usage.last())
2356 .cloned()
2357 }
2358
2359 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2360 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2361 self.request_token_usage
2362 .resize(self.messages.len(), placeholder);
2363
2364 if let Some(last) = self.request_token_usage.last_mut() {
2365 *last = token_usage;
2366 }
2367 }
2368
2369 pub fn deny_tool_use(
2370 &mut self,
2371 tool_use_id: LanguageModelToolUseId,
2372 tool_name: Arc<str>,
2373 window: Option<AnyWindowHandle>,
2374 cx: &mut Context<Self>,
2375 ) {
2376 let err = Err(anyhow::anyhow!(
2377 "Permission to run tool action denied by user"
2378 ));
2379
2380 self.tool_use.insert_tool_output(
2381 tool_use_id.clone(),
2382 tool_name,
2383 err,
2384 self.configured_model.as_ref(),
2385 );
2386 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2387 }
2388}
2389
2390#[derive(Debug, Clone, Error)]
2391pub enum ThreadError {
2392 #[error("Payment required")]
2393 PaymentRequired,
2394 #[error("Max monthly spend reached")]
2395 MaxMonthlySpendReached,
2396 #[error("Model request limit reached")]
2397 ModelRequestLimitReached { plan: Plan },
2398 #[error("Message {header}: {message}")]
2399 Message {
2400 header: SharedString,
2401 message: SharedString,
2402 },
2403}
2404
2405#[derive(Debug, Clone)]
2406pub enum ThreadEvent {
2407 ShowError(ThreadError),
2408 UsageUpdated(RequestUsage),
2409 StreamedCompletion,
2410 ReceivedTextChunk,
2411 StreamedAssistantText(MessageId, String),
2412 StreamedAssistantThinking(MessageId, String),
2413 StreamedToolUse {
2414 tool_use_id: LanguageModelToolUseId,
2415 ui_text: Arc<str>,
2416 input: serde_json::Value,
2417 },
2418 InvalidToolInput {
2419 tool_use_id: LanguageModelToolUseId,
2420 ui_text: Arc<str>,
2421 invalid_input_json: Arc<str>,
2422 },
2423 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2424 MessageAdded(MessageId),
2425 MessageEdited(MessageId),
2426 MessageDeleted(MessageId),
2427 SummaryGenerated,
2428 SummaryChanged,
2429 UsePendingTools {
2430 tool_uses: Vec<PendingToolUse>,
2431 },
2432 ToolFinished {
2433 #[allow(unused)]
2434 tool_use_id: LanguageModelToolUseId,
2435 /// The pending tool use that corresponds to this tool.
2436 pending_tool_use: Option<PendingToolUse>,
2437 },
2438 CheckpointChanged,
2439 ToolConfirmationNeeded,
2440 CancelEditing,
2441}
2442
2443impl EventEmitter<ThreadEvent> for Thread {}
2444
2445struct PendingCompletion {
2446 id: usize,
2447 _task: Task<()>,
2448}
2449
2450#[cfg(test)]
2451mod tests {
2452 use super::*;
2453 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2454 use assistant_settings::AssistantSettings;
2455 use assistant_tool::ToolRegistry;
2456 use context_server::ContextServerSettings;
2457 use editor::EditorSettings;
2458 use gpui::TestAppContext;
2459 use language_model::fake_provider::FakeLanguageModel;
2460 use project::{FakeFs, Project};
2461 use prompt_store::PromptBuilder;
2462 use serde_json::json;
2463 use settings::{Settings, SettingsStore};
2464 use std::sync::Arc;
2465 use theme::ThemeSettings;
2466 use util::path;
2467 use workspace::Workspace;
2468
2469 #[gpui::test]
2470 async fn test_message_with_context(cx: &mut TestAppContext) {
2471 init_test_settings(cx);
2472
2473 let project = create_test_project(
2474 cx,
2475 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2476 )
2477 .await;
2478
2479 let (_workspace, _thread_store, thread, context_store, model) =
2480 setup_test_environment(cx, project.clone()).await;
2481
2482 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2483 .await
2484 .unwrap();
2485
2486 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2487 let loaded_context = cx
2488 .update(|cx| load_context(vec![context], &project, &None, cx))
2489 .await;
2490
2491 // Insert user message with context
2492 let message_id = thread.update(cx, |thread, cx| {
2493 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2494 });
2495
2496 // Check content and context in message object
2497 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2498
2499 // Use different path format strings based on platform for the test
2500 #[cfg(windows)]
2501 let path_part = r"test\code.rs";
2502 #[cfg(not(windows))]
2503 let path_part = "test/code.rs";
2504
2505 let expected_context = format!(
2506 r#"
2507<context>
2508The following items were attached by the user. You don't need to use other tools to read them.
2509
2510<files>
2511```rs {path_part}
2512fn main() {{
2513 println!("Hello, world!");
2514}}
2515```
2516</files>
2517</context>
2518"#
2519 );
2520
2521 assert_eq!(message.role, Role::User);
2522 assert_eq!(message.segments.len(), 1);
2523 assert_eq!(
2524 message.segments[0],
2525 MessageSegment::Text("Please explain this code".to_string())
2526 );
2527 assert_eq!(message.loaded_context.text, expected_context);
2528
2529 // Check message in request
2530 let request = thread.update(cx, |thread, cx| {
2531 thread.to_completion_request(model.clone(), cx)
2532 });
2533
2534 assert_eq!(request.messages.len(), 2);
2535 let expected_full_message = format!("{}Please explain this code", expected_context);
2536 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2537 }
2538
2539 #[gpui::test]
2540 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2541 init_test_settings(cx);
2542
2543 let project = create_test_project(
2544 cx,
2545 json!({
2546 "file1.rs": "fn function1() {}\n",
2547 "file2.rs": "fn function2() {}\n",
2548 "file3.rs": "fn function3() {}\n",
2549 }),
2550 )
2551 .await;
2552
2553 let (_, _thread_store, thread, context_store, model) =
2554 setup_test_environment(cx, project.clone()).await;
2555
2556 // First message with context 1
2557 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2558 .await
2559 .unwrap();
2560 let new_contexts = context_store.update(cx, |store, cx| {
2561 store.new_context_for_thread(thread.read(cx))
2562 });
2563 assert_eq!(new_contexts.len(), 1);
2564 let loaded_context = cx
2565 .update(|cx| load_context(new_contexts, &project, &None, cx))
2566 .await;
2567 let message1_id = thread.update(cx, |thread, cx| {
2568 thread.insert_user_message("Message 1", loaded_context, None, cx)
2569 });
2570
2571 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2572 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2573 .await
2574 .unwrap();
2575 let new_contexts = context_store.update(cx, |store, cx| {
2576 store.new_context_for_thread(thread.read(cx))
2577 });
2578 assert_eq!(new_contexts.len(), 1);
2579 let loaded_context = cx
2580 .update(|cx| load_context(new_contexts, &project, &None, cx))
2581 .await;
2582 let message2_id = thread.update(cx, |thread, cx| {
2583 thread.insert_user_message("Message 2", loaded_context, None, cx)
2584 });
2585
2586 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2587 //
2588 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2589 .await
2590 .unwrap();
2591 let new_contexts = context_store.update(cx, |store, cx| {
2592 store.new_context_for_thread(thread.read(cx))
2593 });
2594 assert_eq!(new_contexts.len(), 1);
2595 let loaded_context = cx
2596 .update(|cx| load_context(new_contexts, &project, &None, cx))
2597 .await;
2598 let message3_id = thread.update(cx, |thread, cx| {
2599 thread.insert_user_message("Message 3", loaded_context, None, cx)
2600 });
2601
2602 // Check what contexts are included in each message
2603 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2604 (
2605 thread.message(message1_id).unwrap().clone(),
2606 thread.message(message2_id).unwrap().clone(),
2607 thread.message(message3_id).unwrap().clone(),
2608 )
2609 });
2610
2611 // First message should include context 1
2612 assert!(message1.loaded_context.text.contains("file1.rs"));
2613
2614 // Second message should include only context 2 (not 1)
2615 assert!(!message2.loaded_context.text.contains("file1.rs"));
2616 assert!(message2.loaded_context.text.contains("file2.rs"));
2617
2618 // Third message should include only context 3 (not 1 or 2)
2619 assert!(!message3.loaded_context.text.contains("file1.rs"));
2620 assert!(!message3.loaded_context.text.contains("file2.rs"));
2621 assert!(message3.loaded_context.text.contains("file3.rs"));
2622
2623 // Check entire request to make sure all contexts are properly included
2624 let request = thread.update(cx, |thread, cx| {
2625 thread.to_completion_request(model.clone(), cx)
2626 });
2627
2628 // The request should contain all 3 messages
2629 assert_eq!(request.messages.len(), 4);
2630
2631 // Check that the contexts are properly formatted in each message
2632 assert!(request.messages[1].string_contents().contains("file1.rs"));
2633 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2634 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2635
2636 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2637 assert!(request.messages[2].string_contents().contains("file2.rs"));
2638 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2639
2640 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2641 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2642 assert!(request.messages[3].string_contents().contains("file3.rs"));
2643 }
2644
2645 #[gpui::test]
2646 async fn test_message_without_files(cx: &mut TestAppContext) {
2647 init_test_settings(cx);
2648
2649 let project = create_test_project(
2650 cx,
2651 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2652 )
2653 .await;
2654
2655 let (_, _thread_store, thread, _context_store, model) =
2656 setup_test_environment(cx, project.clone()).await;
2657
2658 // Insert user message without any context (empty context vector)
2659 let message_id = thread.update(cx, |thread, cx| {
2660 thread.insert_user_message(
2661 "What is the best way to learn Rust?",
2662 ContextLoadResult::default(),
2663 None,
2664 cx,
2665 )
2666 });
2667
2668 // Check content and context in message object
2669 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2670
2671 // Context should be empty when no files are included
2672 assert_eq!(message.role, Role::User);
2673 assert_eq!(message.segments.len(), 1);
2674 assert_eq!(
2675 message.segments[0],
2676 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2677 );
2678 assert_eq!(message.loaded_context.text, "");
2679
2680 // Check message in request
2681 let request = thread.update(cx, |thread, cx| {
2682 thread.to_completion_request(model.clone(), cx)
2683 });
2684
2685 assert_eq!(request.messages.len(), 2);
2686 assert_eq!(
2687 request.messages[1].string_contents(),
2688 "What is the best way to learn Rust?"
2689 );
2690
2691 // Add second message, also without context
2692 let message2_id = thread.update(cx, |thread, cx| {
2693 thread.insert_user_message(
2694 "Are there any good books?",
2695 ContextLoadResult::default(),
2696 None,
2697 cx,
2698 )
2699 });
2700
2701 let message2 =
2702 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2703 assert_eq!(message2.loaded_context.text, "");
2704
2705 // Check that both messages appear in the request
2706 let request = thread.update(cx, |thread, cx| {
2707 thread.to_completion_request(model.clone(), cx)
2708 });
2709
2710 assert_eq!(request.messages.len(), 3);
2711 assert_eq!(
2712 request.messages[1].string_contents(),
2713 "What is the best way to learn Rust?"
2714 );
2715 assert_eq!(
2716 request.messages[2].string_contents(),
2717 "Are there any good books?"
2718 );
2719 }
2720
2721 #[gpui::test]
2722 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2723 init_test_settings(cx);
2724
2725 let project = create_test_project(
2726 cx,
2727 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2728 )
2729 .await;
2730
2731 let (_workspace, _thread_store, thread, context_store, model) =
2732 setup_test_environment(cx, project.clone()).await;
2733
2734 // Open buffer and add it to context
2735 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2736 .await
2737 .unwrap();
2738
2739 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2740 let loaded_context = cx
2741 .update(|cx| load_context(vec![context], &project, &None, cx))
2742 .await;
2743
2744 // Insert user message with the buffer as context
2745 thread.update(cx, |thread, cx| {
2746 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2747 });
2748
2749 // Create a request and check that it doesn't have a stale buffer warning yet
2750 let initial_request = thread.update(cx, |thread, cx| {
2751 thread.to_completion_request(model.clone(), cx)
2752 });
2753
2754 // Make sure we don't have a stale file warning yet
2755 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2756 msg.string_contents()
2757 .contains("These files changed since last read:")
2758 });
2759 assert!(
2760 !has_stale_warning,
2761 "Should not have stale buffer warning before buffer is modified"
2762 );
2763
2764 // Modify the buffer
2765 buffer.update(cx, |buffer, cx| {
2766 // Find a position at the end of line 1
2767 buffer.edit(
2768 [(1..1, "\n println!(\"Added a new line\");\n")],
2769 None,
2770 cx,
2771 );
2772 });
2773
2774 // Insert another user message without context
2775 thread.update(cx, |thread, cx| {
2776 thread.insert_user_message(
2777 "What does the code do now?",
2778 ContextLoadResult::default(),
2779 None,
2780 cx,
2781 )
2782 });
2783
2784 // Create a new request and check for the stale buffer warning
2785 let new_request = thread.update(cx, |thread, cx| {
2786 thread.to_completion_request(model.clone(), cx)
2787 });
2788
2789 // We should have a stale file warning as the last message
2790 let last_message = new_request
2791 .messages
2792 .last()
2793 .expect("Request should have messages");
2794
2795 // The last message should be the stale buffer notification
2796 assert_eq!(last_message.role, Role::User);
2797
2798 // Check the exact content of the message
2799 let expected_content = "These files changed since last read:\n- code.rs\n";
2800 assert_eq!(
2801 last_message.string_contents(),
2802 expected_content,
2803 "Last message should be exactly the stale buffer notification"
2804 );
2805 }
2806
2807 fn init_test_settings(cx: &mut TestAppContext) {
2808 cx.update(|cx| {
2809 let settings_store = SettingsStore::test(cx);
2810 cx.set_global(settings_store);
2811 language::init(cx);
2812 Project::init_settings(cx);
2813 AssistantSettings::register(cx);
2814 prompt_store::init(cx);
2815 thread_store::init(cx);
2816 workspace::init_settings(cx);
2817 language_model::init_settings(cx);
2818 ThemeSettings::register(cx);
2819 ContextServerSettings::register(cx);
2820 EditorSettings::register(cx);
2821 ToolRegistry::default_global(cx);
2822 });
2823 }
2824
2825 // Helper to create a test project with test files
2826 async fn create_test_project(
2827 cx: &mut TestAppContext,
2828 files: serde_json::Value,
2829 ) -> Entity<Project> {
2830 let fs = FakeFs::new(cx.executor());
2831 fs.insert_tree(path!("/test"), files).await;
2832 Project::test(fs, [path!("/test").as_ref()], cx).await
2833 }
2834
2835 async fn setup_test_environment(
2836 cx: &mut TestAppContext,
2837 project: Entity<Project>,
2838 ) -> (
2839 Entity<Workspace>,
2840 Entity<ThreadStore>,
2841 Entity<Thread>,
2842 Entity<ContextStore>,
2843 Arc<dyn LanguageModel>,
2844 ) {
2845 let (workspace, cx) =
2846 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2847
2848 let thread_store = cx
2849 .update(|_, cx| {
2850 ThreadStore::load(
2851 project.clone(),
2852 cx.new(|_| ToolWorkingSet::default()),
2853 None,
2854 Arc::new(PromptBuilder::new(None).unwrap()),
2855 cx,
2856 )
2857 })
2858 .await
2859 .unwrap();
2860
2861 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2862 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2863
2864 let model = FakeLanguageModel::default();
2865 let model: Arc<dyn LanguageModel> = Arc::new(model);
2866
2867 (workspace, thread_store, thread, context_store, model)
2868 }
2869
2870 async fn add_file_to_context(
2871 project: &Entity<Project>,
2872 context_store: &Entity<ContextStore>,
2873 path: &str,
2874 cx: &mut TestAppContext,
2875 ) -> Result<Entity<language::Buffer>> {
2876 let buffer_path = project
2877 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2878 .unwrap();
2879
2880 let buffer = project
2881 .update(cx, |project, cx| {
2882 project.open_buffer(buffer_path.clone(), cx)
2883 })
2884 .await
2885 .unwrap();
2886
2887 context_store.update(cx, |context_store, cx| {
2888 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2889 });
2890
2891 Ok(buffer)
2892 }
2893}