1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
26 TokenUsage,
27};
28use postage::stream::Stream as _;
29use project::Project;
30use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
31use prompt_store::{ModelContext, PromptBuilder};
32use proto::Plan;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use thiserror::Error;
37use util::{ResultExt as _, TryFutureExt as _, post_inc};
38use uuid::Uuid;
39use zed_llm_client::CompletionMode;
40
41use crate::ThreadStore;
42use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
43use crate::thread_store::{
44 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
45 SerializedToolUse, SharedProjectContext,
46};
47use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
48
49#[derive(
50 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
51)]
52pub struct ThreadId(Arc<str>);
53
54impl ThreadId {
55 pub fn new() -> Self {
56 Self(Uuid::new_v4().to_string().into())
57 }
58}
59
60impl std::fmt::Display for ThreadId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<&str> for ThreadId {
67 fn from(value: &str) -> Self {
68 Self(value.into())
69 }
70}
71
72/// The ID of the user prompt that initiated a request.
73///
74/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
75#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
76pub struct PromptId(Arc<str>);
77
78impl PromptId {
79 pub fn new() -> Self {
80 Self(Uuid::new_v4().to_string().into())
81 }
82}
83
84impl std::fmt::Display for PromptId {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
91pub struct MessageId(pub(crate) usize);
92
93impl MessageId {
94 fn post_inc(&mut self) -> Self {
95 Self(post_inc(&mut self.0))
96 }
97}
98
99/// A message in a [`Thread`].
100#[derive(Debug, Clone)]
101pub struct Message {
102 pub id: MessageId,
103 pub role: Role,
104 pub segments: Vec<MessageSegment>,
105 pub loaded_context: LoadedContext,
106}
107
108impl Message {
109 /// Returns whether the message contains any meaningful text that should be displayed
110 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
111 pub fn should_display_content(&self) -> bool {
112 self.segments.iter().all(|segment| segment.should_display())
113 }
114
115 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
116 if let Some(MessageSegment::Thinking {
117 text: segment,
118 signature: current_signature,
119 }) = self.segments.last_mut()
120 {
121 if let Some(signature) = signature {
122 *current_signature = Some(signature);
123 }
124 segment.push_str(text);
125 } else {
126 self.segments.push(MessageSegment::Thinking {
127 text: text.to_string(),
128 signature,
129 });
130 }
131 }
132
133 pub fn push_text(&mut self, text: &str) {
134 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Text(text.to_string()));
138 }
139 }
140
141 pub fn to_string(&self) -> String {
142 let mut result = String::new();
143
144 if !self.loaded_context.text.is_empty() {
145 result.push_str(&self.loaded_context.text);
146 }
147
148 for segment in &self.segments {
149 match segment {
150 MessageSegment::Text(text) => result.push_str(text),
151 MessageSegment::Thinking { text, .. } => {
152 result.push_str("<think>\n");
153 result.push_str(text);
154 result.push_str("\n</think>");
155 }
156 MessageSegment::RedactedThinking(_) => {}
157 }
158 }
159
160 result
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum MessageSegment {
166 Text(String),
167 Thinking {
168 text: String,
169 signature: Option<String>,
170 },
171 RedactedThinking(Vec<u8>),
172}
173
174impl MessageSegment {
175 pub fn should_display(&self) -> bool {
176 match self {
177 Self::Text(text) => text.is_empty(),
178 Self::Thinking { text, .. } => text.is_empty(),
179 Self::RedactedThinking(_) => false,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ProjectSnapshot {
186 pub worktree_snapshots: Vec<WorktreeSnapshot>,
187 pub unsaved_buffer_paths: Vec<String>,
188 pub timestamp: DateTime<Utc>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct WorktreeSnapshot {
193 pub worktree_path: String,
194 pub git_state: Option<GitState>,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct GitState {
199 pub remote_url: Option<String>,
200 pub head_sha: Option<String>,
201 pub current_branch: Option<String>,
202 pub diff: Option<String>,
203}
204
205#[derive(Clone)]
206pub struct ThreadCheckpoint {
207 message_id: MessageId,
208 git_checkpoint: GitStoreCheckpoint,
209}
210
211#[derive(Copy, Clone, Debug, PartialEq, Eq)]
212pub enum ThreadFeedback {
213 Positive,
214 Negative,
215}
216
217pub enum LastRestoreCheckpoint {
218 Pending {
219 message_id: MessageId,
220 },
221 Error {
222 message_id: MessageId,
223 error: String,
224 },
225}
226
227impl LastRestoreCheckpoint {
228 pub fn message_id(&self) -> MessageId {
229 match self {
230 LastRestoreCheckpoint::Pending { message_id } => *message_id,
231 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
232 }
233 }
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub enum DetailedSummaryState {
238 #[default]
239 NotGenerated,
240 Generating {
241 message_id: MessageId,
242 },
243 Generated {
244 text: SharedString,
245 message_id: MessageId,
246 },
247}
248
249impl DetailedSummaryState {
250 fn text(&self) -> Option<SharedString> {
251 if let Self::Generated { text, .. } = self {
252 Some(text.clone())
253 } else {
254 None
255 }
256 }
257}
258
259#[derive(Default)]
260pub struct TotalTokenUsage {
261 pub total: usize,
262 pub max: usize,
263}
264
265impl TotalTokenUsage {
266 pub fn ratio(&self) -> TokenUsageRatio {
267 #[cfg(debug_assertions)]
268 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
269 .unwrap_or("0.8".to_string())
270 .parse()
271 .unwrap();
272 #[cfg(not(debug_assertions))]
273 let warning_threshold: f32 = 0.8;
274
275 if self.total >= self.max {
276 TokenUsageRatio::Exceeded
277 } else if self.total as f32 / self.max as f32 >= warning_threshold {
278 TokenUsageRatio::Warning
279 } else {
280 TokenUsageRatio::Normal
281 }
282 }
283
284 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
285 TotalTokenUsage {
286 total: self.total + tokens,
287 max: self.max,
288 }
289 }
290}
291
292#[derive(Debug, Default, PartialEq, Eq)]
293pub enum TokenUsageRatio {
294 #[default]
295 Normal,
296 Warning,
297 Exceeded,
298}
299
300/// A thread of conversation with the LLM.
301pub struct Thread {
302 id: ThreadId,
303 updated_at: DateTime<Utc>,
304 summary: Option<SharedString>,
305 pending_summary: Task<Option<()>>,
306 detailed_summary_task: Task<Option<()>>,
307 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
308 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
309 completion_mode: Option<CompletionMode>,
310 messages: Vec<Message>,
311 next_message_id: MessageId,
312 last_prompt_id: PromptId,
313 project_context: SharedProjectContext,
314 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
315 completion_count: usize,
316 pending_completions: Vec<PendingCompletion>,
317 project: Entity<Project>,
318 prompt_builder: Arc<PromptBuilder>,
319 tools: Entity<ToolWorkingSet>,
320 tool_use: ToolUseState,
321 action_log: Entity<ActionLog>,
322 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
323 pending_checkpoint: Option<ThreadCheckpoint>,
324 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
325 request_token_usage: Vec<TokenUsage>,
326 cumulative_token_usage: TokenUsage,
327 exceeded_window_error: Option<ExceededWindowError>,
328 feedback: Option<ThreadFeedback>,
329 message_feedback: HashMap<MessageId, ThreadFeedback>,
330 last_auto_capture_at: Option<Instant>,
331 request_callback: Option<
332 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
333 >,
334 remaining_turns: u32,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct ExceededWindowError {
339 /// Model used when last message exceeded context window
340 model_id: LanguageModelId,
341 /// Token count including last message
342 token_count: usize,
343}
344
345impl Thread {
346 pub fn new(
347 project: Entity<Project>,
348 tools: Entity<ToolWorkingSet>,
349 prompt_builder: Arc<PromptBuilder>,
350 system_prompt: SharedProjectContext,
351 cx: &mut Context<Self>,
352 ) -> Self {
353 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
354 Self {
355 id: ThreadId::new(),
356 updated_at: Utc::now(),
357 summary: None,
358 pending_summary: Task::ready(None),
359 detailed_summary_task: Task::ready(None),
360 detailed_summary_tx,
361 detailed_summary_rx,
362 completion_mode: None,
363 messages: Vec::new(),
364 next_message_id: MessageId(0),
365 last_prompt_id: PromptId::new(),
366 project_context: system_prompt,
367 checkpoints_by_message: HashMap::default(),
368 completion_count: 0,
369 pending_completions: Vec::new(),
370 project: project.clone(),
371 prompt_builder,
372 tools: tools.clone(),
373 last_restore_checkpoint: None,
374 pending_checkpoint: None,
375 tool_use: ToolUseState::new(tools.clone()),
376 action_log: cx.new(|_| ActionLog::new(project.clone())),
377 initial_project_snapshot: {
378 let project_snapshot = Self::project_snapshot(project, cx);
379 cx.foreground_executor()
380 .spawn(async move { Some(project_snapshot.await) })
381 .shared()
382 },
383 request_token_usage: Vec::new(),
384 cumulative_token_usage: TokenUsage::default(),
385 exceeded_window_error: None,
386 feedback: None,
387 message_feedback: HashMap::default(),
388 last_auto_capture_at: None,
389 request_callback: None,
390 remaining_turns: u32::MAX,
391 }
392 }
393
394 pub fn deserialize(
395 id: ThreadId,
396 serialized: SerializedThread,
397 project: Entity<Project>,
398 tools: Entity<ToolWorkingSet>,
399 prompt_builder: Arc<PromptBuilder>,
400 project_context: SharedProjectContext,
401 cx: &mut Context<Self>,
402 ) -> Self {
403 let next_message_id = MessageId(
404 serialized
405 .messages
406 .last()
407 .map(|message| message.id.0 + 1)
408 .unwrap_or(0),
409 );
410 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
411 let (detailed_summary_tx, detailed_summary_rx) =
412 postage::watch::channel_with(serialized.detailed_summary_state);
413
414 Self {
415 id,
416 updated_at: serialized.updated_at,
417 summary: Some(serialized.summary),
418 pending_summary: Task::ready(None),
419 detailed_summary_task: Task::ready(None),
420 detailed_summary_tx,
421 detailed_summary_rx,
422 completion_mode: None,
423 messages: serialized
424 .messages
425 .into_iter()
426 .map(|message| Message {
427 id: message.id,
428 role: message.role,
429 segments: message
430 .segments
431 .into_iter()
432 .map(|segment| match segment {
433 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
434 SerializedMessageSegment::Thinking { text, signature } => {
435 MessageSegment::Thinking { text, signature }
436 }
437 SerializedMessageSegment::RedactedThinking { data } => {
438 MessageSegment::RedactedThinking(data)
439 }
440 })
441 .collect(),
442 loaded_context: LoadedContext {
443 contexts: Vec::new(),
444 text: message.context,
445 images: Vec::new(),
446 },
447 })
448 .collect(),
449 next_message_id,
450 last_prompt_id: PromptId::new(),
451 project_context,
452 checkpoints_by_message: HashMap::default(),
453 completion_count: 0,
454 pending_completions: Vec::new(),
455 last_restore_checkpoint: None,
456 pending_checkpoint: None,
457 project: project.clone(),
458 prompt_builder,
459 tools,
460 tool_use,
461 action_log: cx.new(|_| ActionLog::new(project)),
462 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
463 request_token_usage: serialized.request_token_usage,
464 cumulative_token_usage: serialized.cumulative_token_usage,
465 exceeded_window_error: None,
466 feedback: None,
467 message_feedback: HashMap::default(),
468 last_auto_capture_at: None,
469 request_callback: None,
470 remaining_turns: u32::MAX,
471 }
472 }
473
474 pub fn set_request_callback(
475 &mut self,
476 callback: impl 'static
477 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
478 ) {
479 self.request_callback = Some(Box::new(callback));
480 }
481
482 pub fn id(&self) -> &ThreadId {
483 &self.id
484 }
485
486 pub fn is_empty(&self) -> bool {
487 self.messages.is_empty()
488 }
489
490 pub fn updated_at(&self) -> DateTime<Utc> {
491 self.updated_at
492 }
493
494 pub fn touch_updated_at(&mut self) {
495 self.updated_at = Utc::now();
496 }
497
498 pub fn advance_prompt_id(&mut self) {
499 self.last_prompt_id = PromptId::new();
500 }
501
502 pub fn summary(&self) -> Option<SharedString> {
503 self.summary.clone()
504 }
505
506 pub fn project_context(&self) -> SharedProjectContext {
507 self.project_context.clone()
508 }
509
510 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
511
512 pub fn summary_or_default(&self) -> SharedString {
513 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
514 }
515
516 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
517 let Some(current_summary) = &self.summary else {
518 // Don't allow setting summary until generated
519 return;
520 };
521
522 let mut new_summary = new_summary.into();
523
524 if new_summary.is_empty() {
525 new_summary = Self::DEFAULT_SUMMARY;
526 }
527
528 if current_summary != &new_summary {
529 self.summary = Some(new_summary);
530 cx.emit(ThreadEvent::SummaryChanged);
531 }
532 }
533
534 pub fn completion_mode(&self) -> Option<CompletionMode> {
535 self.completion_mode
536 }
537
538 pub fn set_completion_mode(&mut self, mode: Option<CompletionMode>) {
539 self.completion_mode = mode;
540 }
541
542 pub fn message(&self, id: MessageId) -> Option<&Message> {
543 let index = self
544 .messages
545 .binary_search_by(|message| message.id.cmp(&id))
546 .ok()?;
547
548 self.messages.get(index)
549 }
550
551 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
552 self.messages.iter()
553 }
554
555 pub fn is_generating(&self) -> bool {
556 !self.pending_completions.is_empty() || !self.all_tools_finished()
557 }
558
559 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
560 &self.tools
561 }
562
563 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
564 self.tool_use
565 .pending_tool_uses()
566 .into_iter()
567 .find(|tool_use| &tool_use.id == id)
568 }
569
570 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
571 self.tool_use
572 .pending_tool_uses()
573 .into_iter()
574 .filter(|tool_use| tool_use.status.needs_confirmation())
575 }
576
577 pub fn has_pending_tool_uses(&self) -> bool {
578 !self.tool_use.pending_tool_uses().is_empty()
579 }
580
581 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
582 self.checkpoints_by_message.get(&id).cloned()
583 }
584
585 pub fn restore_checkpoint(
586 &mut self,
587 checkpoint: ThreadCheckpoint,
588 cx: &mut Context<Self>,
589 ) -> Task<Result<()>> {
590 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
591 message_id: checkpoint.message_id,
592 });
593 cx.emit(ThreadEvent::CheckpointChanged);
594 cx.notify();
595
596 let git_store = self.project().read(cx).git_store().clone();
597 let restore = git_store.update(cx, |git_store, cx| {
598 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
599 });
600
601 cx.spawn(async move |this, cx| {
602 let result = restore.await;
603 this.update(cx, |this, cx| {
604 if let Err(err) = result.as_ref() {
605 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
606 message_id: checkpoint.message_id,
607 error: err.to_string(),
608 });
609 } else {
610 this.truncate(checkpoint.message_id, cx);
611 this.last_restore_checkpoint = None;
612 }
613 this.pending_checkpoint = None;
614 cx.emit(ThreadEvent::CheckpointChanged);
615 cx.notify();
616 })?;
617 result
618 })
619 }
620
621 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
622 let pending_checkpoint = if self.is_generating() {
623 return;
624 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
625 checkpoint
626 } else {
627 return;
628 };
629
630 let git_store = self.project.read(cx).git_store().clone();
631 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
632 cx.spawn(async move |this, cx| match final_checkpoint.await {
633 Ok(final_checkpoint) => {
634 let equal = git_store
635 .update(cx, |store, cx| {
636 store.compare_checkpoints(
637 pending_checkpoint.git_checkpoint.clone(),
638 final_checkpoint.clone(),
639 cx,
640 )
641 })?
642 .await
643 .unwrap_or(false);
644
645 if !equal {
646 this.update(cx, |this, cx| {
647 this.insert_checkpoint(pending_checkpoint, cx)
648 })?;
649 }
650
651 Ok(())
652 }
653 Err(_) => this.update(cx, |this, cx| {
654 this.insert_checkpoint(pending_checkpoint, cx)
655 }),
656 })
657 .detach();
658 }
659
660 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
661 self.checkpoints_by_message
662 .insert(checkpoint.message_id, checkpoint);
663 cx.emit(ThreadEvent::CheckpointChanged);
664 cx.notify();
665 }
666
667 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
668 self.last_restore_checkpoint.as_ref()
669 }
670
671 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
672 let Some(message_ix) = self
673 .messages
674 .iter()
675 .rposition(|message| message.id == message_id)
676 else {
677 return;
678 };
679 for deleted_message in self.messages.drain(message_ix..) {
680 self.checkpoints_by_message.remove(&deleted_message.id);
681 }
682 cx.notify();
683 }
684
685 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
686 self.messages
687 .iter()
688 .find(|message| message.id == id)
689 .into_iter()
690 .flat_map(|message| message.loaded_context.contexts.iter())
691 }
692
693 pub fn is_turn_end(&self, ix: usize) -> bool {
694 if self.messages.is_empty() {
695 return false;
696 }
697
698 if !self.is_generating() && ix == self.messages.len() - 1 {
699 return true;
700 }
701
702 let Some(message) = self.messages.get(ix) else {
703 return false;
704 };
705
706 if message.role != Role::Assistant {
707 return false;
708 }
709
710 self.messages
711 .get(ix + 1)
712 .and_then(|message| {
713 self.message(message.id)
714 .map(|next_message| next_message.role == Role::User)
715 })
716 .unwrap_or(false)
717 }
718
719 /// Returns whether all of the tool uses have finished running.
720 pub fn all_tools_finished(&self) -> bool {
721 // If the only pending tool uses left are the ones with errors, then
722 // that means that we've finished running all of the pending tools.
723 self.tool_use
724 .pending_tool_uses()
725 .iter()
726 .all(|tool_use| tool_use.status.is_error())
727 }
728
729 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
730 self.tool_use.tool_uses_for_message(id, cx)
731 }
732
733 pub fn tool_results_for_message(
734 &self,
735 assistant_message_id: MessageId,
736 ) -> Vec<&LanguageModelToolResult> {
737 self.tool_use.tool_results_for_message(assistant_message_id)
738 }
739
740 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
741 self.tool_use.tool_result(id)
742 }
743
744 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
745 Some(&self.tool_use.tool_result(id)?.content)
746 }
747
748 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
749 self.tool_use.tool_result_card(id).cloned()
750 }
751
752 /// Return tools that are both enabled and supported by the model
753 pub fn available_tools(
754 &self,
755 cx: &App,
756 model: Arc<dyn LanguageModel>,
757 ) -> Vec<LanguageModelRequestTool> {
758 if model.supports_tools() {
759 self.tools()
760 .read(cx)
761 .enabled_tools(cx)
762 .into_iter()
763 .filter_map(|tool| {
764 // Skip tools that cannot be supported
765 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
766 Some(LanguageModelRequestTool {
767 name: tool.name(),
768 description: tool.description(),
769 input_schema,
770 })
771 })
772 .collect()
773 } else {
774 Vec::default()
775 }
776 }
777
778 pub fn insert_user_message(
779 &mut self,
780 text: impl Into<String>,
781 loaded_context: ContextLoadResult,
782 git_checkpoint: Option<GitStoreCheckpoint>,
783 cx: &mut Context<Self>,
784 ) -> MessageId {
785 if !loaded_context.referenced_buffers.is_empty() {
786 self.action_log.update(cx, |log, cx| {
787 for buffer in loaded_context.referenced_buffers {
788 log.track_buffer(buffer, cx);
789 }
790 });
791 }
792
793 let message_id = self.insert_message(
794 Role::User,
795 vec![MessageSegment::Text(text.into())],
796 loaded_context.loaded_context,
797 cx,
798 );
799
800 if let Some(git_checkpoint) = git_checkpoint {
801 self.pending_checkpoint = Some(ThreadCheckpoint {
802 message_id,
803 git_checkpoint,
804 });
805 }
806
807 self.auto_capture_telemetry(cx);
808
809 message_id
810 }
811
812 pub fn insert_assistant_message(
813 &mut self,
814 segments: Vec<MessageSegment>,
815 cx: &mut Context<Self>,
816 ) -> MessageId {
817 self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
818 }
819
820 pub fn insert_message(
821 &mut self,
822 role: Role,
823 segments: Vec<MessageSegment>,
824 loaded_context: LoadedContext,
825 cx: &mut Context<Self>,
826 ) -> MessageId {
827 let id = self.next_message_id.post_inc();
828 self.messages.push(Message {
829 id,
830 role,
831 segments,
832 loaded_context,
833 });
834 self.touch_updated_at();
835 cx.emit(ThreadEvent::MessageAdded(id));
836 id
837 }
838
839 pub fn edit_message(
840 &mut self,
841 id: MessageId,
842 new_role: Role,
843 new_segments: Vec<MessageSegment>,
844 cx: &mut Context<Self>,
845 ) -> bool {
846 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
847 return false;
848 };
849 message.role = new_role;
850 message.segments = new_segments;
851 self.touch_updated_at();
852 cx.emit(ThreadEvent::MessageEdited(id));
853 true
854 }
855
856 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
857 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
858 return false;
859 };
860 self.messages.remove(index);
861 self.touch_updated_at();
862 cx.emit(ThreadEvent::MessageDeleted(id));
863 true
864 }
865
866 /// Returns the representation of this [`Thread`] in a textual form.
867 ///
868 /// This is the representation we use when attaching a thread as context to another thread.
869 pub fn text(&self) -> String {
870 let mut text = String::new();
871
872 for message in &self.messages {
873 text.push_str(match message.role {
874 language_model::Role::User => "User:",
875 language_model::Role::Assistant => "Assistant:",
876 language_model::Role::System => "System:",
877 });
878 text.push('\n');
879
880 for segment in &message.segments {
881 match segment {
882 MessageSegment::Text(content) => text.push_str(content),
883 MessageSegment::Thinking { text: content, .. } => {
884 text.push_str(&format!("<think>{}</think>", content))
885 }
886 MessageSegment::RedactedThinking(_) => {}
887 }
888 }
889 text.push('\n');
890 }
891
892 text
893 }
894
895 /// Serializes this thread into a format for storage or telemetry.
896 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
897 let initial_project_snapshot = self.initial_project_snapshot.clone();
898 cx.spawn(async move |this, cx| {
899 let initial_project_snapshot = initial_project_snapshot.await;
900 this.read_with(cx, |this, cx| SerializedThread {
901 version: SerializedThread::VERSION.to_string(),
902 summary: this.summary_or_default(),
903 updated_at: this.updated_at(),
904 messages: this
905 .messages()
906 .map(|message| SerializedMessage {
907 id: message.id,
908 role: message.role,
909 segments: message
910 .segments
911 .iter()
912 .map(|segment| match segment {
913 MessageSegment::Text(text) => {
914 SerializedMessageSegment::Text { text: text.clone() }
915 }
916 MessageSegment::Thinking { text, signature } => {
917 SerializedMessageSegment::Thinking {
918 text: text.clone(),
919 signature: signature.clone(),
920 }
921 }
922 MessageSegment::RedactedThinking(data) => {
923 SerializedMessageSegment::RedactedThinking {
924 data: data.clone(),
925 }
926 }
927 })
928 .collect(),
929 tool_uses: this
930 .tool_uses_for_message(message.id, cx)
931 .into_iter()
932 .map(|tool_use| SerializedToolUse {
933 id: tool_use.id,
934 name: tool_use.name,
935 input: tool_use.input,
936 })
937 .collect(),
938 tool_results: this
939 .tool_results_for_message(message.id)
940 .into_iter()
941 .map(|tool_result| SerializedToolResult {
942 tool_use_id: tool_result.tool_use_id.clone(),
943 is_error: tool_result.is_error,
944 content: tool_result.content.clone(),
945 })
946 .collect(),
947 context: message.loaded_context.text.clone(),
948 })
949 .collect(),
950 initial_project_snapshot,
951 cumulative_token_usage: this.cumulative_token_usage,
952 request_token_usage: this.request_token_usage.clone(),
953 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
954 exceeded_window_error: this.exceeded_window_error.clone(),
955 })
956 })
957 }
958
959 pub fn remaining_turns(&self) -> u32 {
960 self.remaining_turns
961 }
962
963 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
964 self.remaining_turns = remaining_turns;
965 }
966
967 pub fn send_to_model(
968 &mut self,
969 model: Arc<dyn LanguageModel>,
970 window: Option<AnyWindowHandle>,
971 cx: &mut Context<Self>,
972 ) {
973 if self.remaining_turns == 0 {
974 return;
975 }
976
977 self.remaining_turns -= 1;
978
979 let request = self.to_completion_request(model.clone(), cx);
980
981 self.stream_completion(request, model, window, cx);
982 }
983
984 pub fn used_tools_since_last_user_message(&self) -> bool {
985 for message in self.messages.iter().rev() {
986 if self.tool_use.message_has_tool_results(message.id) {
987 return true;
988 } else if message.role == Role::User {
989 return false;
990 }
991 }
992
993 false
994 }
995
996 pub fn to_completion_request(
997 &self,
998 model: Arc<dyn LanguageModel>,
999 cx: &mut Context<Self>,
1000 ) -> LanguageModelRequest {
1001 let mut request = LanguageModelRequest {
1002 thread_id: Some(self.id.to_string()),
1003 prompt_id: Some(self.last_prompt_id.to_string()),
1004 mode: None,
1005 messages: vec![],
1006 tools: Vec::new(),
1007 stop: Vec::new(),
1008 temperature: None,
1009 };
1010
1011 let available_tools = self.available_tools(cx, model.clone());
1012 let available_tool_names = available_tools
1013 .iter()
1014 .map(|tool| tool.name.clone())
1015 .collect();
1016
1017 let model_context = &ModelContext {
1018 available_tools: available_tool_names,
1019 };
1020
1021 if let Some(project_context) = self.project_context.borrow().as_ref() {
1022 match self
1023 .prompt_builder
1024 .generate_assistant_system_prompt(project_context, model_context)
1025 {
1026 Err(err) => {
1027 let message = format!("{err:?}").into();
1028 log::error!("{message}");
1029 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1030 header: "Error generating system prompt".into(),
1031 message,
1032 }));
1033 }
1034 Ok(system_prompt) => {
1035 request.messages.push(LanguageModelRequestMessage {
1036 role: Role::System,
1037 content: vec![MessageContent::Text(system_prompt)],
1038 cache: true,
1039 });
1040 }
1041 }
1042 } else {
1043 let message = "Context for system prompt unexpectedly not ready.".into();
1044 log::error!("{message}");
1045 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1046 header: "Error generating system prompt".into(),
1047 message,
1048 }));
1049 }
1050
1051 for message in &self.messages {
1052 let mut request_message = LanguageModelRequestMessage {
1053 role: message.role,
1054 content: Vec::new(),
1055 cache: false,
1056 };
1057
1058 message
1059 .loaded_context
1060 .add_to_request_message(&mut request_message);
1061
1062 for segment in &message.segments {
1063 match segment {
1064 MessageSegment::Text(text) => {
1065 if !text.is_empty() {
1066 request_message
1067 .content
1068 .push(MessageContent::Text(text.into()));
1069 }
1070 }
1071 MessageSegment::Thinking { text, signature } => {
1072 if !text.is_empty() {
1073 request_message.content.push(MessageContent::Thinking {
1074 text: text.into(),
1075 signature: signature.clone(),
1076 });
1077 }
1078 }
1079 MessageSegment::RedactedThinking(data) => {
1080 request_message
1081 .content
1082 .push(MessageContent::RedactedThinking(data.clone()));
1083 }
1084 };
1085 }
1086
1087 self.tool_use
1088 .attach_tool_uses(message.id, &mut request_message);
1089
1090 request.messages.push(request_message);
1091
1092 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1093 request.messages.push(tool_results_message);
1094 }
1095 }
1096
1097 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1098 if let Some(last) = request.messages.last_mut() {
1099 last.cache = true;
1100 }
1101
1102 self.attached_tracked_files_state(&mut request.messages, cx);
1103
1104 request.tools = available_tools;
1105 request.mode = if model.supports_max_mode() {
1106 self.completion_mode
1107 } else {
1108 None
1109 };
1110
1111 request
1112 }
1113
1114 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1115 let mut request = LanguageModelRequest {
1116 thread_id: None,
1117 prompt_id: None,
1118 mode: None,
1119 messages: vec![],
1120 tools: Vec::new(),
1121 stop: Vec::new(),
1122 temperature: None,
1123 };
1124
1125 for message in &self.messages {
1126 let mut request_message = LanguageModelRequestMessage {
1127 role: message.role,
1128 content: Vec::new(),
1129 cache: false,
1130 };
1131
1132 for segment in &message.segments {
1133 match segment {
1134 MessageSegment::Text(text) => request_message
1135 .content
1136 .push(MessageContent::Text(text.clone())),
1137 MessageSegment::Thinking { .. } => {}
1138 MessageSegment::RedactedThinking(_) => {}
1139 }
1140 }
1141
1142 if request_message.content.is_empty() {
1143 continue;
1144 }
1145
1146 request.messages.push(request_message);
1147 }
1148
1149 request.messages.push(LanguageModelRequestMessage {
1150 role: Role::User,
1151 content: vec![MessageContent::Text(added_user_message)],
1152 cache: false,
1153 });
1154
1155 request
1156 }
1157
1158 fn attached_tracked_files_state(
1159 &self,
1160 messages: &mut Vec<LanguageModelRequestMessage>,
1161 cx: &App,
1162 ) {
1163 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1164
1165 let mut stale_message = String::new();
1166
1167 let action_log = self.action_log.read(cx);
1168
1169 for stale_file in action_log.stale_buffers(cx) {
1170 let Some(file) = stale_file.read(cx).file() else {
1171 continue;
1172 };
1173
1174 if stale_message.is_empty() {
1175 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1176 }
1177
1178 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1179 }
1180
1181 let mut content = Vec::with_capacity(2);
1182
1183 if !stale_message.is_empty() {
1184 content.push(stale_message.into());
1185 }
1186
1187 if !content.is_empty() {
1188 let context_message = LanguageModelRequestMessage {
1189 role: Role::User,
1190 content,
1191 cache: false,
1192 };
1193
1194 messages.push(context_message);
1195 }
1196 }
1197
1198 pub fn stream_completion(
1199 &mut self,
1200 request: LanguageModelRequest,
1201 model: Arc<dyn LanguageModel>,
1202 window: Option<AnyWindowHandle>,
1203 cx: &mut Context<Self>,
1204 ) {
1205 let pending_completion_id = post_inc(&mut self.completion_count);
1206 let mut request_callback_parameters = if self.request_callback.is_some() {
1207 Some((request.clone(), Vec::new()))
1208 } else {
1209 None
1210 };
1211 let prompt_id = self.last_prompt_id.clone();
1212 let tool_use_metadata = ToolUseMetadata {
1213 model: model.clone(),
1214 thread_id: self.id.clone(),
1215 prompt_id: prompt_id.clone(),
1216 };
1217
1218 let task = cx.spawn(async move |thread, cx| {
1219 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1220 let initial_token_usage =
1221 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1222 let stream_completion = async {
1223 let (mut events, usage) = stream_completion_future.await?;
1224
1225 let mut stop_reason = StopReason::EndTurn;
1226 let mut current_token_usage = TokenUsage::default();
1227
1228 if let Some(usage) = usage {
1229 thread
1230 .update(cx, |_thread, cx| {
1231 cx.emit(ThreadEvent::UsageUpdated(usage));
1232 })
1233 .ok();
1234 }
1235
1236 let mut request_assistant_message_id = None;
1237
1238 while let Some(event) = events.next().await {
1239 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1240 response_events
1241 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1242 }
1243
1244 thread.update(cx, |thread, cx| {
1245 let event = match event {
1246 Ok(event) => event,
1247 Err(LanguageModelCompletionError::BadInputJson {
1248 id,
1249 tool_name,
1250 raw_input: invalid_input_json,
1251 json_parse_error,
1252 }) => {
1253 thread.receive_invalid_tool_json(
1254 id,
1255 tool_name,
1256 invalid_input_json,
1257 json_parse_error,
1258 window,
1259 cx,
1260 );
1261 return Ok(());
1262 }
1263 Err(LanguageModelCompletionError::Other(error)) => {
1264 return Err(error);
1265 }
1266 };
1267
1268 match event {
1269 LanguageModelCompletionEvent::StartMessage { .. } => {
1270 request_assistant_message_id =
1271 Some(thread.insert_assistant_message(
1272 vec![MessageSegment::Text(String::new())],
1273 cx,
1274 ));
1275 }
1276 LanguageModelCompletionEvent::Stop(reason) => {
1277 stop_reason = reason;
1278 }
1279 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1280 thread.update_token_usage_at_last_message(token_usage);
1281 thread.cumulative_token_usage = thread.cumulative_token_usage
1282 + token_usage
1283 - current_token_usage;
1284 current_token_usage = token_usage;
1285 }
1286 LanguageModelCompletionEvent::Text(chunk) => {
1287 cx.emit(ThreadEvent::ReceivedTextChunk);
1288 if let Some(last_message) = thread.messages.last_mut() {
1289 if last_message.role == Role::Assistant
1290 && !thread.tool_use.has_tool_results(last_message.id)
1291 {
1292 last_message.push_text(&chunk);
1293 cx.emit(ThreadEvent::StreamedAssistantText(
1294 last_message.id,
1295 chunk,
1296 ));
1297 } else {
1298 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1299 // of a new Assistant response.
1300 //
1301 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1302 // will result in duplicating the text of the chunk in the rendered Markdown.
1303 request_assistant_message_id =
1304 Some(thread.insert_assistant_message(
1305 vec![MessageSegment::Text(chunk.to_string())],
1306 cx,
1307 ));
1308 };
1309 }
1310 }
1311 LanguageModelCompletionEvent::Thinking {
1312 text: chunk,
1313 signature,
1314 } => {
1315 if let Some(last_message) = thread.messages.last_mut() {
1316 if last_message.role == Role::Assistant
1317 && !thread.tool_use.has_tool_results(last_message.id)
1318 {
1319 last_message.push_thinking(&chunk, signature);
1320 cx.emit(ThreadEvent::StreamedAssistantThinking(
1321 last_message.id,
1322 chunk,
1323 ));
1324 } else {
1325 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1326 // of a new Assistant response.
1327 //
1328 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1329 // will result in duplicating the text of the chunk in the rendered Markdown.
1330 request_assistant_message_id =
1331 Some(thread.insert_assistant_message(
1332 vec![MessageSegment::Thinking {
1333 text: chunk.to_string(),
1334 signature,
1335 }],
1336 cx,
1337 ));
1338 };
1339 }
1340 }
1341 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1342 let last_assistant_message_id = request_assistant_message_id
1343 .unwrap_or_else(|| {
1344 let new_assistant_message_id =
1345 thread.insert_assistant_message(vec![], cx);
1346 request_assistant_message_id =
1347 Some(new_assistant_message_id);
1348 new_assistant_message_id
1349 });
1350
1351 let tool_use_id = tool_use.id.clone();
1352 let streamed_input = if tool_use.is_input_complete {
1353 None
1354 } else {
1355 Some((&tool_use.input).clone())
1356 };
1357
1358 let ui_text = thread.tool_use.request_tool_use(
1359 last_assistant_message_id,
1360 tool_use,
1361 tool_use_metadata.clone(),
1362 cx,
1363 );
1364
1365 if let Some(input) = streamed_input {
1366 cx.emit(ThreadEvent::StreamedToolUse {
1367 tool_use_id,
1368 ui_text,
1369 input,
1370 });
1371 }
1372 }
1373 }
1374
1375 thread.touch_updated_at();
1376 cx.emit(ThreadEvent::StreamedCompletion);
1377 cx.notify();
1378
1379 thread.auto_capture_telemetry(cx);
1380 Ok(())
1381 })??;
1382
1383 smol::future::yield_now().await;
1384 }
1385
1386 thread.update(cx, |thread, cx| {
1387 thread
1388 .pending_completions
1389 .retain(|completion| completion.id != pending_completion_id);
1390
1391 // If there is a response without tool use, summarize the message. Otherwise,
1392 // allow two tool uses before summarizing.
1393 if thread.summary.is_none()
1394 && thread.messages.len() >= 2
1395 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1396 {
1397 thread.summarize(cx);
1398 }
1399 })?;
1400
1401 anyhow::Ok(stop_reason)
1402 };
1403
1404 let result = stream_completion.await;
1405
1406 thread
1407 .update(cx, |thread, cx| {
1408 thread.finalize_pending_checkpoint(cx);
1409 match result.as_ref() {
1410 Ok(stop_reason) => match stop_reason {
1411 StopReason::ToolUse => {
1412 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1413 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1414 }
1415 StopReason::EndTurn => {}
1416 StopReason::MaxTokens => {}
1417 },
1418 Err(error) => {
1419 if error.is::<PaymentRequiredError>() {
1420 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1421 } else if error.is::<MaxMonthlySpendReachedError>() {
1422 cx.emit(ThreadEvent::ShowError(
1423 ThreadError::MaxMonthlySpendReached,
1424 ));
1425 } else if let Some(error) =
1426 error.downcast_ref::<ModelRequestLimitReachedError>()
1427 {
1428 cx.emit(ThreadEvent::ShowError(
1429 ThreadError::ModelRequestLimitReached { plan: error.plan },
1430 ));
1431 } else if let Some(known_error) =
1432 error.downcast_ref::<LanguageModelKnownError>()
1433 {
1434 match known_error {
1435 LanguageModelKnownError::ContextWindowLimitExceeded {
1436 tokens,
1437 } => {
1438 thread.exceeded_window_error = Some(ExceededWindowError {
1439 model_id: model.id(),
1440 token_count: *tokens,
1441 });
1442 cx.notify();
1443 }
1444 }
1445 } else {
1446 let error_message = error
1447 .chain()
1448 .map(|err| err.to_string())
1449 .collect::<Vec<_>>()
1450 .join("\n");
1451 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1452 header: "Error interacting with language model".into(),
1453 message: SharedString::from(error_message.clone()),
1454 }));
1455 }
1456
1457 thread.cancel_last_completion(window, cx);
1458 }
1459 }
1460 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1461
1462 if let Some((request_callback, (request, response_events))) = thread
1463 .request_callback
1464 .as_mut()
1465 .zip(request_callback_parameters.as_ref())
1466 {
1467 request_callback(request, response_events);
1468 }
1469
1470 thread.auto_capture_telemetry(cx);
1471
1472 if let Ok(initial_usage) = initial_token_usage {
1473 let usage = thread.cumulative_token_usage - initial_usage;
1474
1475 telemetry::event!(
1476 "Assistant Thread Completion",
1477 thread_id = thread.id().to_string(),
1478 prompt_id = prompt_id,
1479 model = model.telemetry_id(),
1480 model_provider = model.provider_id().to_string(),
1481 input_tokens = usage.input_tokens,
1482 output_tokens = usage.output_tokens,
1483 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1484 cache_read_input_tokens = usage.cache_read_input_tokens,
1485 );
1486 }
1487 })
1488 .ok();
1489 });
1490
1491 self.pending_completions.push(PendingCompletion {
1492 id: pending_completion_id,
1493 _task: task,
1494 });
1495 }
1496
1497 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1498 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1499 return;
1500 };
1501
1502 if !model.provider.is_authenticated(cx) {
1503 return;
1504 }
1505
1506 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1507 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1508 If the conversation is about a specific subject, include it in the title. \
1509 Be descriptive. DO NOT speak in the first person.";
1510
1511 let request = self.to_summarize_request(added_user_message.into());
1512
1513 self.pending_summary = cx.spawn(async move |this, cx| {
1514 async move {
1515 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1516 let (mut messages, usage) = stream.await?;
1517
1518 if let Some(usage) = usage {
1519 this.update(cx, |_thread, cx| {
1520 cx.emit(ThreadEvent::UsageUpdated(usage));
1521 })
1522 .ok();
1523 }
1524
1525 let mut new_summary = String::new();
1526 while let Some(message) = messages.stream.next().await {
1527 let text = message?;
1528 let mut lines = text.lines();
1529 new_summary.extend(lines.next());
1530
1531 // Stop if the LLM generated multiple lines.
1532 if lines.next().is_some() {
1533 break;
1534 }
1535 }
1536
1537 this.update(cx, |this, cx| {
1538 if !new_summary.is_empty() {
1539 this.summary = Some(new_summary.into());
1540 }
1541
1542 cx.emit(ThreadEvent::SummaryGenerated);
1543 })?;
1544
1545 anyhow::Ok(())
1546 }
1547 .log_err()
1548 .await
1549 });
1550 }
1551
1552 pub fn start_generating_detailed_summary_if_needed(
1553 &mut self,
1554 thread_store: WeakEntity<ThreadStore>,
1555 cx: &mut Context<Self>,
1556 ) {
1557 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1558 return;
1559 };
1560
1561 match &*self.detailed_summary_rx.borrow() {
1562 DetailedSummaryState::Generating { message_id, .. }
1563 | DetailedSummaryState::Generated { message_id, .. }
1564 if *message_id == last_message_id =>
1565 {
1566 // Already up-to-date
1567 return;
1568 }
1569 _ => {}
1570 }
1571
1572 let Some(ConfiguredModel { model, provider }) =
1573 LanguageModelRegistry::read_global(cx).thread_summary_model()
1574 else {
1575 return;
1576 };
1577
1578 if !provider.is_authenticated(cx) {
1579 return;
1580 }
1581
1582 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1583 1. A brief overview of what was discussed\n\
1584 2. Key facts or information discovered\n\
1585 3. Outcomes or conclusions reached\n\
1586 4. Any action items or next steps if any\n\
1587 Format it in Markdown with headings and bullet points.";
1588
1589 let request = self.to_summarize_request(added_user_message.into());
1590
1591 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1592 message_id: last_message_id,
1593 };
1594
1595 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1596 // be better to allow the old task to complete, but this would require logic for choosing
1597 // which result to prefer (the old task could complete after the new one, resulting in a
1598 // stale summary).
1599 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1600 let stream = model.stream_completion_text(request, &cx);
1601 let Some(mut messages) = stream.await.log_err() else {
1602 thread
1603 .update(cx, |thread, _cx| {
1604 *thread.detailed_summary_tx.borrow_mut() =
1605 DetailedSummaryState::NotGenerated;
1606 })
1607 .ok()?;
1608 return None;
1609 };
1610
1611 let mut new_detailed_summary = String::new();
1612
1613 while let Some(chunk) = messages.stream.next().await {
1614 if let Some(chunk) = chunk.log_err() {
1615 new_detailed_summary.push_str(&chunk);
1616 }
1617 }
1618
1619 thread
1620 .update(cx, |thread, _cx| {
1621 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1622 text: new_detailed_summary.into(),
1623 message_id: last_message_id,
1624 };
1625 })
1626 .ok()?;
1627
1628 // Save thread so its summary can be reused later
1629 if let Some(thread) = thread.upgrade() {
1630 if let Ok(Ok(save_task)) = cx.update(|cx| {
1631 thread_store
1632 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1633 }) {
1634 save_task.await.log_err();
1635 }
1636 }
1637
1638 Some(())
1639 });
1640 }
1641
1642 pub async fn wait_for_detailed_summary_or_text(
1643 this: &Entity<Self>,
1644 cx: &mut AsyncApp,
1645 ) -> Option<SharedString> {
1646 let mut detailed_summary_rx = this
1647 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1648 .ok()?;
1649 loop {
1650 match detailed_summary_rx.recv().await? {
1651 DetailedSummaryState::Generating { .. } => {}
1652 DetailedSummaryState::NotGenerated => {
1653 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1654 }
1655 DetailedSummaryState::Generated { text, .. } => return Some(text),
1656 }
1657 }
1658 }
1659
1660 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1661 self.detailed_summary_rx
1662 .borrow()
1663 .text()
1664 .unwrap_or_else(|| self.text().into())
1665 }
1666
1667 pub fn is_generating_detailed_summary(&self) -> bool {
1668 matches!(
1669 &*self.detailed_summary_rx.borrow(),
1670 DetailedSummaryState::Generating { .. }
1671 )
1672 }
1673
1674 pub fn use_pending_tools(
1675 &mut self,
1676 window: Option<AnyWindowHandle>,
1677 cx: &mut Context<Self>,
1678 model: Arc<dyn LanguageModel>,
1679 ) -> Vec<PendingToolUse> {
1680 self.auto_capture_telemetry(cx);
1681 let request = self.to_completion_request(model, cx);
1682 let messages = Arc::new(request.messages);
1683 let pending_tool_uses = self
1684 .tool_use
1685 .pending_tool_uses()
1686 .into_iter()
1687 .filter(|tool_use| tool_use.status.is_idle())
1688 .cloned()
1689 .collect::<Vec<_>>();
1690
1691 for tool_use in pending_tool_uses.iter() {
1692 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1693 if tool.needs_confirmation(&tool_use.input, cx)
1694 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1695 {
1696 self.tool_use.confirm_tool_use(
1697 tool_use.id.clone(),
1698 tool_use.ui_text.clone(),
1699 tool_use.input.clone(),
1700 messages.clone(),
1701 tool,
1702 );
1703 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1704 } else {
1705 self.run_tool(
1706 tool_use.id.clone(),
1707 tool_use.ui_text.clone(),
1708 tool_use.input.clone(),
1709 &messages,
1710 tool,
1711 window,
1712 cx,
1713 );
1714 }
1715 }
1716 }
1717
1718 pending_tool_uses
1719 }
1720
1721 pub fn receive_invalid_tool_json(
1722 &mut self,
1723 tool_use_id: LanguageModelToolUseId,
1724 tool_name: Arc<str>,
1725 invalid_json: Arc<str>,
1726 error: String,
1727 window: Option<AnyWindowHandle>,
1728 cx: &mut Context<Thread>,
1729 ) {
1730 log::error!("The model returned invalid input JSON: {invalid_json}");
1731
1732 let pending_tool_use = self.tool_use.insert_tool_output(
1733 tool_use_id.clone(),
1734 tool_name,
1735 Err(anyhow!("Error parsing input JSON: {error}")),
1736 cx,
1737 );
1738 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1739 pending_tool_use.ui_text.clone()
1740 } else {
1741 log::error!(
1742 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1743 );
1744 format!("Unknown tool {}", tool_use_id).into()
1745 };
1746
1747 cx.emit(ThreadEvent::InvalidToolInput {
1748 tool_use_id: tool_use_id.clone(),
1749 ui_text,
1750 invalid_input_json: invalid_json,
1751 });
1752
1753 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1754 }
1755
1756 pub fn run_tool(
1757 &mut self,
1758 tool_use_id: LanguageModelToolUseId,
1759 ui_text: impl Into<SharedString>,
1760 input: serde_json::Value,
1761 messages: &[LanguageModelRequestMessage],
1762 tool: Arc<dyn Tool>,
1763 window: Option<AnyWindowHandle>,
1764 cx: &mut Context<Thread>,
1765 ) {
1766 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1767 self.tool_use
1768 .run_pending_tool(tool_use_id, ui_text.into(), task);
1769 }
1770
1771 fn spawn_tool_use(
1772 &mut self,
1773 tool_use_id: LanguageModelToolUseId,
1774 messages: &[LanguageModelRequestMessage],
1775 input: serde_json::Value,
1776 tool: Arc<dyn Tool>,
1777 window: Option<AnyWindowHandle>,
1778 cx: &mut Context<Thread>,
1779 ) -> Task<()> {
1780 let tool_name: Arc<str> = tool.name().into();
1781
1782 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1783 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1784 } else {
1785 tool.run(
1786 input,
1787 messages,
1788 self.project.clone(),
1789 self.action_log.clone(),
1790 window,
1791 cx,
1792 )
1793 };
1794
1795 // Store the card separately if it exists
1796 if let Some(card) = tool_result.card.clone() {
1797 self.tool_use
1798 .insert_tool_result_card(tool_use_id.clone(), card);
1799 }
1800
1801 cx.spawn({
1802 async move |thread: WeakEntity<Thread>, cx| {
1803 let output = tool_result.output.await;
1804
1805 thread
1806 .update(cx, |thread, cx| {
1807 let pending_tool_use = thread.tool_use.insert_tool_output(
1808 tool_use_id.clone(),
1809 tool_name,
1810 output,
1811 cx,
1812 );
1813 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1814 })
1815 .ok();
1816 }
1817 })
1818 }
1819
1820 fn tool_finished(
1821 &mut self,
1822 tool_use_id: LanguageModelToolUseId,
1823 pending_tool_use: Option<PendingToolUse>,
1824 canceled: bool,
1825 window: Option<AnyWindowHandle>,
1826 cx: &mut Context<Self>,
1827 ) {
1828 if self.all_tools_finished() {
1829 let model_registry = LanguageModelRegistry::read_global(cx);
1830 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1831 if !canceled {
1832 self.send_to_model(model, window, cx);
1833 }
1834 self.auto_capture_telemetry(cx);
1835 }
1836 }
1837
1838 cx.emit(ThreadEvent::ToolFinished {
1839 tool_use_id,
1840 pending_tool_use,
1841 });
1842 }
1843
1844 /// Cancels the last pending completion, if there are any pending.
1845 ///
1846 /// Returns whether a completion was canceled.
1847 pub fn cancel_last_completion(
1848 &mut self,
1849 window: Option<AnyWindowHandle>,
1850 cx: &mut Context<Self>,
1851 ) -> bool {
1852 let mut canceled = self.pending_completions.pop().is_some();
1853
1854 for pending_tool_use in self.tool_use.cancel_pending() {
1855 canceled = true;
1856 self.tool_finished(
1857 pending_tool_use.id.clone(),
1858 Some(pending_tool_use),
1859 true,
1860 window,
1861 cx,
1862 );
1863 }
1864
1865 self.finalize_pending_checkpoint(cx);
1866 canceled
1867 }
1868
1869 pub fn feedback(&self) -> Option<ThreadFeedback> {
1870 self.feedback
1871 }
1872
1873 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1874 self.message_feedback.get(&message_id).copied()
1875 }
1876
1877 pub fn report_message_feedback(
1878 &mut self,
1879 message_id: MessageId,
1880 feedback: ThreadFeedback,
1881 cx: &mut Context<Self>,
1882 ) -> Task<Result<()>> {
1883 if self.message_feedback.get(&message_id) == Some(&feedback) {
1884 return Task::ready(Ok(()));
1885 }
1886
1887 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1888 let serialized_thread = self.serialize(cx);
1889 let thread_id = self.id().clone();
1890 let client = self.project.read(cx).client();
1891
1892 let enabled_tool_names: Vec<String> = self
1893 .tools()
1894 .read(cx)
1895 .enabled_tools(cx)
1896 .iter()
1897 .map(|tool| tool.name().to_string())
1898 .collect();
1899
1900 self.message_feedback.insert(message_id, feedback);
1901
1902 cx.notify();
1903
1904 let message_content = self
1905 .message(message_id)
1906 .map(|msg| msg.to_string())
1907 .unwrap_or_default();
1908
1909 cx.background_spawn(async move {
1910 let final_project_snapshot = final_project_snapshot.await;
1911 let serialized_thread = serialized_thread.await?;
1912 let thread_data =
1913 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1914
1915 let rating = match feedback {
1916 ThreadFeedback::Positive => "positive",
1917 ThreadFeedback::Negative => "negative",
1918 };
1919 telemetry::event!(
1920 "Assistant Thread Rated",
1921 rating,
1922 thread_id,
1923 enabled_tool_names,
1924 message_id = message_id.0,
1925 message_content,
1926 thread_data,
1927 final_project_snapshot
1928 );
1929 client.telemetry().flush_events().await;
1930
1931 Ok(())
1932 })
1933 }
1934
1935 pub fn report_feedback(
1936 &mut self,
1937 feedback: ThreadFeedback,
1938 cx: &mut Context<Self>,
1939 ) -> Task<Result<()>> {
1940 let last_assistant_message_id = self
1941 .messages
1942 .iter()
1943 .rev()
1944 .find(|msg| msg.role == Role::Assistant)
1945 .map(|msg| msg.id);
1946
1947 if let Some(message_id) = last_assistant_message_id {
1948 self.report_message_feedback(message_id, feedback, cx)
1949 } else {
1950 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1951 let serialized_thread = self.serialize(cx);
1952 let thread_id = self.id().clone();
1953 let client = self.project.read(cx).client();
1954 self.feedback = Some(feedback);
1955 cx.notify();
1956
1957 cx.background_spawn(async move {
1958 let final_project_snapshot = final_project_snapshot.await;
1959 let serialized_thread = serialized_thread.await?;
1960 let thread_data = serde_json::to_value(serialized_thread)
1961 .unwrap_or_else(|_| serde_json::Value::Null);
1962
1963 let rating = match feedback {
1964 ThreadFeedback::Positive => "positive",
1965 ThreadFeedback::Negative => "negative",
1966 };
1967 telemetry::event!(
1968 "Assistant Thread Rated",
1969 rating,
1970 thread_id,
1971 thread_data,
1972 final_project_snapshot
1973 );
1974 client.telemetry().flush_events().await;
1975
1976 Ok(())
1977 })
1978 }
1979 }
1980
1981 /// Create a snapshot of the current project state including git information and unsaved buffers.
1982 fn project_snapshot(
1983 project: Entity<Project>,
1984 cx: &mut Context<Self>,
1985 ) -> Task<Arc<ProjectSnapshot>> {
1986 let git_store = project.read(cx).git_store().clone();
1987 let worktree_snapshots: Vec<_> = project
1988 .read(cx)
1989 .visible_worktrees(cx)
1990 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1991 .collect();
1992
1993 cx.spawn(async move |_, cx| {
1994 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1995
1996 let mut unsaved_buffers = Vec::new();
1997 cx.update(|app_cx| {
1998 let buffer_store = project.read(app_cx).buffer_store();
1999 for buffer_handle in buffer_store.read(app_cx).buffers() {
2000 let buffer = buffer_handle.read(app_cx);
2001 if buffer.is_dirty() {
2002 if let Some(file) = buffer.file() {
2003 let path = file.path().to_string_lossy().to_string();
2004 unsaved_buffers.push(path);
2005 }
2006 }
2007 }
2008 })
2009 .ok();
2010
2011 Arc::new(ProjectSnapshot {
2012 worktree_snapshots,
2013 unsaved_buffer_paths: unsaved_buffers,
2014 timestamp: Utc::now(),
2015 })
2016 })
2017 }
2018
2019 fn worktree_snapshot(
2020 worktree: Entity<project::Worktree>,
2021 git_store: Entity<GitStore>,
2022 cx: &App,
2023 ) -> Task<WorktreeSnapshot> {
2024 cx.spawn(async move |cx| {
2025 // Get worktree path and snapshot
2026 let worktree_info = cx.update(|app_cx| {
2027 let worktree = worktree.read(app_cx);
2028 let path = worktree.abs_path().to_string_lossy().to_string();
2029 let snapshot = worktree.snapshot();
2030 (path, snapshot)
2031 });
2032
2033 let Ok((worktree_path, _snapshot)) = worktree_info else {
2034 return WorktreeSnapshot {
2035 worktree_path: String::new(),
2036 git_state: None,
2037 };
2038 };
2039
2040 let git_state = git_store
2041 .update(cx, |git_store, cx| {
2042 git_store
2043 .repositories()
2044 .values()
2045 .find(|repo| {
2046 repo.read(cx)
2047 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2048 .is_some()
2049 })
2050 .cloned()
2051 })
2052 .ok()
2053 .flatten()
2054 .map(|repo| {
2055 repo.update(cx, |repo, _| {
2056 let current_branch =
2057 repo.branch.as_ref().map(|branch| branch.name.to_string());
2058 repo.send_job(None, |state, _| async move {
2059 let RepositoryState::Local { backend, .. } = state else {
2060 return GitState {
2061 remote_url: None,
2062 head_sha: None,
2063 current_branch,
2064 diff: None,
2065 };
2066 };
2067
2068 let remote_url = backend.remote_url("origin");
2069 let head_sha = backend.head_sha().await;
2070 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2071
2072 GitState {
2073 remote_url,
2074 head_sha,
2075 current_branch,
2076 diff,
2077 }
2078 })
2079 })
2080 });
2081
2082 let git_state = match git_state {
2083 Some(git_state) => match git_state.ok() {
2084 Some(git_state) => git_state.await.ok(),
2085 None => None,
2086 },
2087 None => None,
2088 };
2089
2090 WorktreeSnapshot {
2091 worktree_path,
2092 git_state,
2093 }
2094 })
2095 }
2096
2097 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2098 let mut markdown = Vec::new();
2099
2100 if let Some(summary) = self.summary() {
2101 writeln!(markdown, "# {summary}\n")?;
2102 };
2103
2104 for message in self.messages() {
2105 writeln!(
2106 markdown,
2107 "## {role}\n",
2108 role = match message.role {
2109 Role::User => "User",
2110 Role::Assistant => "Assistant",
2111 Role::System => "System",
2112 }
2113 )?;
2114
2115 if !message.loaded_context.text.is_empty() {
2116 writeln!(markdown, "{}", message.loaded_context.text)?;
2117 }
2118
2119 if !message.loaded_context.images.is_empty() {
2120 writeln!(
2121 markdown,
2122 "\n{} images attached as context.\n",
2123 message.loaded_context.images.len()
2124 )?;
2125 }
2126
2127 for segment in &message.segments {
2128 match segment {
2129 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2130 MessageSegment::Thinking { text, .. } => {
2131 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2132 }
2133 MessageSegment::RedactedThinking(_) => {}
2134 }
2135 }
2136
2137 for tool_use in self.tool_uses_for_message(message.id, cx) {
2138 writeln!(
2139 markdown,
2140 "**Use Tool: {} ({})**",
2141 tool_use.name, tool_use.id
2142 )?;
2143 writeln!(markdown, "```json")?;
2144 writeln!(
2145 markdown,
2146 "{}",
2147 serde_json::to_string_pretty(&tool_use.input)?
2148 )?;
2149 writeln!(markdown, "```")?;
2150 }
2151
2152 for tool_result in self.tool_results_for_message(message.id) {
2153 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2154 if tool_result.is_error {
2155 write!(markdown, " (Error)")?;
2156 }
2157
2158 writeln!(markdown, "**\n")?;
2159 writeln!(markdown, "{}", tool_result.content)?;
2160 }
2161 }
2162
2163 Ok(String::from_utf8_lossy(&markdown).to_string())
2164 }
2165
2166 pub fn keep_edits_in_range(
2167 &mut self,
2168 buffer: Entity<language::Buffer>,
2169 buffer_range: Range<language::Anchor>,
2170 cx: &mut Context<Self>,
2171 ) {
2172 self.action_log.update(cx, |action_log, cx| {
2173 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2174 });
2175 }
2176
2177 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2178 self.action_log
2179 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2180 }
2181
2182 pub fn reject_edits_in_ranges(
2183 &mut self,
2184 buffer: Entity<language::Buffer>,
2185 buffer_ranges: Vec<Range<language::Anchor>>,
2186 cx: &mut Context<Self>,
2187 ) -> Task<Result<()>> {
2188 self.action_log.update(cx, |action_log, cx| {
2189 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2190 })
2191 }
2192
2193 pub fn action_log(&self) -> &Entity<ActionLog> {
2194 &self.action_log
2195 }
2196
2197 pub fn project(&self) -> &Entity<Project> {
2198 &self.project
2199 }
2200
2201 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2202 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2203 return;
2204 }
2205
2206 let now = Instant::now();
2207 if let Some(last) = self.last_auto_capture_at {
2208 if now.duration_since(last).as_secs() < 10 {
2209 return;
2210 }
2211 }
2212
2213 self.last_auto_capture_at = Some(now);
2214
2215 let thread_id = self.id().clone();
2216 let github_login = self
2217 .project
2218 .read(cx)
2219 .user_store()
2220 .read(cx)
2221 .current_user()
2222 .map(|user| user.github_login.clone());
2223 let client = self.project.read(cx).client().clone();
2224 let serialize_task = self.serialize(cx);
2225
2226 cx.background_executor()
2227 .spawn(async move {
2228 if let Ok(serialized_thread) = serialize_task.await {
2229 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2230 telemetry::event!(
2231 "Agent Thread Auto-Captured",
2232 thread_id = thread_id.to_string(),
2233 thread_data = thread_data,
2234 auto_capture_reason = "tracked_user",
2235 github_login = github_login
2236 );
2237
2238 client.telemetry().flush_events().await;
2239 }
2240 }
2241 })
2242 .detach();
2243 }
2244
2245 pub fn cumulative_token_usage(&self) -> TokenUsage {
2246 self.cumulative_token_usage
2247 }
2248
2249 pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
2250 let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
2251 return TotalTokenUsage::default();
2252 };
2253
2254 let max = model.model.max_token_count();
2255
2256 let index = self
2257 .messages
2258 .iter()
2259 .position(|msg| msg.id == message_id)
2260 .unwrap_or(0);
2261
2262 if index == 0 {
2263 return TotalTokenUsage { total: 0, max };
2264 }
2265
2266 let token_usage = &self
2267 .request_token_usage
2268 .get(index - 1)
2269 .cloned()
2270 .unwrap_or_default();
2271
2272 TotalTokenUsage {
2273 total: token_usage.total_tokens() as usize,
2274 max,
2275 }
2276 }
2277
2278 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
2279 let model_registry = LanguageModelRegistry::read_global(cx);
2280 let Some(model) = model_registry.default_model() else {
2281 return TotalTokenUsage::default();
2282 };
2283
2284 let max = model.model.max_token_count();
2285
2286 if let Some(exceeded_error) = &self.exceeded_window_error {
2287 if model.model.id() == exceeded_error.model_id {
2288 return TotalTokenUsage {
2289 total: exceeded_error.token_count,
2290 max,
2291 };
2292 }
2293 }
2294
2295 let total = self
2296 .token_usage_at_last_message()
2297 .unwrap_or_default()
2298 .total_tokens() as usize;
2299
2300 TotalTokenUsage { total, max }
2301 }
2302
2303 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2304 self.request_token_usage
2305 .get(self.messages.len().saturating_sub(1))
2306 .or_else(|| self.request_token_usage.last())
2307 .cloned()
2308 }
2309
2310 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2311 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2312 self.request_token_usage
2313 .resize(self.messages.len(), placeholder);
2314
2315 if let Some(last) = self.request_token_usage.last_mut() {
2316 *last = token_usage;
2317 }
2318 }
2319
2320 pub fn deny_tool_use(
2321 &mut self,
2322 tool_use_id: LanguageModelToolUseId,
2323 tool_name: Arc<str>,
2324 window: Option<AnyWindowHandle>,
2325 cx: &mut Context<Self>,
2326 ) {
2327 let err = Err(anyhow::anyhow!(
2328 "Permission to run tool action denied by user"
2329 ));
2330
2331 self.tool_use
2332 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2333 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2334 }
2335}
2336
2337#[derive(Debug, Clone, Error)]
2338pub enum ThreadError {
2339 #[error("Payment required")]
2340 PaymentRequired,
2341 #[error("Max monthly spend reached")]
2342 MaxMonthlySpendReached,
2343 #[error("Model request limit reached")]
2344 ModelRequestLimitReached { plan: Plan },
2345 #[error("Message {header}: {message}")]
2346 Message {
2347 header: SharedString,
2348 message: SharedString,
2349 },
2350}
2351
2352#[derive(Debug, Clone)]
2353pub enum ThreadEvent {
2354 ShowError(ThreadError),
2355 UsageUpdated(RequestUsage),
2356 StreamedCompletion,
2357 ReceivedTextChunk,
2358 StreamedAssistantText(MessageId, String),
2359 StreamedAssistantThinking(MessageId, String),
2360 StreamedToolUse {
2361 tool_use_id: LanguageModelToolUseId,
2362 ui_text: Arc<str>,
2363 input: serde_json::Value,
2364 },
2365 InvalidToolInput {
2366 tool_use_id: LanguageModelToolUseId,
2367 ui_text: Arc<str>,
2368 invalid_input_json: Arc<str>,
2369 },
2370 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2371 MessageAdded(MessageId),
2372 MessageEdited(MessageId),
2373 MessageDeleted(MessageId),
2374 SummaryGenerated,
2375 SummaryChanged,
2376 UsePendingTools {
2377 tool_uses: Vec<PendingToolUse>,
2378 },
2379 ToolFinished {
2380 #[allow(unused)]
2381 tool_use_id: LanguageModelToolUseId,
2382 /// The pending tool use that corresponds to this tool.
2383 pending_tool_use: Option<PendingToolUse>,
2384 },
2385 CheckpointChanged,
2386 ToolConfirmationNeeded,
2387}
2388
2389impl EventEmitter<ThreadEvent> for Thread {}
2390
2391struct PendingCompletion {
2392 id: usize,
2393 _task: Task<()>,
2394}
2395
2396#[cfg(test)]
2397mod tests {
2398 use super::*;
2399 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2400 use assistant_settings::AssistantSettings;
2401 use assistant_tool::ToolRegistry;
2402 use context_server::ContextServerSettings;
2403 use editor::EditorSettings;
2404 use gpui::TestAppContext;
2405 use language_model::fake_provider::FakeLanguageModel;
2406 use project::{FakeFs, Project};
2407 use prompt_store::PromptBuilder;
2408 use serde_json::json;
2409 use settings::{Settings, SettingsStore};
2410 use std::sync::Arc;
2411 use theme::ThemeSettings;
2412 use util::path;
2413 use workspace::Workspace;
2414
2415 #[gpui::test]
2416 async fn test_message_with_context(cx: &mut TestAppContext) {
2417 init_test_settings(cx);
2418
2419 let project = create_test_project(
2420 cx,
2421 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2422 )
2423 .await;
2424
2425 let (_workspace, _thread_store, thread, context_store, model) =
2426 setup_test_environment(cx, project.clone()).await;
2427
2428 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2429 .await
2430 .unwrap();
2431
2432 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2433 let loaded_context = cx
2434 .update(|cx| load_context(vec![context], &project, &None, cx))
2435 .await;
2436
2437 // Insert user message with context
2438 let message_id = thread.update(cx, |thread, cx| {
2439 thread.insert_user_message("Please explain this code", loaded_context, None, cx)
2440 });
2441
2442 // Check content and context in message object
2443 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2444
2445 // Use different path format strings based on platform for the test
2446 #[cfg(windows)]
2447 let path_part = r"test\code.rs";
2448 #[cfg(not(windows))]
2449 let path_part = "test/code.rs";
2450
2451 let expected_context = format!(
2452 r#"
2453<context>
2454The following items were attached by the user. You don't need to use other tools to read them.
2455
2456<files>
2457```rs {path_part}
2458fn main() {{
2459 println!("Hello, world!");
2460}}
2461```
2462</files>
2463</context>
2464"#
2465 );
2466
2467 assert_eq!(message.role, Role::User);
2468 assert_eq!(message.segments.len(), 1);
2469 assert_eq!(
2470 message.segments[0],
2471 MessageSegment::Text("Please explain this code".to_string())
2472 );
2473 assert_eq!(message.loaded_context.text, expected_context);
2474
2475 // Check message in request
2476 let request = thread.update(cx, |thread, cx| {
2477 thread.to_completion_request(model.clone(), cx)
2478 });
2479
2480 assert_eq!(request.messages.len(), 2);
2481 let expected_full_message = format!("{}Please explain this code", expected_context);
2482 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2483 }
2484
2485 #[gpui::test]
2486 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2487 init_test_settings(cx);
2488
2489 let project = create_test_project(
2490 cx,
2491 json!({
2492 "file1.rs": "fn function1() {}\n",
2493 "file2.rs": "fn function2() {}\n",
2494 "file3.rs": "fn function3() {}\n",
2495 }),
2496 )
2497 .await;
2498
2499 let (_, _thread_store, thread, context_store, model) =
2500 setup_test_environment(cx, project.clone()).await;
2501
2502 // First message with context 1
2503 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2504 .await
2505 .unwrap();
2506 let new_contexts = context_store.update(cx, |store, cx| {
2507 store.new_context_for_thread(thread.read(cx))
2508 });
2509 assert_eq!(new_contexts.len(), 1);
2510 let loaded_context = cx
2511 .update(|cx| load_context(new_contexts, &project, &None, cx))
2512 .await;
2513 let message1_id = thread.update(cx, |thread, cx| {
2514 thread.insert_user_message("Message 1", loaded_context, None, cx)
2515 });
2516
2517 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2518 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2519 .await
2520 .unwrap();
2521 let new_contexts = context_store.update(cx, |store, cx| {
2522 store.new_context_for_thread(thread.read(cx))
2523 });
2524 assert_eq!(new_contexts.len(), 1);
2525 let loaded_context = cx
2526 .update(|cx| load_context(new_contexts, &project, &None, cx))
2527 .await;
2528 let message2_id = thread.update(cx, |thread, cx| {
2529 thread.insert_user_message("Message 2", loaded_context, None, cx)
2530 });
2531
2532 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2533 //
2534 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2535 .await
2536 .unwrap();
2537 let new_contexts = context_store.update(cx, |store, cx| {
2538 store.new_context_for_thread(thread.read(cx))
2539 });
2540 assert_eq!(new_contexts.len(), 1);
2541 let loaded_context = cx
2542 .update(|cx| load_context(new_contexts, &project, &None, cx))
2543 .await;
2544 let message3_id = thread.update(cx, |thread, cx| {
2545 thread.insert_user_message("Message 3", loaded_context, None, cx)
2546 });
2547
2548 // Check what contexts are included in each message
2549 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2550 (
2551 thread.message(message1_id).unwrap().clone(),
2552 thread.message(message2_id).unwrap().clone(),
2553 thread.message(message3_id).unwrap().clone(),
2554 )
2555 });
2556
2557 // First message should include context 1
2558 assert!(message1.loaded_context.text.contains("file1.rs"));
2559
2560 // Second message should include only context 2 (not 1)
2561 assert!(!message2.loaded_context.text.contains("file1.rs"));
2562 assert!(message2.loaded_context.text.contains("file2.rs"));
2563
2564 // Third message should include only context 3 (not 1 or 2)
2565 assert!(!message3.loaded_context.text.contains("file1.rs"));
2566 assert!(!message3.loaded_context.text.contains("file2.rs"));
2567 assert!(message3.loaded_context.text.contains("file3.rs"));
2568
2569 // Check entire request to make sure all contexts are properly included
2570 let request = thread.update(cx, |thread, cx| {
2571 thread.to_completion_request(model.clone(), cx)
2572 });
2573
2574 // The request should contain all 3 messages
2575 assert_eq!(request.messages.len(), 4);
2576
2577 // Check that the contexts are properly formatted in each message
2578 assert!(request.messages[1].string_contents().contains("file1.rs"));
2579 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2580 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2581
2582 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2583 assert!(request.messages[2].string_contents().contains("file2.rs"));
2584 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2585
2586 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2587 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2588 assert!(request.messages[3].string_contents().contains("file3.rs"));
2589 }
2590
2591 #[gpui::test]
2592 async fn test_message_without_files(cx: &mut TestAppContext) {
2593 init_test_settings(cx);
2594
2595 let project = create_test_project(
2596 cx,
2597 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2598 )
2599 .await;
2600
2601 let (_, _thread_store, thread, _context_store, model) =
2602 setup_test_environment(cx, project.clone()).await;
2603
2604 // Insert user message without any context (empty context vector)
2605 let message_id = thread.update(cx, |thread, cx| {
2606 thread.insert_user_message(
2607 "What is the best way to learn Rust?",
2608 ContextLoadResult::default(),
2609 None,
2610 cx,
2611 )
2612 });
2613
2614 // Check content and context in message object
2615 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2616
2617 // Context should be empty when no files are included
2618 assert_eq!(message.role, Role::User);
2619 assert_eq!(message.segments.len(), 1);
2620 assert_eq!(
2621 message.segments[0],
2622 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2623 );
2624 assert_eq!(message.loaded_context.text, "");
2625
2626 // Check message in request
2627 let request = thread.update(cx, |thread, cx| {
2628 thread.to_completion_request(model.clone(), cx)
2629 });
2630
2631 assert_eq!(request.messages.len(), 2);
2632 assert_eq!(
2633 request.messages[1].string_contents(),
2634 "What is the best way to learn Rust?"
2635 );
2636
2637 // Add second message, also without context
2638 let message2_id = thread.update(cx, |thread, cx| {
2639 thread.insert_user_message(
2640 "Are there any good books?",
2641 ContextLoadResult::default(),
2642 None,
2643 cx,
2644 )
2645 });
2646
2647 let message2 =
2648 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2649 assert_eq!(message2.loaded_context.text, "");
2650
2651 // Check that both messages appear in the request
2652 let request = thread.update(cx, |thread, cx| {
2653 thread.to_completion_request(model.clone(), cx)
2654 });
2655
2656 assert_eq!(request.messages.len(), 3);
2657 assert_eq!(
2658 request.messages[1].string_contents(),
2659 "What is the best way to learn Rust?"
2660 );
2661 assert_eq!(
2662 request.messages[2].string_contents(),
2663 "Are there any good books?"
2664 );
2665 }
2666
2667 #[gpui::test]
2668 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2669 init_test_settings(cx);
2670
2671 let project = create_test_project(
2672 cx,
2673 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2674 )
2675 .await;
2676
2677 let (_workspace, _thread_store, thread, context_store, model) =
2678 setup_test_environment(cx, project.clone()).await;
2679
2680 // Open buffer and add it to context
2681 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2682 .await
2683 .unwrap();
2684
2685 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2686 let loaded_context = cx
2687 .update(|cx| load_context(vec![context], &project, &None, cx))
2688 .await;
2689
2690 // Insert user message with the buffer as context
2691 thread.update(cx, |thread, cx| {
2692 thread.insert_user_message("Explain this code", loaded_context, None, cx)
2693 });
2694
2695 // Create a request and check that it doesn't have a stale buffer warning yet
2696 let initial_request = thread.update(cx, |thread, cx| {
2697 thread.to_completion_request(model.clone(), cx)
2698 });
2699
2700 // Make sure we don't have a stale file warning yet
2701 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2702 msg.string_contents()
2703 .contains("These files changed since last read:")
2704 });
2705 assert!(
2706 !has_stale_warning,
2707 "Should not have stale buffer warning before buffer is modified"
2708 );
2709
2710 // Modify the buffer
2711 buffer.update(cx, |buffer, cx| {
2712 // Find a position at the end of line 1
2713 buffer.edit(
2714 [(1..1, "\n println!(\"Added a new line\");\n")],
2715 None,
2716 cx,
2717 );
2718 });
2719
2720 // Insert another user message without context
2721 thread.update(cx, |thread, cx| {
2722 thread.insert_user_message(
2723 "What does the code do now?",
2724 ContextLoadResult::default(),
2725 None,
2726 cx,
2727 )
2728 });
2729
2730 // Create a new request and check for the stale buffer warning
2731 let new_request = thread.update(cx, |thread, cx| {
2732 thread.to_completion_request(model.clone(), cx)
2733 });
2734
2735 // We should have a stale file warning as the last message
2736 let last_message = new_request
2737 .messages
2738 .last()
2739 .expect("Request should have messages");
2740
2741 // The last message should be the stale buffer notification
2742 assert_eq!(last_message.role, Role::User);
2743
2744 // Check the exact content of the message
2745 let expected_content = "These files changed since last read:\n- code.rs\n";
2746 assert_eq!(
2747 last_message.string_contents(),
2748 expected_content,
2749 "Last message should be exactly the stale buffer notification"
2750 );
2751 }
2752
2753 fn init_test_settings(cx: &mut TestAppContext) {
2754 cx.update(|cx| {
2755 let settings_store = SettingsStore::test(cx);
2756 cx.set_global(settings_store);
2757 language::init(cx);
2758 Project::init_settings(cx);
2759 AssistantSettings::register(cx);
2760 prompt_store::init(cx);
2761 thread_store::init(cx);
2762 workspace::init_settings(cx);
2763 ThemeSettings::register(cx);
2764 ContextServerSettings::register(cx);
2765 EditorSettings::register(cx);
2766 ToolRegistry::default_global(cx);
2767 });
2768 }
2769
2770 // Helper to create a test project with test files
2771 async fn create_test_project(
2772 cx: &mut TestAppContext,
2773 files: serde_json::Value,
2774 ) -> Entity<Project> {
2775 let fs = FakeFs::new(cx.executor());
2776 fs.insert_tree(path!("/test"), files).await;
2777 Project::test(fs, [path!("/test").as_ref()], cx).await
2778 }
2779
2780 async fn setup_test_environment(
2781 cx: &mut TestAppContext,
2782 project: Entity<Project>,
2783 ) -> (
2784 Entity<Workspace>,
2785 Entity<ThreadStore>,
2786 Entity<Thread>,
2787 Entity<ContextStore>,
2788 Arc<dyn LanguageModel>,
2789 ) {
2790 let (workspace, cx) =
2791 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2792
2793 let thread_store = cx
2794 .update(|_, cx| {
2795 ThreadStore::load(
2796 project.clone(),
2797 cx.new(|_| ToolWorkingSet::default()),
2798 None,
2799 Arc::new(PromptBuilder::new(None).unwrap()),
2800 cx,
2801 )
2802 })
2803 .await
2804 .unwrap();
2805
2806 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2807 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2808
2809 let model = FakeLanguageModel::default();
2810 let model: Arc<dyn LanguageModel> = Arc::new(model);
2811
2812 (workspace, thread_store, thread, context_store, model)
2813 }
2814
2815 async fn add_file_to_context(
2816 project: &Entity<Project>,
2817 context_store: &Entity<ContextStore>,
2818 path: &str,
2819 cx: &mut TestAppContext,
2820 ) -> Result<Entity<language::Buffer>> {
2821 let buffer_path = project
2822 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2823 .unwrap();
2824
2825 let buffer = project
2826 .update(cx, |project, cx| {
2827 project.open_buffer(buffer_path.clone(), cx)
2828 })
2829 .await
2830 .unwrap();
2831
2832 context_store.update(cx, |context_store, cx| {
2833 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
2834 });
2835
2836 Ok(buffer)
2837 }
2838}