1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
26 StopReason, TokenUsage,
27};
28use postage::stream::Stream as _;
29use project::Project;
30use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
31use prompt_store::{ModelContext, PromptBuilder};
32use proto::Plan;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use thiserror::Error;
37use util::{ResultExt as _, TryFutureExt as _, post_inc};
38use uuid::Uuid;
39use zed_llm_client::CompletionMode;
40
41use crate::ThreadStore;
42use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
43use crate::thread_store::{
44 SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
45 SerializedToolResult, SerializedToolUse, SharedProjectContext,
46};
47use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72/// The ID of the user prompt that initiated a request.
73///
74/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
75#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
76pub struct PromptId(Arc<str>);
77
78impl PromptId {
79 pub fn new() -> Self {
80 Self(Uuid::new_v4().to_string().into())
81 }
82}
83
84impl std::fmt::Display for PromptId {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
91pub struct MessageId(pub(crate) usize);
92
93impl MessageId {
94 fn post_inc(&mut self) -> Self {
95 Self(post_inc(&mut self.0))
96 }
97}
98
99/// A message in a [`Thread`].
100#[derive(Debug, Clone)]
101pub struct Message {
102 pub id: MessageId,
103 pub role: Role,
104 pub segments: Vec<MessageSegment>,
105 pub loaded_context: LoadedContext,
106}
107
108impl Message {
109 /// Returns whether the message contains any meaningful text that should be displayed
110 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
111 pub fn should_display_content(&self) -> bool {
112 self.segments.iter().all(|segment| segment.should_display())
113 }
114
115 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
116 if let Some(MessageSegment::Thinking {
117 text: segment,
118 signature: current_signature,
119 }) = self.segments.last_mut()
120 {
121 if let Some(signature) = signature {
122 *current_signature = Some(signature);
123 }
124 segment.push_str(text);
125 } else {
126 self.segments.push(MessageSegment::Thinking {
127 text: text.to_string(),
128 signature,
129 });
130 }
131 }
132
133 pub fn push_text(&mut self, text: &str) {
134 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Text(text.to_string()));
138 }
139 }
140
141 pub fn to_string(&self) -> String {
142 let mut result = String::new();
143
144 if !self.loaded_context.text.is_empty() {
145 result.push_str(&self.loaded_context.text);
146 }
147
148 for segment in &self.segments {
149 match segment {
150 MessageSegment::Text(text) => result.push_str(text),
151 MessageSegment::Thinking { text, .. } => {
152 result.push_str("<think>\n");
153 result.push_str(text);
154 result.push_str("\n</think>");
155 }
156 MessageSegment::RedactedThinking(_) => {}
157 }
158 }
159
160 result
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum MessageSegment {
166 Text(String),
167 Thinking {
168 text: String,
169 signature: Option<String>,
170 },
171 RedactedThinking(Vec<u8>),
172}
173
174impl MessageSegment {
175 pub fn should_display(&self) -> bool {
176 match self {
177 Self::Text(text) => text.is_empty(),
178 Self::Thinking { text, .. } => text.is_empty(),
179 Self::RedactedThinking(_) => false,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ProjectSnapshot {
186 pub worktree_snapshots: Vec<WorktreeSnapshot>,
187 pub unsaved_buffer_paths: Vec<String>,
188 pub timestamp: DateTime<Utc>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct WorktreeSnapshot {
193 pub worktree_path: String,
194 pub git_state: Option<GitState>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct GitState {
199 pub remote_url: Option<String>,
200 pub head_sha: Option<String>,
201 pub current_branch: Option<String>,
202 pub diff: Option<String>,
203}
204
205#[derive(Clone)]
206pub struct ThreadCheckpoint {
207 message_id: MessageId,
208 git_checkpoint: GitStoreCheckpoint,
209}
210
211#[derive(Copy, Clone, Debug, PartialEq, Eq)]
212pub enum ThreadFeedback {
213 Positive,
214 Negative,
215}
216
217pub enum LastRestoreCheckpoint {
218 Pending {
219 message_id: MessageId,
220 },
221 Error {
222 message_id: MessageId,
223 error: String,
224 },
225}
226
227impl LastRestoreCheckpoint {
228 pub fn message_id(&self) -> MessageId {
229 match self {
230 LastRestoreCheckpoint::Pending { message_id } => *message_id,
231 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
232 }
233 }
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub enum DetailedSummaryState {
238 #[default]
239 NotGenerated,
240 Generating {
241 message_id: MessageId,
242 },
243 Generated {
244 text: SharedString,
245 message_id: MessageId,
246 },
247}
248
249impl DetailedSummaryState {
250 fn text(&self) -> Option<SharedString> {
251 if let Self::Generated { text, .. } = self {
252 Some(text.clone())
253 } else {
254 None
255 }
256 }
257}
258
259#[derive(Default)]
260pub struct TotalTokenUsage {
261 pub total: usize,
262 pub max: usize,
263}
264
265impl TotalTokenUsage {
266 pub fn ratio(&self) -> TokenUsageRatio {
267 #[cfg(debug_assertions)]
268 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
269 .unwrap_or("0.8".to_string())
270 .parse()
271 .unwrap();
272 #[cfg(not(debug_assertions))]
273 let warning_threshold: f32 = 0.8;
274
275 if self.total >= self.max {
276 TokenUsageRatio::Exceeded
277 } else if self.total as f32 / self.max as f32 >= warning_threshold {
278 TokenUsageRatio::Warning
279 } else {
280 TokenUsageRatio::Normal
281 }
282 }
283
284 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
285 TotalTokenUsage {
286 total: self.total + tokens,
287 max: self.max,
288 }
289 }
290}
291
292#[derive(Debug, Default, PartialEq, Eq)]
293pub enum TokenUsageRatio {
294 #[default]
295 Normal,
296 Warning,
297 Exceeded,
298}
299
300/// A thread of conversation with the LLM.
301pub struct Thread {
302 id: ThreadId,
303 updated_at: DateTime<Utc>,
304 summary: Option<SharedString>,
305 pending_summary: Task<Option<()>>,
306 detailed_summary_task: Task<Option<()>>,
307 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
308 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
309 completion_mode: Option<CompletionMode>,
310 messages: Vec<Message>,
311 next_message_id: MessageId,
312 last_prompt_id: PromptId,
313 project_context: SharedProjectContext,
314 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
315 completion_count: usize,
316 pending_completions: Vec<PendingCompletion>,
317 project: Entity<Project>,
318 prompt_builder: Arc<PromptBuilder>,
319 tools: Entity<ToolWorkingSet>,
320 tool_use: ToolUseState,
321 action_log: Entity<ActionLog>,
322 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
323 pending_checkpoint: Option<ThreadCheckpoint>,
324 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
325 request_token_usage: Vec<TokenUsage>,
326 cumulative_token_usage: TokenUsage,
327 exceeded_window_error: Option<ExceededWindowError>,
328 feedback: Option<ThreadFeedback>,
329 message_feedback: HashMap<MessageId, ThreadFeedback>,
330 last_auto_capture_at: Option<Instant>,
331 request_callback: Option<
332 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
333 >,
334 remaining_turns: u32,
335 configured_model: Option<ConfiguredModel>,
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct ExceededWindowError {
340 /// Model used when last message exceeded context window
341 model_id: LanguageModelId,
342 /// Token count including last message
343 token_count: usize,
344}
345
346impl Thread {
347 pub fn new(
348 project: Entity<Project>,
349 tools: Entity<ToolWorkingSet>,
350 prompt_builder: Arc<PromptBuilder>,
351 system_prompt: SharedProjectContext,
352 cx: &mut Context<Self>,
353 ) -> Self {
354 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
355 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
356
357 Self {
358 id: ThreadId::new(),
359 updated_at: Utc::now(),
360 summary: None,
361 pending_summary: Task::ready(None),
362 detailed_summary_task: Task::ready(None),
363 detailed_summary_tx,
364 detailed_summary_rx,
365 completion_mode: None,
366 messages: Vec::new(),
367 next_message_id: MessageId(0),
368 last_prompt_id: PromptId::new(),
369 project_context: system_prompt,
370 checkpoints_by_message: HashMap::default(),
371 completion_count: 0,
372 pending_completions: Vec::new(),
373 project: project.clone(),
374 prompt_builder,
375 tools: tools.clone(),
376 last_restore_checkpoint: None,
377 pending_checkpoint: None,
378 tool_use: ToolUseState::new(tools.clone()),
379 action_log: cx.new(|_| ActionLog::new(project.clone())),
380 initial_project_snapshot: {
381 let project_snapshot = Self::project_snapshot(project, cx);
382 cx.foreground_executor()
383 .spawn(async move { Some(project_snapshot.await) })
384 .shared()
385 },
386 request_token_usage: Vec::new(),
387 cumulative_token_usage: TokenUsage::default(),
388 exceeded_window_error: None,
389 feedback: None,
390 message_feedback: HashMap::default(),
391 last_auto_capture_at: None,
392 request_callback: None,
393 remaining_turns: u32::MAX,
394 configured_model,
395 }
396 }
397
398 pub fn deserialize(
399 id: ThreadId,
400 serialized: SerializedThread,
401 project: Entity<Project>,
402 tools: Entity<ToolWorkingSet>,
403 prompt_builder: Arc<PromptBuilder>,
404 project_context: SharedProjectContext,
405 cx: &mut Context<Self>,
406 ) -> Self {
407 let next_message_id = MessageId(
408 serialized
409 .messages
410 .last()
411 .map(|message| message.id.0 + 1)
412 .unwrap_or(0),
413 );
414 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
415 let (detailed_summary_tx, detailed_summary_rx) =
416 postage::watch::channel_with(serialized.detailed_summary_state);
417
418 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
419 serialized
420 .model
421 .and_then(|model| {
422 let model = SelectedModel {
423 provider: model.provider.clone().into(),
424 model: model.model.clone().into(),
425 };
426 registry.select_model(&model, cx)
427 })
428 .or_else(|| registry.default_model())
429 });
430
431 Self {
432 id,
433 updated_at: serialized.updated_at,
434 summary: Some(serialized.summary),
435 pending_summary: Task::ready(None),
436 detailed_summary_task: Task::ready(None),
437 detailed_summary_tx,
438 detailed_summary_rx,
439 completion_mode: None,
440 messages: serialized
441 .messages
442 .into_iter()
443 .map(|message| Message {
444 id: message.id,
445 role: message.role,
446 segments: message
447 .segments
448 .into_iter()
449 .map(|segment| match segment {
450 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
451 SerializedMessageSegment::Thinking { text, signature } => {
452 MessageSegment::Thinking { text, signature }
453 }
454 SerializedMessageSegment::RedactedThinking { data } => {
455 MessageSegment::RedactedThinking(data)
456 }
457 })
458 .collect(),
459 loaded_context: LoadedContext {
460 contexts: Vec::new(),
461 text: message.context,
462 images: Vec::new(),
463 },
464 })
465 .collect(),
466 next_message_id,
467 last_prompt_id: PromptId::new(),
468 project_context,
469 checkpoints_by_message: HashMap::default(),
470 completion_count: 0,
471 pending_completions: Vec::new(),
472 last_restore_checkpoint: None,
473 pending_checkpoint: None,
474 project: project.clone(),
475 prompt_builder,
476 tools,
477 tool_use,
478 action_log: cx.new(|_| ActionLog::new(project)),
479 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
480 request_token_usage: serialized.request_token_usage,
481 cumulative_token_usage: serialized.cumulative_token_usage,
482 exceeded_window_error: None,
483 feedback: None,
484 message_feedback: HashMap::default(),
485 last_auto_capture_at: None,
486 request_callback: None,
487 remaining_turns: u32::MAX,
488 configured_model,
489 }
490 }
491
492 pub fn set_request_callback(
493 &mut self,
494 callback: impl 'static
495 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
496 ) {
497 self.request_callback = Some(Box::new(callback));
498 }
499
500 pub fn id(&self) -> &ThreadId {
501 &self.id
502 }
503
504 pub fn is_empty(&self) -> bool {
505 self.messages.is_empty()
506 }
507
508 pub fn updated_at(&self) -> DateTime<Utc> {
509 self.updated_at
510 }
511
512 pub fn touch_updated_at(&mut self) {
513 self.updated_at = Utc::now();
514 }
515
516 pub fn advance_prompt_id(&mut self) {
517 self.last_prompt_id = PromptId::new();
518 }
519
520 pub fn summary(&self) -> Option<SharedString> {
521 self.summary.clone()
522 }
523
524 pub fn project_context(&self) -> SharedProjectContext {
525 self.project_context.clone()
526 }
527
528 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
529 if self.configured_model.is_none() {
530 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
531 }
532 self.configured_model.clone()
533 }
534
535 pub fn configured_model(&self) -> Option<ConfiguredModel> {
536 self.configured_model.clone()
537 }
538
539 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
540 self.configured_model = model;
541 cx.notify();
542 }
543
544 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
545
546 pub fn summary_or_default(&self) -> SharedString {
547 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
548 }
549
550 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
551 let Some(current_summary) = &self.summary else {
552 // Don't allow setting summary until generated
553 return;
554 };
555
556 let mut new_summary = new_summary.into();
557
558 if new_summary.is_empty() {
559 new_summary = Self::DEFAULT_SUMMARY;
560 }
561
562 if current_summary != &new_summary {
563 self.summary = Some(new_summary);
564 cx.emit(ThreadEvent::SummaryChanged);
565 }
566 }
567
568 pub fn completion_mode(&self) -> Option<CompletionMode> {
569 self.completion_mode
570 }
571
572 pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
573 self.completion_mode = mode;
574 }
575
576 pub fn message(&self, id: MessageId) -> Option<&Message> {
577 let index = self
578 .messages
579 .binary_search_by(|message| message.id.cmp(&id))
580 .ok()?;
581
582 self.messages.get(index)
583 }
584
585 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
586 self.messages.iter()
587 }
588
589 pub fn is_generating(&self) -> bool {
590 !self.pending_completions.is_empty() || !self.all_tools_finished()
591 }
592
593 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
594 &self.tools
595 }
596
597 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
598 self.tool_use
599 .pending_tool_uses()
600 .into_iter()
601 .find(|tool_use| &tool_use.id == id)
602 }
603
604 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
605 self.tool_use
606 .pending_tool_uses()
607 .into_iter()
608 .filter(|tool_use| tool_use.status.needs_confirmation())
609 }
610
611 pub fn has_pending_tool_uses(&self) -> bool {
612 !self.tool_use.pending_tool_uses().is_empty()
613 }
614
615 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
616 self.checkpoints_by_message.get(&id).cloned()
617 }
618
619 pub fn restore_checkpoint(
620 &mut self,
621 checkpoint: ThreadCheckpoint,
622 cx: &mut Context<Self>,
623 ) -> Task<Result<()>> {
624 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
625 message_id: checkpoint.message_id,
626 });
627 cx.emit(ThreadEvent::CheckpointChanged);
628 cx.notify();
629
630 let git_store = self.project().read(cx).git_store().clone();
631 let restore = git_store.update(cx, |git_store, cx| {
632 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
633 });
634
635 cx.spawn(async move |this, cx| {
636 let result = restore.await;
637 this.update(cx, |this, cx| {
638 if let Err(err) = result.as_ref() {
639 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
640 message_id: checkpoint.message_id,
641 error: err.to_string(),
642 });
643 } else {
644 this.truncate(checkpoint.message_id, cx);
645 this.last_restore_checkpoint = None;
646 }
647 this.pending_checkpoint = None;
648 cx.emit(ThreadEvent::CheckpointChanged);
649 cx.notify();
650 })?;
651 result
652 })
653 }
654
655 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
656 let pending_checkpoint = if self.is_generating() {
657 return;
658 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
659 checkpoint
660 } else {
661 return;
662 };
663
664 let git_store = self.project.read(cx).git_store().clone();
665 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
666 cx.spawn(async move |this, cx| match final_checkpoint.await {
667 Ok(final_checkpoint) => {
668 let equal = git_store
669 .update(cx, |store, cx| {
670 store.compare_checkpoints(
671 pending_checkpoint.git_checkpoint.clone(),
672 final_checkpoint.clone(),
673 cx,
674 )
675 })?
676 .await
677 .unwrap_or(false);
678
679 if !equal {
680 this.update(cx, |this, cx| {
681 this.insert_checkpoint(pending_checkpoint, cx)
682 })?;
683 }
684
685 Ok(())
686 }
687 Err(_) => this.update(cx, |this, cx| {
688 this.insert_checkpoint(pending_checkpoint, cx)
689 }),
690 })
691 .detach();
692 }
693
694 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
695 self.checkpoints_by_message
696 .insert(checkpoint.message_id, checkpoint);
697 cx.emit(ThreadEvent::CheckpointChanged);
698 cx.notify();
699 }
700
701 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
702 self.last_restore_checkpoint.as_ref()
703 }
704
705 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
706 let Some(message_ix) = self
707 .messages
708 .iter()
709 .rposition(|message| message.id == message_id)
710 else {
711 return;
712 };
713 for deleted_message in self.messages.drain(message_ix..) {
714 self.checkpoints_by_message.remove(&deleted_message.id);
715 }
716 cx.notify();
717 }
718
719 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
720 self.messages
721 .iter()
722 .find(|message| message.id == id)
723 .into_iter()
724 .flat_map(|message| message.loaded_context.contexts.iter())
725 }
726
727 pub fn is_turn_end(&self, ix: usize) -> bool {
728 if self.messages.is_empty() {
729 return false;
730 }
731
732 if !self.is_generating() && ix == self.messages.len() - 1 {
733 return true;
734 }
735
736 let Some(message) = self.messages.get(ix) else {
737 return false;
738 };
739
740 if message.role != Role::Assistant {
741 return false;
742 }
743
744 self.messages
745 .get(ix + 1)
746 .and_then(|message| {
747 self.message(message.id)
748 .map(|next_message| next_message.role == Role::User)
749 })
750 .unwrap_or(false)
751 }
752
753 /// Returns whether all of the tool uses have finished running.
754 pub fn all_tools_finished(&self) -> bool {
755 // If the only pending tool uses left are the ones with errors, then
756 // that means that we've finished running all of the pending tools.
757 self.tool_use
758 .pending_tool_uses()
759 .iter()
760 .all(|tool_use| tool_use.status.is_error())
761 }
762
763 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
764 self.tool_use.tool_uses_for_message(id, cx)
765 }
766
767 pub fn tool_results_for_message(
768 &self,
769 assistant_message_id: MessageId,
770 ) -> Vec<&LanguageModelToolResult> {
771 self.tool_use.tool_results_for_message(assistant_message_id)
772 }
773
774 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
775 self.tool_use.tool_result(id)
776 }
777
778 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
779 Some(&self.tool_use.tool_result(id)?.content)
780 }
781
782 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
783 self.tool_use.tool_result_card(id).cloned()
784 }
785
786 /// Return tools that are both enabled and supported by the model
787 pub fn available_tools(
788 &self,
789 cx: &App,
790 model: Arc<dyn LanguageModel>,
791 ) -> Vec<LanguageModelRequestTool> {
792 if model.supports_tools() {
793 self.tools()
794 .read(cx)
795 .enabled_tools(cx)
796 .into_iter()
797 .filter_map(|tool| {
798 // Skip tools that cannot be supported
799 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
800 Some(LanguageModelRequestTool {
801 name: tool.name(),
802 description: tool.description(),
803 input_schema,
804 })
805 })
806 .collect()
807 } else {
808 Vec::default()
809 }
810 }
811
812 pub fn insert_user_message(
813 &mut self,
814 text: impl Into<String>,
815 loaded_context: ContextLoadResult,
816 git_checkpoint: Option<GitStoreCheckpoint>,
817 cx: &mut Context<Self>,
818 ) -> MessageId {
819 if !loaded_context.referenced_buffers.is_empty() {
820 self.action_log.update(cx, |log, cx| {
821 for buffer in loaded_context.referenced_buffers {
822 log.track_buffer(buffer, cx);
823 }
824 });
825 }
826
827 let message_id = self.insert_message(
828 Role::User,
829 vec![MessageSegment::Text(text.into())],
830 loaded_context.loaded_context,
831 cx,
832 );
833
834 if let Some(git_checkpoint) = git_checkpoint {
835 self.pending_checkpoint = Some(ThreadCheckpoint {
836 message_id,
837 git_checkpoint,
838 });
839 }
840
841 self.auto_capture_telemetry(cx);
842
843 message_id
844 }
845
846 pub fn insert_assistant_message(
847 &mut self,
848 segments: Vec<MessageSegment>,
849 cx: &mut Context<Self>,
850 ) -> MessageId {
851 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
852 }
853
854 pub fn insert_message(
855 &mut self,
856 role: Role,
857 segments: Vec<MessageSegment>,
858 loaded_context: LoadedContext,
859 cx: &mut Context<Self>,
860 ) -> MessageId {
861 let id = self.next_message_id.post_inc();
862 self.messages.push(Message {
863 id,
864 role,
865 segments,
866 loaded_context,
867 });
868 self.touch_updated_at();
869 cx.emit(ThreadEvent::MessageAdded(id));
870 id
871 }
872
873 pub fn edit_message(
874 &mut self,
875 id: MessageId,
876 new_role: Role,
877 new_segments: Vec<MessageSegment>,
878 cx: &mut Context<Self>,
879 ) -> bool {
880 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
881 return false;
882 };
883 message.role = new_role;
884 message.segments = new_segments;
885 self.touch_updated_at();
886 cx.emit(ThreadEvent::MessageEdited(id));
887 true
888 }
889
890 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
891 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
892 return false;
893 };
894 self.messages.remove(index);
895 self.touch_updated_at();
896 cx.emit(ThreadEvent::MessageDeleted(id));
897 true
898 }
899
900 /// Returns the representation of this [`Thread`] in a textual form.
901 ///
902 /// This is the representation we use when attaching a thread as context to another thread.
903 pub fn text(&self) -> String {
904 let mut text = String::new();
905
906 for message in &self.messages {
907 text.push_str(match message.role {
908 language_model::Role::User => "User:",
909 language_model::Role::Assistant => "Assistant:",
910 language_model::Role::System => "System:",
911 });
912 text.push('\n');
913
914 for segment in &message.segments {
915 match segment {
916 MessageSegment::Text(content) => text.push_str(content),
917 MessageSegment::Thinking { text: content, .. } => {
918 text.push_str(&format!("<think>{}</think>", content))
919 }
920 MessageSegment::RedactedThinking(_) => {}
921 }
922 }
923 text.push('\n');
924 }
925
926 text
927 }
928
929 /// Serializes this thread into a format for storage or telemetry.
930 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
931 let initial_project_snapshot = self.initial_project_snapshot.clone();
932 cx.spawn(async move |this, cx| {
933 let initial_project_snapshot = initial_project_snapshot.await;
934 this.read_with(cx, |this, cx| SerializedThread {
935 version: SerializedThread::VERSION.to_string(),
936 summary: this.summary_or_default(),
937 updated_at: this.updated_at(),
938 messages: this
939 .messages()
940 .map(|message| SerializedMessage {
941 id: message.id,
942 role: message.role,
943 segments: message
944 .segments
945 .iter()
946 .map(|segment| match segment {
947 MessageSegment::Text(text) => {
948 SerializedMessageSegment::Text { text: text.clone() }
949 }
950 MessageSegment::Thinking { text, signature } => {
951 SerializedMessageSegment::Thinking {
952 text: text.clone(),
953 signature: signature.clone(),
954 }
955 }
956 MessageSegment::RedactedThinking(data) => {
957 SerializedMessageSegment::RedactedThinking {
958 data: data.clone(),
959 }
960 }
961 })
962 .collect(),
963 tool_uses: this
964 .tool_uses_for_message(message.id, cx)
965 .into_iter()
966 .map(|tool_use| SerializedToolUse {
967 id: tool_use.id,
968 name: tool_use.name,
969 input: tool_use.input,
970 })
971 .collect(),
972 tool_results: this
973 .tool_results_for_message(message.id)
974 .into_iter()
975 .map(|tool_result| SerializedToolResult {
976 tool_use_id: tool_result.tool_use_id.clone(),
977 is_error: tool_result.is_error,
978 content: tool_result.content.clone(),
979 })
980 .collect(),
981 context: message.loaded_context.text.clone(),
982 })
983 .collect(),
984 initial_project_snapshot,
985 cumulative_token_usage: this.cumulative_token_usage,
986 request_token_usage: this.request_token_usage.clone(),
987 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
988 exceeded_window_error: this.exceeded_window_error.clone(),
989 model: this
990 .configured_model
991 .as_ref()
992 .map(|model| SerializedLanguageModel {
993 provider: model.provider.id().0.to_string(),
994 model: model.model.id().0.to_string(),
995 }),
996 })
997 })
998 }
999
1000 pub fn remaining_turns(&self) -> u32 {
1001 self.remaining_turns
1002 }
1003
1004 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1005 self.remaining_turns = remaining_turns;
1006 }
1007
1008 pub fn send_to_model(
1009 &mut self,
1010 model: Arc<dyn LanguageModel>,
1011 window: Option<AnyWindowHandle>,
1012 cx: &mut Context<Self>,
1013 ) {
1014 if self.remaining_turns == 0 {
1015 return;
1016 }
1017
1018 self.remaining_turns -= 1;
1019
1020 let request = self.to_completion_request(model.clone(), cx);
1021
1022 self.stream_completion(request, model, window, cx);
1023 }
1024
1025 pub fn used_tools_since_last_user_message(&self) -> bool {
1026 for message in self.messages.iter().rev() {
1027 if self.tool_use.message_has_tool_results(message.id) {
1028 return true;
1029 } else if message.role == Role::User {
1030 return false;
1031 }
1032 }
1033
1034 false
1035 }
1036
1037 pub fn to_completion_request(
1038 &self,
1039 model: Arc<dyn LanguageModel>,
1040 cx: &mut Context<Self>,
1041 ) -> LanguageModelRequest {
1042 let mut request = LanguageModelRequest {
1043 thread_id: Some(self.id.to_string()),
1044 prompt_id: Some(self.last_prompt_id.to_string()),
1045 mode: None,
1046 messages: vec![],
1047 tools: Vec::new(),
1048 stop: Vec::new(),
1049 temperature: None,
1050 };
1051
1052 let available_tools = self.available_tools(cx, model.clone());
1053 let available_tool_names = available_tools
1054 .iter()
1055 .map(|tool| tool.name.clone())
1056 .collect();
1057
1058 let model_context = &ModelContext {
1059 available_tools: available_tool_names,
1060 };
1061
1062 if let Some(project_context) = self.project_context.borrow().as_ref() {
1063 match self
1064 .prompt_builder
1065 .generate_assistant_system_prompt(project_context, model_context)
1066 {
1067 Err(err) => {
1068 let message = format!("{err:?}").into();
1069 log::error!("{message}");
1070 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1071 header: "Error generating system prompt".into(),
1072 message,
1073 }));
1074 }
1075 Ok(system_prompt) => {
1076 request.messages.push(LanguageModelRequestMessage {
1077 role: Role::System,
1078 content: vec![MessageContent::Text(system_prompt)],
1079 cache: true,
1080 });
1081 }
1082 }
1083 } else {
1084 let message = "Context for system prompt unexpectedly not ready.".into();
1085 log::error!("{message}");
1086 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1087 header: "Error generating system prompt".into(),
1088 message,
1089 }));
1090 }
1091
1092 for message in &self.messages {
1093 let mut request_message = LanguageModelRequestMessage {
1094 role: message.role,
1095 content: Vec::new(),
1096 cache: false,
1097 };
1098
1099 message
1100 .loaded_context
1101 .add_to_request_message(&mut request_message);
1102
1103 for segment in &message.segments {
1104 match segment {
1105 MessageSegment::Text(text) => {
1106 if !text.is_empty() {
1107 request_message
1108 .content
1109 .push(MessageContent::Text(text.into()));
1110 }
1111 }
1112 MessageSegment::Thinking { text, signature } => {
1113 if !text.is_empty() {
1114 request_message.content.push(MessageContent::Thinking {
1115 text: text.into(),
1116 signature: signature.clone(),
1117 });
1118 }
1119 }
1120 MessageSegment::RedactedThinking(data) => {
1121 request_message
1122 .content
1123 .push(MessageContent::RedactedThinking(data.clone()));
1124 }
1125 };
1126 }
1127
1128 self.tool_use
1129 .attach_tool_uses(message.id, &mut request_message);
1130
1131 request.messages.push(request_message);
1132
1133 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1134 request.messages.push(tool_results_message);
1135 }
1136 }
1137
1138 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1139 if let Some(last) = request.messages.last_mut() {
1140 last.cache = true;
1141 }
1142
1143 self.attached_tracked_files_state(&mut request.messages, cx);
1144
1145 request.tools = available_tools;
1146 request.mode = if model.supports_max_mode() {
1147 self.completion_mode
1148 } else {
1149 None
1150 };
1151
1152 request
1153 }
1154
1155 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1156 let mut request = LanguageModelRequest {
1157 thread_id: None,
1158 prompt_id: None,
1159 mode: None,
1160 messages: vec![],
1161 tools: Vec::new(),
1162 stop: Vec::new(),
1163 temperature: None,
1164 };
1165
1166 for message in &self.messages {
1167 let mut request_message = LanguageModelRequestMessage {
1168 role: message.role,
1169 content: Vec::new(),
1170 cache: false,
1171 };
1172
1173 for segment in &message.segments {
1174 match segment {
1175 MessageSegment::Text(text) => request_message
1176 .content
1177 .push(MessageContent::Text(text.clone())),
1178 MessageSegment::Thinking { .. } => {}
1179 MessageSegment::RedactedThinking(_) => {}
1180 }
1181 }
1182
1183 if request_message.content.is_empty() {
1184 continue;
1185 }
1186
1187 request.messages.push(request_message);
1188 }
1189
1190 request.messages.push(LanguageModelRequestMessage {
1191 role: Role::User,
1192 content: vec![MessageContent::Text(added_user_message)],
1193 cache: false,
1194 });
1195
1196 request
1197 }
1198
1199 fn attached_tracked_files_state(
1200 &self,
1201 messages: &mut Vec<LanguageModelRequestMessage>,
1202 cx: &App,
1203 ) {
1204 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1205
1206 let mut stale_message = String::new();
1207
1208 let action_log = self.action_log.read(cx);
1209
1210 for stale_file in action_log.stale_buffers(cx) {
1211 let Some(file) = stale_file.read(cx).file() else {
1212 continue;
1213 };
1214
1215 if stale_message.is_empty() {
1216 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1217 }
1218
1219 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1220 }
1221
1222 let mut content = Vec::with_capacity(2);
1223
1224 if !stale_message.is_empty() {
1225 content.push(stale_message.into());
1226 }
1227
1228 if !content.is_empty() {
1229 let context_message = LanguageModelRequestMessage {
1230 role: Role::User,
1231 content,
1232 cache: false,
1233 };
1234
1235 messages.push(context_message);
1236 }
1237 }
1238
1239 pub fn stream_completion(
1240 &mut self,
1241 request: LanguageModelRequest,
1242 model: Arc<dyn LanguageModel>,
1243 window: Option<AnyWindowHandle>,
1244 cx: &mut Context<Self>,
1245 ) {
1246 let pending_completion_id = post_inc(&mut self.completion_count);
1247 let mut request_callback_parameters = if self.request_callback.is_some() {
1248 Some((request.clone(), Vec::new()))
1249 } else {
1250 None
1251 };
1252 let prompt_id = self.last_prompt_id.clone();
1253 let tool_use_metadata = ToolUseMetadata {
1254 model: model.clone(),
1255 thread_id: self.id.clone(),
1256 prompt_id: prompt_id.clone(),
1257 };
1258
1259 let task = cx.spawn(async move |thread, cx| {
1260 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1261 let initial_token_usage =
1262 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1263 let stream_completion = async {
1264 let (mut events, usage) = stream_completion_future.await?;
1265
1266 let mut stop_reason = StopReason::EndTurn;
1267 let mut current_token_usage = TokenUsage::default();
1268
1269 if let Some(usage) = usage {
1270 thread
1271 .update(cx, |_thread, cx| {
1272 cx.emit(ThreadEvent::UsageUpdated(usage));
1273 })
1274 .ok();
1275 }
1276
1277 let mut request_assistant_message_id = None;
1278
1279 while let Some(event) = events.next().await {
1280 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1281 response_events
1282 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1283 }
1284
1285 thread.update(cx, |thread, cx| {
1286 let event = match event {
1287 Ok(event) => event,
1288 Err(LanguageModelCompletionError::BadInputJson {
1289 id,
1290 tool_name,
1291 raw_input: invalid_input_json,
1292 json_parse_error,
1293 }) => {
1294 thread.receive_invalid_tool_json(
1295 id,
1296 tool_name,
1297 invalid_input_json,
1298 json_parse_error,
1299 window,
1300 cx,
1301 );
1302 return Ok(());
1303 }
1304 Err(LanguageModelCompletionError::Other(error)) => {
1305 return Err(error);
1306 }
1307 };
1308
1309 match event {
1310 LanguageModelCompletionEvent::StartMessage { .. } => {
1311 request_assistant_message_id =
1312 Some(thread.insert_assistant_message(
1313 vec![MessageSegment::Text(String::new())],
1314 cx,
1315 ));
1316 }
1317 LanguageModelCompletionEvent::Stop(reason) => {
1318 stop_reason = reason;
1319 }
1320 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1321 thread.update_token_usage_at_last_message(token_usage);
1322 thread.cumulative_token_usage = thread.cumulative_token_usage
1323 + token_usage
1324 - current_token_usage;
1325 current_token_usage = token_usage;
1326 }
1327 LanguageModelCompletionEvent::Text(chunk) => {
1328 cx.emit(ThreadEvent::ReceivedTextChunk);
1329 if let Some(last_message) = thread.messages.last_mut() {
1330 if last_message.role == Role::Assistant
1331 && !thread.tool_use.has_tool_results(last_message.id)
1332 {
1333 last_message.push_text(&chunk);
1334 cx.emit(ThreadEvent::StreamedAssistantText(
1335 last_message.id,
1336 chunk,
1337 ));
1338 } else {
1339 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1340 // of a new Assistant response.
1341 //
1342 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1343 // will result in duplicating the text of the chunk in the rendered Markdown.
1344 request_assistant_message_id =
1345 Some(thread.insert_assistant_message(
1346 vec![MessageSegment::Text(chunk.to_string())],
1347 cx,
1348 ));
1349 };
1350 }
1351 }
1352 LanguageModelCompletionEvent::Thinking {
1353 text: chunk,
1354 signature,
1355 } => {
1356 if let Some(last_message) = thread.messages.last_mut() {
1357 if last_message.role == Role::Assistant
1358 && !thread.tool_use.has_tool_results(last_message.id)
1359 {
1360 last_message.push_thinking(&chunk, signature);
1361 cx.emit(ThreadEvent::StreamedAssistantThinking(
1362 last_message.id,
1363 chunk,
1364 ));
1365 } else {
1366 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1367 // of a new Assistant response.
1368 //
1369 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1370 // will result in duplicating the text of the chunk in the rendered Markdown.
1371 request_assistant_message_id =
1372 Some(thread.insert_assistant_message(
1373 vec![MessageSegment::Thinking {
1374 text: chunk.to_string(),
1375 signature,
1376 }],
1377 cx,
1378 ));
1379 };
1380 }
1381 }
1382 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1383 let last_assistant_message_id = request_assistant_message_id
1384 .unwrap_or_else(|| {
1385 let new_assistant_message_id =
1386 thread.insert_assistant_message(vec![], cx);
1387 request_assistant_message_id =
1388 Some(new_assistant_message_id);
1389 new_assistant_message_id
1390 });
1391
1392 let tool_use_id = tool_use.id.clone();
1393 let streamed_input = if tool_use.is_input_complete {
1394 None
1395 } else {
1396 Some((&tool_use.input).clone())
1397 };
1398
1399 let ui_text = thread.tool_use.request_tool_use(
1400 last_assistant_message_id,
1401 tool_use,
1402 tool_use_metadata.clone(),
1403 cx,
1404 );
1405
1406 if let Some(input) = streamed_input {
1407 cx.emit(ThreadEvent::StreamedToolUse {
1408 tool_use_id,
1409 ui_text,
1410 input,
1411 });
1412 }
1413 }
1414 }
1415
1416 thread.touch_updated_at();
1417 cx.emit(ThreadEvent::StreamedCompletion);
1418 cx.notify();
1419
1420 thread.auto_capture_telemetry(cx);
1421 Ok(())
1422 })??;
1423
1424 smol::future::yield_now().await;
1425 }
1426
1427 thread.update(cx, |thread, cx| {
1428 thread
1429 .pending_completions
1430 .retain(|completion| completion.id != pending_completion_id);
1431
1432 // If there is a response without tool use, summarize the message. Otherwise,
1433 // allow two tool uses before summarizing.
1434 if thread.summary.is_none()
1435 && thread.messages.len() >= 2
1436 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1437 {
1438 thread.summarize(cx);
1439 }
1440 })?;
1441
1442 anyhow::Ok(stop_reason)
1443 };
1444
1445 let result = stream_completion.await;
1446
1447 thread
1448 .update(cx, |thread, cx| {
1449 thread.finalize_pending_checkpoint(cx);
1450 match result.as_ref() {
1451 Ok(stop_reason) => match stop_reason {
1452 StopReason::ToolUse => {
1453 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1454 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1455 }
1456 StopReason::EndTurn => {}
1457 StopReason::MaxTokens => {}
1458 },
1459 Err(error) => {
1460 if error.is::<PaymentRequiredError>() {
1461 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1462 } else if error.is::<MaxMonthlySpendReachedError>() {
1463 cx.emit(ThreadEvent::ShowError(
1464 ThreadError::MaxMonthlySpendReached,
1465 ));
1466 } else if let Some(error) =
1467 error.downcast_ref::<ModelRequestLimitReachedError>()
1468 {
1469 cx.emit(ThreadEvent::ShowError(
1470 ThreadError::ModelRequestLimitReached { plan: error.plan },
1471 ));
1472 } else if let Some(known_error) =
1473 error.downcast_ref::<LanguageModelKnownError>()
1474 {
1475 match known_error {
1476 LanguageModelKnownError::ContextWindowLimitExceeded {
1477 tokens,
1478 } => {
1479 thread.exceeded_window_error = Some(ExceededWindowError {
1480 model_id: model.id(),
1481 token_count: *tokens,
1482 });
1483 cx.notify();
1484 }
1485 }
1486 } else {
1487 let error_message = error
1488 .chain()
1489 .map(|err| err.to_string())
1490 .collect::<Vec<_>>()
1491 .join("\n");
1492 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1493 header: "Error interacting with language model".into(),
1494 message: SharedString::from(error_message.clone()),
1495 }));
1496 }
1497
1498 thread.cancel_last_completion(window, cx);
1499 }
1500 }
1501 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1502
1503 if let Some((request_callback, (request, response_events))) = thread
1504 .request_callback
1505 .as_mut()
1506 .zip(request_callback_parameters.as_ref())
1507 {
1508 request_callback(request, response_events);
1509 }
1510
1511 thread.auto_capture_telemetry(cx);
1512
1513 if let Ok(initial_usage) = initial_token_usage {
1514 let usage = thread.cumulative_token_usage - initial_usage;
1515
1516 telemetry::event!(
1517 "Assistant Thread Completion",
1518 thread_id = thread.id().to_string(),
1519 prompt_id = prompt_id,
1520 model = model.telemetry_id(),
1521 model_provider = model.provider_id().to_string(),
1522 input_tokens = usage.input_tokens,
1523 output_tokens = usage.output_tokens,
1524 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1525 cache_read_input_tokens = usage.cache_read_input_tokens,
1526 );
1527 }
1528 })
1529 .ok();
1530 });
1531
1532 self.pending_completions.push(PendingCompletion {
1533 id: pending_completion_id,
1534 _task: task,
1535 });
1536 }
1537
1538 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1539 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1540 return;
1541 };
1542
1543 if !model.provider.is_authenticated(cx) {
1544 return;
1545 }
1546
1547 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1548 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1549 If the conversation is about a specific subject, include it in the title. \
1550 Be descriptive. DO NOT speak in the first person.";
1551
1552 let request = self.to_summarize_request(added_user_message.into());
1553
1554 self.pending_summary = cx.spawn(async move |this, cx| {
1555 async move {
1556 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1557 let (mut messages, usage) = stream.await?;
1558
1559 if let Some(usage) = usage {
1560 this.update(cx, |_thread, cx| {
1561 cx.emit(ThreadEvent::UsageUpdated(usage));
1562 })
1563 .ok();
1564 }
1565
1566 let mut new_summary = String::new();
1567 while let Some(message) = messages.stream.next().await {
1568 let text = message?;
1569 let mut lines = text.lines();
1570 new_summary.extend(lines.next());
1571
1572 // Stop if the LLM generated multiple lines.
1573 if lines.next().is_some() {
1574 break;
1575 }
1576 }
1577
1578 this.update(cx, |this, cx| {
1579 if !new_summary.is_empty() {
1580 this.summary = Some(new_summary.into());
1581 }
1582
1583 cx.emit(ThreadEvent::SummaryGenerated);
1584 })?;
1585
1586 anyhow::Ok(())
1587 }
1588 .log_err()
1589 .await
1590 });
1591 }
1592
1593 pub fn start_generating_detailed_summary_if_needed(
1594 &mut self,
1595 thread_store: WeakEntity<ThreadStore>,
1596 cx: &mut Context<Self>,
1597 ) {
1598 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1599 return;
1600 };
1601
1602 match &*self.detailed_summary_rx.borrow() {
1603 DetailedSummaryState::Generating { message_id, .. }
1604 | DetailedSummaryState::Generated { message_id, .. }
1605 if *message_id == last_message_id =>
1606 {
1607 // Already up-to-date
1608 return;
1609 }
1610 _ => {}
1611 }
1612
1613 let Some(ConfiguredModel { model, provider }) =
1614 LanguageModelRegistry::read_global(cx).thread_summary_model()
1615 else {
1616 return;
1617 };
1618
1619 if !provider.is_authenticated(cx) {
1620 return;
1621 }
1622
1623 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1624 1. A brief overview of what was discussed\n\
1625 2. Key facts or information discovered\n\
1626 3. Outcomes or conclusions reached\n\
1627 4. Any action items or next steps if any\n\
1628 Format it in Markdown with headings and bullet points.";
1629
1630 let request = self.to_summarize_request(added_user_message.into());
1631
1632 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1633 message_id: last_message_id,
1634 };
1635
1636 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1637 // be better to allow the old task to complete, but this would require logic for choosing
1638 // which result to prefer (the old task could complete after the new one, resulting in a
1639 // stale summary).
1640 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1641 let stream = model.stream_completion_text(request, &cx);
1642 let Some(mut messages) = stream.await.log_err() else {
1643 thread
1644 .update(cx, |thread, _cx| {
1645 *thread.detailed_summary_tx.borrow_mut() =
1646 DetailedSummaryState::NotGenerated;
1647 })
1648 .ok()?;
1649 return None;
1650 };
1651
1652 let mut new_detailed_summary = String::new();
1653
1654 while let Some(chunk) = messages.stream.next().await {
1655 if let Some(chunk) = chunk.log_err() {
1656 new_detailed_summary.push_str(&chunk);
1657 }
1658 }
1659
1660 thread
1661 .update(cx, |thread, _cx| {
1662 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1663 text: new_detailed_summary.into(),
1664 message_id: last_message_id,
1665 };
1666 })
1667 .ok()?;
1668
1669 // Save thread so its summary can be reused later
1670 if let Some(thread) = thread.upgrade() {
1671 if let Ok(Ok(save_task)) = cx.update(|cx| {
1672 thread_store
1673 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1674 }) {
1675 save_task.await.log_err();
1676 }
1677 }
1678
1679 Some(())
1680 });
1681 }
1682
1683 pub async fn wait_for_detailed_summary_or_text(
1684 this: &Entity<Self>,
1685 cx: &mut AsyncApp,
1686 ) -> Option<SharedString> {
1687 let mut detailed_summary_rx = this
1688 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1689 .ok()?;
1690 loop {
1691 match detailed_summary_rx.recv().await? {
1692 DetailedSummaryState::Generating { .. } => {}
1693 DetailedSummaryState::NotGenerated => {
1694 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1695 }
1696 DetailedSummaryState::Generated { text, .. } => return Some(text),
1697 }
1698 }
1699 }
1700
1701 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1702 self.detailed_summary_rx
1703 .borrow()
1704 .text()
1705 .unwrap_or_else(|| self.text().into())
1706 }
1707
1708 pub fn is_generating_detailed_summary(&self) -> bool {
1709 matches!(
1710 &*self.detailed_summary_rx.borrow(),
1711 DetailedSummaryState::Generating { .. }
1712 )
1713 }
1714
1715 pub fn use_pending_tools(
1716 &mut self,
1717 window: Option<AnyWindowHandle>,
1718 cx: &mut Context<Self>,
1719 model: Arc<dyn LanguageModel>,
1720 ) -> Vec<PendingToolUse> {
1721 self.auto_capture_telemetry(cx);
1722 let request = self.to_completion_request(model, cx);
1723 let messages = Arc::new(request.messages);
1724 let pending_tool_uses = self
1725 .tool_use
1726 .pending_tool_uses()
1727 .into_iter()
1728 .filter(|tool_use| tool_use.status.is_idle())
1729 .cloned()
1730 .collect::<Vec<_>>();
1731
1732 for tool_use in pending_tool_uses.iter() {
1733 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1734 if tool.needs_confirmation(&tool_use.input, cx)
1735 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1736 {
1737 self.tool_use.confirm_tool_use(
1738 tool_use.id.clone(),
1739 tool_use.ui_text.clone(),
1740 tool_use.input.clone(),
1741 messages.clone(),
1742 tool,
1743 );
1744 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1745 } else {
1746 self.run_tool(
1747 tool_use.id.clone(),
1748 tool_use.ui_text.clone(),
1749 tool_use.input.clone(),
1750 &messages,
1751 tool,
1752 window,
1753 cx,
1754 );
1755 }
1756 }
1757 }
1758
1759 pending_tool_uses
1760 }
1761
1762 pub fn receive_invalid_tool_json(
1763 &mut self,
1764 tool_use_id: LanguageModelToolUseId,
1765 tool_name: Arc<str>,
1766 invalid_json: Arc<str>,
1767 error: String,
1768 window: Option<AnyWindowHandle>,
1769 cx: &mut Context<Thread>,
1770 ) {
1771 log::error!("The model returned invalid input JSON: {invalid_json}");
1772
1773 let pending_tool_use = self.tool_use.insert_tool_output(
1774 tool_use_id.clone(),
1775 tool_name,
1776 Err(anyhow!("Error parsing input JSON: {error}")),
1777 self.configured_model.as_ref(),
1778 );
1779 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1780 pending_tool_use.ui_text.clone()
1781 } else {
1782 log::error!(
1783 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1784 );
1785 format!("Unknown tool {}", tool_use_id).into()
1786 };
1787
1788 cx.emit(ThreadEvent::InvalidToolInput {
1789 tool_use_id: tool_use_id.clone(),
1790 ui_text,
1791 invalid_input_json: invalid_json,
1792 });
1793
1794 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1795 }
1796
1797 pub fn run_tool(
1798 &mut self,
1799 tool_use_id: LanguageModelToolUseId,
1800 ui_text: impl Into<SharedString>,
1801 input: serde_json::Value,
1802 messages: &[LanguageModelRequestMessage],
1803 tool: Arc<dyn Tool>,
1804 window: Option<AnyWindowHandle>,
1805 cx: &mut Context<Thread>,
1806 ) {
1807 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1808 self.tool_use
1809 .run_pending_tool(tool_use_id, ui_text.into(), task);
1810 }
1811
1812 fn spawn_tool_use(
1813 &mut self,
1814 tool_use_id: LanguageModelToolUseId,
1815 messages: &[LanguageModelRequestMessage],
1816 input: serde_json::Value,
1817 tool: Arc<dyn Tool>,
1818 window: Option<AnyWindowHandle>,
1819 cx: &mut Context<Thread>,
1820 ) -> Task<()> {
1821 let tool_name: Arc<str> = tool.name().into();
1822
1823 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1824 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1825 } else {
1826 tool.run(
1827 input,
1828 messages,
1829 self.project.clone(),
1830 self.action_log.clone(),
1831 window,
1832 cx,
1833 )
1834 };
1835
1836 // Store the card separately if it exists
1837 if let Some(card) = tool_result.card.clone() {
1838 self.tool_use
1839 .insert_tool_result_card(tool_use_id.clone(), card);
1840 }
1841
1842 cx.spawn({
1843 async move |thread: WeakEntity<Thread>, cx| {
1844 let output = tool_result.output.await;
1845
1846 thread
1847 .update(cx, |thread, cx| {
1848 let pending_tool_use = thread.tool_use.insert_tool_output(
1849 tool_use_id.clone(),
1850 tool_name,
1851 output,
1852 thread.configured_model.as_ref(),
1853 );
1854 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1855 })
1856 .ok();
1857 }
1858 })
1859 }
1860
1861 fn tool_finished(
1862 &mut self,
1863 tool_use_id: LanguageModelToolUseId,
1864 pending_tool_use: Option<PendingToolUse>,
1865 canceled: bool,
1866 window: Option<AnyWindowHandle>,
1867 cx: &mut Context<Self>,
1868 ) {
1869 if self.all_tools_finished() {
1870 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
1871 if !canceled {
1872 self.send_to_model(model.clone(), window, cx);
1873 }
1874 self.auto_capture_telemetry(cx);
1875 }
1876 }
1877
1878 cx.emit(ThreadEvent::ToolFinished {
1879 tool_use_id,
1880 pending_tool_use,
1881 });
1882 }
1883
1884 /// Cancels the last pending completion, if there are any pending.
1885 ///
1886 /// Returns whether a completion was canceled.
1887 pub fn cancel_last_completion(
1888 &mut self,
1889 window: Option<AnyWindowHandle>,
1890 cx: &mut Context<Self>,
1891 ) -> bool {
1892 let mut canceled = self.pending_completions.pop().is_some();
1893
1894 for pending_tool_use in self.tool_use.cancel_pending() {
1895 canceled = true;
1896 self.tool_finished(
1897 pending_tool_use.id.clone(),
1898 Some(pending_tool_use),
1899 true,
1900 window,
1901 cx,
1902 );
1903 }
1904
1905 self.finalize_pending_checkpoint(cx);
1906 canceled
1907 }
1908
1909 /// Signals that any in-progress editing should be canceled.
1910 ///
1911 /// This method is used to notify listeners (like ActiveThread) that
1912 /// they should cancel any editing operations.
1913 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
1914 cx.emit(ThreadEvent::CancelEditing);
1915 }
1916
1917 pub fn feedback(&self) -> Option<ThreadFeedback> {
1918 self.feedback
1919 }
1920
1921 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1922 self.message_feedback.get(&message_id).copied()
1923 }
1924
1925 pub fn report_message_feedback(
1926 &mut self,
1927 message_id: MessageId,
1928 feedback: ThreadFeedback,
1929 cx: &mut Context<Self>,
1930 ) -> Task<Result<()>> {
1931 if self.message_feedback.get(&message_id) == Some(&feedback) {
1932 return Task::ready(Ok(()));
1933 }
1934
1935 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1936 let serialized_thread = self.serialize(cx);
1937 let thread_id = self.id().clone();
1938 let client = self.project.read(cx).client();
1939
1940 let enabled_tool_names: Vec<String> = self
1941 .tools()
1942 .read(cx)
1943 .enabled_tools(cx)
1944 .iter()
1945 .map(|tool| tool.name().to_string())
1946 .collect();
1947
1948 self.message_feedback.insert(message_id, feedback);
1949
1950 cx.notify();
1951
1952 let message_content = self
1953 .message(message_id)
1954 .map(|msg| msg.to_string())
1955 .unwrap_or_default();
1956
1957 cx.background_spawn(async move {
1958 let final_project_snapshot = final_project_snapshot.await;
1959 let serialized_thread = serialized_thread.await?;
1960 let thread_data =
1961 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1962
1963 let rating = match feedback {
1964 ThreadFeedback::Positive => "positive",
1965 ThreadFeedback::Negative => "negative",
1966 };
1967 telemetry::event!(
1968 "Assistant Thread Rated",
1969 rating,
1970 thread_id,
1971 enabled_tool_names,
1972 message_id = message_id.0,
1973 message_content,
1974 thread_data,
1975 final_project_snapshot
1976 );
1977 client.telemetry().flush_events().await;
1978
1979 Ok(())
1980 })
1981 }
1982
1983 pub fn report_feedback(
1984 &mut self,
1985 feedback: ThreadFeedback,
1986 cx: &mut Context<Self>,
1987 ) -> Task<Result<()>> {
1988 let last_assistant_message_id = self
1989 .messages
1990 .iter()
1991 .rev()
1992 .find(|msg| msg.role == Role::Assistant)
1993 .map(|msg| msg.id);
1994
1995 if let Some(message_id) = last_assistant_message_id {
1996 self.report_message_feedback(message_id, feedback, cx)
1997 } else {
1998 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1999 let serialized_thread = self.serialize(cx);
2000 let thread_id = self.id().clone();
2001 let client = self.project.read(cx).client();
2002 self.feedback = Some(feedback);
2003 cx.notify();
2004
2005 cx.background_spawn(async move {
2006 let final_project_snapshot = final_project_snapshot.await;
2007 let serialized_thread = serialized_thread.await?;
2008 let thread_data = serde_json::to_value(serialized_thread)
2009 .unwrap_or_else(|_| serde_json::Value::Null);
2010
2011 let rating = match feedback {
2012 ThreadFeedback::Positive => "positive",
2013 ThreadFeedback::Negative => "negative",
2014 };
2015 telemetry::event!(
2016 "Assistant Thread Rated",
2017 rating,
2018 thread_id,
2019 thread_data,
2020 final_project_snapshot
2021 );
2022 client.telemetry().flush_events().await;
2023
2024 Ok(())
2025 })
2026 }
2027 }
2028
2029 /// Create a snapshot of the current project state including git information and unsaved buffers.
2030 fn project_snapshot(
2031 project: Entity<Project>,
2032 cx: &mut Context<Self>,
2033 ) -> Task<Arc<ProjectSnapshot>> {
2034 let git_store = project.read(cx).git_store().clone();
2035 let worktree_snapshots: Vec<_> = project
2036 .read(cx)
2037 .visible_worktrees(cx)
2038 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2039 .collect();
2040
2041 cx.spawn(async move |_, cx| {
2042 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2043
2044 let mut unsaved_buffers = Vec::new();
2045 cx.update(|app_cx| {
2046 let buffer_store = project.read(app_cx).buffer_store();
2047 for buffer_handle in buffer_store.read(app_cx).buffers() {
2048 let buffer = buffer_handle.read(app_cx);
2049 if buffer.is_dirty() {
2050 if let Some(file) = buffer.file() {
2051 let path = file.path().to_string_lossy().to_string();
2052 unsaved_buffers.push(path);
2053 }
2054 }
2055 }
2056 })
2057 .ok();
2058
2059 Arc::new(ProjectSnapshot {
2060 worktree_snapshots,
2061 unsaved_buffer_paths: unsaved_buffers,
2062 timestamp: Utc::now(),
2063 })
2064 })
2065 }
2066
2067 fn worktree_snapshot(
2068 worktree: Entity<project::Worktree>,
2069 git_store: Entity<GitStore>,
2070 cx: &App,
2071 ) -> Task<WorktreeSnapshot> {
2072 cx.spawn(async move |cx| {
2073 // Get worktree path and snapshot
2074 let worktree_info = cx.update(|app_cx| {
2075 let worktree = worktree.read(app_cx);
2076 let path = worktree.abs_path().to_string_lossy().to_string();
2077 let snapshot = worktree.snapshot();
2078 (path, snapshot)
2079 });
2080
2081 let Ok((worktree_path, _snapshot)) = worktree_info else {
2082 return WorktreeSnapshot {
2083 worktree_path: String::new(),
2084 git_state: None,
2085 };
2086 };
2087
2088 let git_state = git_store
2089 .update(cx, |git_store, cx| {
2090 git_store
2091 .repositories()
2092 .values()
2093 .find(|repo| {
2094 repo.read(cx)
2095 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2096 .is_some()
2097 })
2098 .cloned()
2099 })
2100 .ok()
2101 .flatten()
2102 .map(|repo| {
2103 repo.update(cx, |repo, _| {
2104 let current_branch =
2105 repo.branch.as_ref().map(|branch| branch.name.to_string());
2106 repo.send_job(None, |state, _| async move {
2107 let RepositoryState::Local { backend, .. } = state else {
2108 return GitState {
2109 remote_url: None,
2110 head_sha: None,
2111 current_branch,
2112 diff: None,
2113 };
2114 };
2115
2116 let remote_url = backend.remote_url("origin");
2117 let head_sha = backend.head_sha().await;
2118 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2119
2120 GitState {
2121 remote_url,
2122 head_sha,
2123 current_branch,
2124 diff,
2125 }
2126 })
2127 })
2128 });
2129
2130 let git_state = match git_state {
2131 Some(git_state) => match git_state.ok() {
2132 Some(git_state) => git_state.await.ok(),
2133 None => None,
2134 },
2135 None => None,
2136 };
2137
2138 WorktreeSnapshot {
2139 worktree_path,
2140 git_state,
2141 }
2142 })
2143 }
2144
2145 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2146 let mut markdown = Vec::new();
2147
2148 if let Some(summary) = self.summary() {
2149 writeln!(markdown, "# {summary}\n")?;
2150 };
2151
2152 for message in self.messages() {
2153 writeln!(
2154 markdown,
2155 "## {role}\n",
2156 role = match message.role {
2157 Role::User => "User",
2158 Role::Assistant => "Assistant",
2159 Role::System => "System",
2160 }
2161 )?;
2162
2163 if !message.loaded_context.text.is_empty() {
2164 writeln!(markdown, "{}", message.loaded_context.text)?;
2165 }
2166
2167 if !message.loaded_context.images.is_empty() {
2168 writeln!(
2169 markdown,
2170 "\n{} images attached as context.\n",
2171 message.loaded_context.images.len()
2172 )?;
2173 }
2174
2175 for segment in &message.segments {
2176 match segment {
2177 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2178 MessageSegment::Thinking { text, .. } => {
2179 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2180 }
2181 MessageSegment::RedactedThinking(_) => {}
2182 }
2183 }
2184
2185 for tool_use in self.tool_uses_for_message(message.id, cx) {
2186 writeln!(
2187 markdown,
2188 "**Use Tool: {} ({})**",
2189 tool_use.name, tool_use.id
2190 )?;
2191 writeln!(markdown, "```json")?;
2192 writeln!(
2193 markdown,
2194 "{}",
2195 serde_json::to_string_pretty(&tool_use.input)?
2196 )?;
2197 writeln!(markdown, "```")?;
2198 }
2199
2200 for tool_result in self.tool_results_for_message(message.id) {
2201 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2202 if tool_result.is_error {
2203 write!(markdown, " (Error)")?;
2204 }
2205
2206 writeln!(markdown, "**\n")?;
2207 writeln!(markdown, "{}", tool_result.content)?;
2208 }
2209 }
2210
2211 Ok(String::from_utf8_lossy(&markdown).to_string())
2212 }
2213
2214 pub fn keep_edits_in_range(
2215 &mut self,
2216 buffer: Entity<language::Buffer>,
2217 buffer_range: Range<language::Anchor>,
2218 cx: &mut Context<Self>,
2219 ) {
2220 self.action_log.update(cx, |action_log, cx| {
2221 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2222 });
2223 }
2224
2225 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2226 self.action_log
2227 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2228 }
2229
2230 pub fn reject_edits_in_ranges(
2231 &mut self,
2232 buffer: Entity<language::Buffer>,
2233 buffer_ranges: Vec<Range<language::Anchor>>,
2234 cx: &mut Context<Self>,
2235 ) -> Task<Result<()>> {
2236 self.action_log.update(cx, |action_log, cx| {
2237 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2238 })
2239 }
2240
2241 pub fn action_log(&self) -> &Entity<ActionLog> {
2242 &self.action_log
2243 }
2244
2245 pub fn project(&self) -> &Entity<Project> {
2246 &self.project
2247 }
2248
2249 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2250 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2251 return;
2252 }
2253
2254 let now = Instant::now();
2255 if let Some(last) = self.last_auto_capture_at {
2256 if now.duration_since(last).as_secs() < 10 {
2257 return;
2258 }
2259 }
2260
2261 self.last_auto_capture_at = Some(now);
2262
2263 let thread_id = self.id().clone();
2264 let github_login = self
2265 .project
2266 .read(cx)
2267 .user_store()
2268 .read(cx)
2269 .current_user()
2270 .map(|user| user.github_login.clone());
2271 let client = self.project.read(cx).client().clone();
2272 let serialize_task = self.serialize(cx);
2273
2274 cx.background_executor()
2275 .spawn(async move {
2276 if let Ok(serialized_thread) = serialize_task.await {
2277 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2278 telemetry::event!(
2279 "Agent Thread Auto-Captured",
2280 thread_id = thread_id.to_string(),
2281 thread_data = thread_data,
2282 auto_capture_reason = "tracked_user",
2283 github_login = github_login
2284 );
2285
2286 client.telemetry().flush_events().await;
2287 }
2288 }
2289 })
2290 .detach();
2291 }
2292
2293 pub fn cumulative_token_usage(&self) -> TokenUsage {
2294 self.cumulative_token_usage
2295 }
2296
2297 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2298 let Some(model) = self.configured_model.as_ref() else {
2299 return TotalTokenUsage::default();
2300 };
2301
2302 let max = model.model.max_token_count();
2303
2304 let index = self
2305 .messages
2306 .iter()
2307 .position(|msg| msg.id == message_id)
2308 .unwrap_or(0);
2309
2310 if index == 0 {
2311 return TotalTokenUsage { total: 0, max };
2312 }
2313
2314 let token_usage = &self
2315 .request_token_usage
2316 .get(index - 1)
2317 .cloned()
2318 .unwrap_or_default();
2319
2320 TotalTokenUsage {
2321 total: token_usage.total_tokens() as usize,
2322 max,
2323 }
2324 }
2325
2326 pub fn total_token_usage(&self) -> TotalTokenUsage {
2327 let Some(model) = self.configured_model.as_ref() else {
2328 return TotalTokenUsage::default();
2329 };
2330
2331 let max = model.model.max_token_count();
2332
2333 if let Some(exceeded_error) = &self.exceeded_window_error {
2334 if model.model.id() == exceeded_error.model_id {
2335 return TotalTokenUsage {
2336 total: exceeded_error.token_count,
2337 max,
2338 };
2339 }
2340 }
2341
2342 let total = self
2343 .token_usage_at_last_message()
2344 .unwrap_or_default()
2345 .total_tokens() as usize;
2346
2347 TotalTokenUsage { total, max }
2348 }
2349
2350 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2351 self.request_token_usage
2352 .get(self.messages.len().saturating_sub(1))
2353 .or_else(|| self.request_token_usage.last())
2354 .cloned()
2355 }
2356
2357 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2358 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2359 self.request_token_usage
2360 .resize(self.messages.len(), placeholder);
2361
2362 if let Some(last) = self.request_token_usage.last_mut() {
2363 *last = token_usage;
2364 }
2365 }
2366
2367 pub fn deny_tool_use(
2368 &mut self,
2369 tool_use_id: LanguageModelToolUseId,
2370 tool_name: Arc<str>,
2371 window: Option<AnyWindowHandle>,
2372 cx: &mut Context<Self>,
2373 ) {
2374 let err = Err(anyhow::anyhow!(
2375 "Permission to run tool action denied by user"
2376 ));
2377
2378 self.tool_use.insert_tool_output(
2379 tool_use_id.clone(),
2380 tool_name,
2381 err,
2382 self.configured_model.as_ref(),
2383 );
2384 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2385 }
2386}
2387
2388#[derive(Debug, Clone, Error)]
2389pub enum ThreadError {
2390 #[error("Payment required")]
2391 PaymentRequired,
2392 #[error("Max monthly spend reached")]
2393 MaxMonthlySpendReached,
2394 #[error("Model request limit reached")]
2395 ModelRequestLimitReached { plan: Plan },
2396 #[error("Message {header}: {message}")]
2397 Message {
2398 header: SharedString,
2399 message: SharedString,
2400 },
2401}
2402
2403#[derive(Debug, Clone)]
2404pub enum ThreadEvent {
2405 ShowError(ThreadError),
2406 UsageUpdated(RequestUsage),
2407 StreamedCompletion,
2408 ReceivedTextChunk,
2409 StreamedAssistantText(MessageId, String),
2410 StreamedAssistantThinking(MessageId, String),
2411 StreamedToolUse {
2412 tool_use_id: LanguageModelToolUseId,
2413 ui_text: Arc<str>,
2414 input: serde_json::Value,
2415 },
2416 InvalidToolInput {
2417 tool_use_id: LanguageModelToolUseId,
2418 ui_text: Arc<str>,
2419 invalid_input_json: Arc<str>,
2420 },
2421 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2422 MessageAdded(MessageId),
2423 MessageEdited(MessageId),
2424 MessageDeleted(MessageId),
2425 SummaryGenerated,
2426 SummaryChanged,
2427 UsePendingTools {
2428 tool_uses: Vec<PendingToolUse>,
2429 },
2430 ToolFinished {
2431 #[allow(unused)]
2432 tool_use_id: LanguageModelToolUseId,
2433 /// The pending tool use that corresponds to this tool.
2434 pending_tool_use: Option<PendingToolUse>,
2435 },
2436 CheckpointChanged,
2437 ToolConfirmationNeeded,
2438 CancelEditing,
2439}
2440
2441impl EventEmitter<ThreadEvent> for Thread {}
2442
2443struct PendingCompletion {
2444 id: usize,
2445 _task: Task<()>,
2446}
2447
2448#[cfg(test)]
2449mod tests {
2450 use super::*;
2451 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2452 use assistant_settings::AssistantSettings;
2453 use assistant_tool::ToolRegistry;
2454 use context_server::ContextServerSettings;
2455 use editor::EditorSettings;
2456 use gpui::TestAppContext;
2457 use language_model::fake_provider::FakeLanguageModel;
2458 use project::{FakeFs, Project};
2459 use prompt_store::PromptBuilder;
2460 use serde_json::json;
2461 use settings::{Settings, SettingsStore};
2462 use std::sync::Arc;
2463 use theme::ThemeSettings;
2464 use util::path;
2465 use workspace::Workspace;
2466
2467 #[gpui::test]
2468 async fn test_message_with_context(cx: &mut TestAppContext) {
2469 init_test_settings(cx);
2470
2471 let project = create_test_project(
2472 cx,
2473 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2474 )
2475 .await;
2476
2477 let (_workspace, _thread_store, thread, context_store, model) =
2478 setup_test_environment(cx, project.clone()).await;
2479
2480 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2481 .await
2482 .unwrap();
2483
2484 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2485 let loaded_context = cx
2486 .update(|cx| load_context(vec![context], &project, &None, cx))
2487 .await;
2488
2489 // Insert user message with context
2490 let message_id = thread.update(cx, |thread, cx| {
2491 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2492 });
2493
2494 // Check content and context in message object
2495 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2496
2497 // Use different path format strings based on platform for the test
2498 #[cfg(windows)]
2499 let path_part = r"test\code.rs";
2500 #[cfg(not(windows))]
2501 let path_part = "test/code.rs";
2502
2503 let expected_context = format!(
2504 r#"
2505<context>
2506The following items were attached by the user. You don't need to use other tools to read them.
2507
2508<files>
2509```rs {path_part}
2510fn main() {{
2511 println!("Hello, world!");
2512}}
2513```
2514</files>
2515</context>
2516"#
2517 );
2518
2519 assert_eq!(message.role, Role::User);
2520 assert_eq!(message.segments.len(), 1);
2521 assert_eq!(
2522 message.segments[0],
2523 MessageSegment::Text("Please explain this code".to_string())
2524 );
2525 assert_eq!(message.loaded_context.text, expected_context);
2526
2527 // Check message in request
2528 let request = thread.update(cx, |thread, cx| {
2529 thread.to_completion_request(model.clone(), cx)
2530 });
2531
2532 assert_eq!(request.messages.len(), 2);
2533 let expected_full_message = format!("{}Please explain this code", expected_context);
2534 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2535 }
2536
2537 #[gpui::test]
2538 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2539 init_test_settings(cx);
2540
2541 let project = create_test_project(
2542 cx,
2543 json!({
2544 "file1.rs": "fn function1() {}\n",
2545 "file2.rs": "fn function2() {}\n",
2546 "file3.rs": "fn function3() {}\n",
2547 }),
2548 )
2549 .await;
2550
2551 let (_, _thread_store, thread, context_store, model) =
2552 setup_test_environment(cx, project.clone()).await;
2553
2554 // First message with context 1
2555 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2556 .await
2557 .unwrap();
2558 let new_contexts = context_store.update(cx, |store, cx| {
2559 store.new_context_for_thread(thread.read(cx))
2560 });
2561 assert_eq!(new_contexts.len(), 1);
2562 let loaded_context = cx
2563 .update(|cx| load_context(new_contexts, &project, &None, cx))
2564 .await;
2565 let message1_id = thread.update(cx, |thread, cx| {
2566 thread.insert_user_message("Message 1", loaded_context, None, cx)
2567 });
2568
2569 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2570 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2571 .await
2572 .unwrap();
2573 let new_contexts = context_store.update(cx, |store, cx| {
2574 store.new_context_for_thread(thread.read(cx))
2575 });
2576 assert_eq!(new_contexts.len(), 1);
2577 let loaded_context = cx
2578 .update(|cx| load_context(new_contexts, &project, &None, cx))
2579 .await;
2580 let message2_id = thread.update(cx, |thread, cx| {
2581 thread.insert_user_message("Message 2", loaded_context, None, cx)
2582 });
2583
2584 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2585 //
2586 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2587 .await
2588 .unwrap();
2589 let new_contexts = context_store.update(cx, |store, cx| {
2590 store.new_context_for_thread(thread.read(cx))
2591 });
2592 assert_eq!(new_contexts.len(), 1);
2593 let loaded_context = cx
2594 .update(|cx| load_context(new_contexts, &project, &None, cx))
2595 .await;
2596 let message3_id = thread.update(cx, |thread, cx| {
2597 thread.insert_user_message("Message 3", loaded_context, None, cx)
2598 });
2599
2600 // Check what contexts are included in each message
2601 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2602 (
2603 thread.message(message1_id).unwrap().clone(),
2604 thread.message(message2_id).unwrap().clone(),
2605 thread.message(message3_id).unwrap().clone(),
2606 )
2607 });
2608
2609 // First message should include context 1
2610 assert!(message1.loaded_context.text.contains("file1.rs"));
2611
2612 // Second message should include only context 2 (not 1)
2613 assert!(!message2.loaded_context.text.contains("file1.rs"));
2614 assert!(message2.loaded_context.text.contains("file2.rs"));
2615
2616 // Third message should include only context 3 (not 1 or 2)
2617 assert!(!message3.loaded_context.text.contains("file1.rs"));
2618 assert!(!message3.loaded_context.text.contains("file2.rs"));
2619 assert!(message3.loaded_context.text.contains("file3.rs"));
2620
2621 // Check entire request to make sure all contexts are properly included
2622 let request = thread.update(cx, |thread, cx| {
2623 thread.to_completion_request(model.clone(), cx)
2624 });
2625
2626 // The request should contain all 3 messages
2627 assert_eq!(request.messages.len(), 4);
2628
2629 // Check that the contexts are properly formatted in each message
2630 assert!(request.messages[1].string_contents().contains("file1.rs"));
2631 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2632 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2633
2634 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2635 assert!(request.messages[2].string_contents().contains("file2.rs"));
2636 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2637
2638 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2639 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2640 assert!(request.messages[3].string_contents().contains("file3.rs"));
2641 }
2642
2643 #[gpui::test]
2644 async fn test_message_without_files(cx: &mut TestAppContext) {
2645 init_test_settings(cx);
2646
2647 let project = create_test_project(
2648 cx,
2649 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2650 )
2651 .await;
2652
2653 let (_, _thread_store, thread, _context_store, model) =
2654 setup_test_environment(cx, project.clone()).await;
2655
2656 // Insert user message without any context (empty context vector)
2657 let message_id = thread.update(cx, |thread, cx| {
2658 thread.insert_user_message(
2659 "What is the best way to learn Rust?",
2660 ContextLoadResult::default(),
2661 None,
2662 cx,
2663 )
2664 });
2665
2666 // Check content and context in message object
2667 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2668
2669 // Context should be empty when no files are included
2670 assert_eq!(message.role, Role::User);
2671 assert_eq!(message.segments.len(), 1);
2672 assert_eq!(
2673 message.segments[0],
2674 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2675 );
2676 assert_eq!(message.loaded_context.text, "");
2677
2678 // Check message in request
2679 let request = thread.update(cx, |thread, cx| {
2680 thread.to_completion_request(model.clone(), cx)
2681 });
2682
2683 assert_eq!(request.messages.len(), 2);
2684 assert_eq!(
2685 request.messages[1].string_contents(),
2686 "What is the best way to learn Rust?"
2687 );
2688
2689 // Add second message, also without context
2690 let message2_id = thread.update(cx, |thread, cx| {
2691 thread.insert_user_message(
2692 "Are there any good books?",
2693 ContextLoadResult::default(),
2694 None,
2695 cx,
2696 )
2697 });
2698
2699 let message2 =
2700 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2701 assert_eq!(message2.loaded_context.text, "");
2702
2703 // Check that both messages appear in the request
2704 let request = thread.update(cx, |thread, cx| {
2705 thread.to_completion_request(model.clone(), cx)
2706 });
2707
2708 assert_eq!(request.messages.len(), 3);
2709 assert_eq!(
2710 request.messages[1].string_contents(),
2711 "What is the best way to learn Rust?"
2712 );
2713 assert_eq!(
2714 request.messages[2].string_contents(),
2715 "Are there any good books?"
2716 );
2717 }
2718
2719 #[gpui::test]
2720 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2721 init_test_settings(cx);
2722
2723 let project = create_test_project(
2724 cx,
2725 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2726 )
2727 .await;
2728
2729 let (_workspace, _thread_store, thread, context_store, model) =
2730 setup_test_environment(cx, project.clone()).await;
2731
2732 // Open buffer and add it to context
2733 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2734 .await
2735 .unwrap();
2736
2737 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2738 let loaded_context = cx
2739 .update(|cx| load_context(vec![context], &project, &None, cx))
2740 .await;
2741
2742 // Insert user message with the buffer as context
2743 thread.update(cx, |thread, cx| {
2744 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2745 });
2746
2747 // Create a request and check that it doesn't have a stale buffer warning yet
2748 let initial_request = thread.update(cx, |thread, cx| {
2749 thread.to_completion_request(model.clone(), cx)
2750 });
2751
2752 // Make sure we don't have a stale file warning yet
2753 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2754 msg.string_contents()
2755 .contains("These files changed since last read:")
2756 });
2757 assert!(
2758 !has_stale_warning,
2759 "Should not have stale buffer warning before buffer is modified"
2760 );
2761
2762 // Modify the buffer
2763 buffer.update(cx, |buffer, cx| {
2764 // Find a position at the end of line 1
2765 buffer.edit(
2766 [(1..1, "\n println!(\"Added a new line\");\n")],
2767 None,
2768 cx,
2769 );
2770 });
2771
2772 // Insert another user message without context
2773 thread.update(cx, |thread, cx| {
2774 thread.insert_user_message(
2775 "What does the code do now?",
2776 ContextLoadResult::default(),
2777 None,
2778 cx,
2779 )
2780 });
2781
2782 // Create a new request and check for the stale buffer warning
2783 let new_request = thread.update(cx, |thread, cx| {
2784 thread.to_completion_request(model.clone(), cx)
2785 });
2786
2787 // We should have a stale file warning as the last message
2788 let last_message = new_request
2789 .messages
2790 .last()
2791 .expect("Request should have messages");
2792
2793 // The last message should be the stale buffer notification
2794 assert_eq!(last_message.role, Role::User);
2795
2796 // Check the exact content of the message
2797 let expected_content = "These files changed since last read:\n- code.rs\n";
2798 assert_eq!(
2799 last_message.string_contents(),
2800 expected_content,
2801 "Last message should be exactly the stale buffer notification"
2802 );
2803 }
2804
2805 fn init_test_settings(cx: &mut TestAppContext) {
2806 cx.update(|cx| {
2807 let settings_store = SettingsStore::test(cx);
2808 cx.set_global(settings_store);
2809 language::init(cx);
2810 Project::init_settings(cx);
2811 AssistantSettings::register(cx);
2812 prompt_store::init(cx);
2813 thread_store::init(cx);
2814 workspace::init_settings(cx);
2815 language_model::init_settings(cx);
2816 ThemeSettings::register(cx);
2817 ContextServerSettings::register(cx);
2818 EditorSettings::register(cx);
2819 ToolRegistry::default_global(cx);
2820 });
2821 }
2822
2823 // Helper to create a test project with test files
2824 async fn create_test_project(
2825 cx: &mut TestAppContext,
2826 files: serde_json::Value,
2827 ) -> Entity<Project> {
2828 let fs = FakeFs::new(cx.executor());
2829 fs.insert_tree(path!("/test"), files).await;
2830 Project::test(fs, [path!("/test").as_ref()], cx).await
2831 }
2832
2833 async fn setup_test_environment(
2834 cx: &mut TestAppContext,
2835 project: Entity<Project>,
2836 ) -> (
2837 Entity<Workspace>,
2838 Entity<ThreadStore>,
2839 Entity<Thread>,
2840 Entity<ContextStore>,
2841 Arc<dyn LanguageModel>,
2842 ) {
2843 let (workspace, cx) =
2844 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2845
2846 let thread_store = cx
2847 .update(|_, cx| {
2848 ThreadStore::load(
2849 project.clone(),
2850 cx.new(|_| ToolWorkingSet::default()),
2851 None,
2852 Arc::new(PromptBuilder::new(None).unwrap()),
2853 cx,
2854 )
2855 })
2856 .await
2857 .unwrap();
2858
2859 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2860 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2861
2862 let model = FakeLanguageModel::default();
2863 let model: Arc<dyn LanguageModel> = Arc::new(model);
2864
2865 (workspace, thread_store, thread, context_store, model)
2866 }
2867
2868 async fn add_file_to_context(
2869 project: &Entity<Project>,
2870 context_store: &Entity<ContextStore>,
2871 path: &str,
2872 cx: &mut TestAppContext,
2873 ) -> Result<Entity<language::Buffer>> {
2874 let buffer_path = project
2875 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2876 .unwrap();
2877
2878 let buffer = project
2879 .update(cx, |project, cx| {
2880 project.open_buffer(buffer_path.clone(), cx)
2881 })
2882 .await
2883 .unwrap();
2884
2885 context_store.update(cx, |context_store, cx| {
2886 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2887 });
2888
2889 Ok(buffer)
2890 }
2891}