1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use util::{ResultExt as _, TryFutureExt as _, post_inc};
39use uuid::Uuid;
40use zed_llm_client::CompletionMode;
41
42use crate::ThreadStore;
43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
44use crate::thread_store::{
45 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
46 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
47};
48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
101#[derive(Clone, Debug)]
102pub struct MessageCrease {
103 pub range: Range<usize>,
104 pub metadata: CreaseMetadata,
105 /// None for a deserialized message, Some otherwise.
106 pub context: Option<AgentContextHandle>,
107}
108
109/// A message in a [`Thread`].
110#[derive(Debug, Clone)]
111pub struct Message {
112 pub id: MessageId,
113 pub role: Role,
114 pub segments: Vec<MessageSegment>,
115 pub loaded_context: LoadedContext,
116 pub creases: Vec<MessageCrease>,
117}
118
119impl Message {
120 /// Returns whether the message contains any meaningful text that should be displayed
121 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
122 pub fn should_display_content(&self) -> bool {
123 self.segments.iter().all(|segment| segment.should_display())
124 }
125
126 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
127 if let Some(MessageSegment::Thinking {
128 text: segment,
129 signature: current_signature,
130 }) = self.segments.last_mut()
131 {
132 if let Some(signature) = signature {
133 *current_signature = Some(signature);
134 }
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Thinking {
138 text: text.to_string(),
139 signature,
140 });
141 }
142 }
143
144 pub fn push_text(&mut self, text: &str) {
145 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
146 segment.push_str(text);
147 } else {
148 self.segments.push(MessageSegment::Text(text.to_string()));
149 }
150 }
151
152 pub fn to_string(&self) -> String {
153 let mut result = String::new();
154
155 if !self.loaded_context.text.is_empty() {
156 result.push_str(&self.loaded_context.text);
157 }
158
159 for segment in &self.segments {
160 match segment {
161 MessageSegment::Text(text) => result.push_str(text),
162 MessageSegment::Thinking { text, .. } => {
163 result.push_str("<think>\n");
164 result.push_str(text);
165 result.push_str("\n</think>");
166 }
167 MessageSegment::RedactedThinking(_) => {}
168 }
169 }
170
171 result
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum MessageSegment {
177 Text(String),
178 Thinking {
179 text: String,
180 signature: Option<String>,
181 },
182 RedactedThinking(Vec<u8>),
183}
184
185impl MessageSegment {
186 pub fn should_display(&self) -> bool {
187 match self {
188 Self::Text(text) => text.is_empty(),
189 Self::Thinking { text, .. } => text.is_empty(),
190 Self::RedactedThinking(_) => false,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ProjectSnapshot {
197 pub worktree_snapshots: Vec<WorktreeSnapshot>,
198 pub unsaved_buffer_paths: Vec<String>,
199 pub timestamp: DateTime<Utc>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WorktreeSnapshot {
204 pub worktree_path: String,
205 pub git_state: Option<GitState>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GitState {
210 pub remote_url: Option<String>,
211 pub head_sha: Option<String>,
212 pub current_branch: Option<String>,
213 pub diff: Option<String>,
214}
215
216#[derive(Clone)]
217pub struct ThreadCheckpoint {
218 message_id: MessageId,
219 git_checkpoint: GitStoreCheckpoint,
220}
221
222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
223pub enum ThreadFeedback {
224 Positive,
225 Negative,
226}
227
228pub enum LastRestoreCheckpoint {
229 Pending {
230 message_id: MessageId,
231 },
232 Error {
233 message_id: MessageId,
234 error: String,
235 },
236}
237
238impl LastRestoreCheckpoint {
239 pub fn message_id(&self) -> MessageId {
240 match self {
241 LastRestoreCheckpoint::Pending { message_id } => *message_id,
242 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub enum DetailedSummaryState {
249 #[default]
250 NotGenerated,
251 Generating {
252 message_id: MessageId,
253 },
254 Generated {
255 text: SharedString,
256 message_id: MessageId,
257 },
258}
259
260impl DetailedSummaryState {
261 fn text(&self) -> Option<SharedString> {
262 if let Self::Generated { text, .. } = self {
263 Some(text.clone())
264 } else {
265 None
266 }
267 }
268}
269
270#[derive(Default)]
271pub struct TotalTokenUsage {
272 pub total: usize,
273 pub max: usize,
274}
275
276impl TotalTokenUsage {
277 pub fn ratio(&self) -> TokenUsageRatio {
278 #[cfg(debug_assertions)]
279 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
280 .unwrap_or("0.8".to_string())
281 .parse()
282 .unwrap();
283 #[cfg(not(debug_assertions))]
284 let warning_threshold: f32 = 0.8;
285
286 // When the maximum is unknown because there is no selected model,
287 // avoid showing the token limit warning.
288 if self.max == 0 {
289 TokenUsageRatio::Normal
290 } else if self.total >= self.max {
291 TokenUsageRatio::Exceeded
292 } else if self.total as f32 / self.max as f32 >= warning_threshold {
293 TokenUsageRatio::Warning
294 } else {
295 TokenUsageRatio::Normal
296 }
297 }
298
299 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
300 TotalTokenUsage {
301 total: self.total + tokens,
302 max: self.max,
303 }
304 }
305}
306
307#[derive(Debug, Default, PartialEq, Eq)]
308pub enum TokenUsageRatio {
309 #[default]
310 Normal,
311 Warning,
312 Exceeded,
313}
314
315fn default_completion_mode(cx: &App) -> CompletionMode {
316 if cx.is_staff() {
317 CompletionMode::Max
318 } else {
319 CompletionMode::Normal
320 }
321}
322
323#[derive(Debug, Clone, Copy)]
324pub enum QueueState {
325 Sending,
326 Queued { position: usize },
327 Started,
328}
329
330/// A thread of conversation with the LLM.
331pub struct Thread {
332 id: ThreadId,
333 updated_at: DateTime<Utc>,
334 summary: Option<SharedString>,
335 pending_summary: Task<Option<()>>,
336 detailed_summary_task: Task<Option<()>>,
337 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
338 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
339 completion_mode: CompletionMode,
340 messages: Vec<Message>,
341 next_message_id: MessageId,
342 last_prompt_id: PromptId,
343 project_context: SharedProjectContext,
344 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
345 completion_count: usize,
346 pending_completions: Vec<PendingCompletion>,
347 project: Entity<Project>,
348 prompt_builder: Arc<PromptBuilder>,
349 tools: Entity<ToolWorkingSet>,
350 tool_use: ToolUseState,
351 action_log: Entity<ActionLog>,
352 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
353 pending_checkpoint: Option<ThreadCheckpoint>,
354 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
355 request_token_usage: Vec<TokenUsage>,
356 cumulative_token_usage: TokenUsage,
357 exceeded_window_error: Option<ExceededWindowError>,
358 feedback: Option<ThreadFeedback>,
359 message_feedback: HashMap<MessageId, ThreadFeedback>,
360 last_auto_capture_at: Option<Instant>,
361 last_received_chunk_at: Option<Instant>,
362 request_callback: Option<
363 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
364 >,
365 remaining_turns: u32,
366 configured_model: Option<ConfiguredModel>,
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ExceededWindowError {
371 /// Model used when last message exceeded context window
372 model_id: LanguageModelId,
373 /// Token count including last message
374 token_count: usize,
375}
376
377impl Thread {
378 pub fn new(
379 project: Entity<Project>,
380 tools: Entity<ToolWorkingSet>,
381 prompt_builder: Arc<PromptBuilder>,
382 system_prompt: SharedProjectContext,
383 cx: &mut Context<Self>,
384 ) -> Self {
385 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
386 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
387
388 Self {
389 id: ThreadId::new(),
390 updated_at: Utc::now(),
391 summary: None,
392 pending_summary: Task::ready(None),
393 detailed_summary_task: Task::ready(None),
394 detailed_summary_tx,
395 detailed_summary_rx,
396 completion_mode: default_completion_mode(cx),
397 messages: Vec::new(),
398 next_message_id: MessageId(0),
399 last_prompt_id: PromptId::new(),
400 project_context: system_prompt,
401 checkpoints_by_message: HashMap::default(),
402 completion_count: 0,
403 pending_completions: Vec::new(),
404 project: project.clone(),
405 prompt_builder,
406 tools: tools.clone(),
407 last_restore_checkpoint: None,
408 pending_checkpoint: None,
409 tool_use: ToolUseState::new(tools.clone()),
410 action_log: cx.new(|_| ActionLog::new(project.clone())),
411 initial_project_snapshot: {
412 let project_snapshot = Self::project_snapshot(project, cx);
413 cx.foreground_executor()
414 .spawn(async move { Some(project_snapshot.await) })
415 .shared()
416 },
417 request_token_usage: Vec::new(),
418 cumulative_token_usage: TokenUsage::default(),
419 exceeded_window_error: None,
420 feedback: None,
421 message_feedback: HashMap::default(),
422 last_auto_capture_at: None,
423 last_received_chunk_at: None,
424 request_callback: None,
425 remaining_turns: u32::MAX,
426 configured_model,
427 }
428 }
429
430 pub fn deserialize(
431 id: ThreadId,
432 serialized: SerializedThread,
433 project: Entity<Project>,
434 tools: Entity<ToolWorkingSet>,
435 prompt_builder: Arc<PromptBuilder>,
436 project_context: SharedProjectContext,
437 cx: &mut Context<Self>,
438 ) -> Self {
439 let next_message_id = MessageId(
440 serialized
441 .messages
442 .last()
443 .map(|message| message.id.0 + 1)
444 .unwrap_or(0),
445 );
446 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
447 let (detailed_summary_tx, detailed_summary_rx) =
448 postage::watch::channel_with(serialized.detailed_summary_state);
449
450 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
451 serialized
452 .model
453 .and_then(|model| {
454 let model = SelectedModel {
455 provider: model.provider.clone().into(),
456 model: model.model.clone().into(),
457 };
458 registry.select_model(&model, cx)
459 })
460 .or_else(|| registry.default_model())
461 });
462
463 Self {
464 id,
465 updated_at: serialized.updated_at,
466 summary: Some(serialized.summary),
467 pending_summary: Task::ready(None),
468 detailed_summary_task: Task::ready(None),
469 detailed_summary_tx,
470 detailed_summary_rx,
471 completion_mode: default_completion_mode(cx),
472 messages: serialized
473 .messages
474 .into_iter()
475 .map(|message| Message {
476 id: message.id,
477 role: message.role,
478 segments: message
479 .segments
480 .into_iter()
481 .map(|segment| match segment {
482 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
483 SerializedMessageSegment::Thinking { text, signature } => {
484 MessageSegment::Thinking { text, signature }
485 }
486 SerializedMessageSegment::RedactedThinking { data } => {
487 MessageSegment::RedactedThinking(data)
488 }
489 })
490 .collect(),
491 loaded_context: LoadedContext {
492 contexts: Vec::new(),
493 text: message.context,
494 images: Vec::new(),
495 },
496 creases: message
497 .creases
498 .into_iter()
499 .map(|crease| MessageCrease {
500 range: crease.start..crease.end,
501 metadata: CreaseMetadata {
502 icon_path: crease.icon_path,
503 label: crease.label,
504 },
505 context: None,
506 })
507 .collect(),
508 })
509 .collect(),
510 next_message_id,
511 last_prompt_id: PromptId::new(),
512 project_context,
513 checkpoints_by_message: HashMap::default(),
514 completion_count: 0,
515 pending_completions: Vec::new(),
516 last_restore_checkpoint: None,
517 pending_checkpoint: None,
518 project: project.clone(),
519 prompt_builder,
520 tools,
521 tool_use,
522 action_log: cx.new(|_| ActionLog::new(project)),
523 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
524 request_token_usage: serialized.request_token_usage,
525 cumulative_token_usage: serialized.cumulative_token_usage,
526 exceeded_window_error: None,
527 feedback: None,
528 message_feedback: HashMap::default(),
529 last_auto_capture_at: None,
530 last_received_chunk_at: None,
531 request_callback: None,
532 remaining_turns: u32::MAX,
533 configured_model,
534 }
535 }
536
537 pub fn set_request_callback(
538 &mut self,
539 callback: impl 'static
540 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
541 ) {
542 self.request_callback = Some(Box::new(callback));
543 }
544
545 pub fn id(&self) -> &ThreadId {
546 &self.id
547 }
548
549 pub fn is_empty(&self) -> bool {
550 self.messages.is_empty()
551 }
552
553 pub fn updated_at(&self) -> DateTime<Utc> {
554 self.updated_at
555 }
556
557 pub fn touch_updated_at(&mut self) {
558 self.updated_at = Utc::now();
559 }
560
561 pub fn advance_prompt_id(&mut self) {
562 self.last_prompt_id = PromptId::new();
563 }
564
565 pub fn summary(&self) -> Option<SharedString> {
566 self.summary.clone()
567 }
568
569 pub fn project_context(&self) -> SharedProjectContext {
570 self.project_context.clone()
571 }
572
573 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
574 if self.configured_model.is_none() {
575 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
576 }
577 self.configured_model.clone()
578 }
579
580 pub fn configured_model(&self) -> Option<ConfiguredModel> {
581 self.configured_model.clone()
582 }
583
584 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
585 self.configured_model = model;
586 cx.notify();
587 }
588
589 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
590
591 pub fn summary_or_default(&self) -> SharedString {
592 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
593 }
594
595 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
596 let Some(current_summary) = &self.summary else {
597 // Don't allow setting summary until generated
598 return;
599 };
600
601 let mut new_summary = new_summary.into();
602
603 if new_summary.is_empty() {
604 new_summary = Self::DEFAULT_SUMMARY;
605 }
606
607 if current_summary != &new_summary {
608 self.summary = Some(new_summary);
609 cx.emit(ThreadEvent::SummaryChanged);
610 }
611 }
612
613 pub fn completion_mode(&self) -> CompletionMode {
614 self.completion_mode
615 }
616
617 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
618 self.completion_mode = mode;
619 }
620
621 pub fn message(&self, id: MessageId) -> Option<&Message> {
622 let index = self
623 .messages
624 .binary_search_by(|message| message.id.cmp(&id))
625 .ok()?;
626
627 self.messages.get(index)
628 }
629
630 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
631 self.messages.iter()
632 }
633
634 pub fn is_generating(&self) -> bool {
635 !self.pending_completions.is_empty() || !self.all_tools_finished()
636 }
637
638 /// Indicates whether streaming of language model events is stale.
639 /// When `is_generating()` is false, this method returns `None`.
640 pub fn is_generation_stale(&self) -> Option<bool> {
641 const STALE_THRESHOLD: u128 = 250;
642
643 self.last_received_chunk_at
644 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
645 }
646
647 fn received_chunk(&mut self) {
648 self.last_received_chunk_at = Some(Instant::now());
649 }
650
651 pub fn queue_state(&self) -> Option<QueueState> {
652 self.pending_completions
653 .first()
654 .map(|pending_completion| pending_completion.queue_state)
655 }
656
657 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
658 &self.tools
659 }
660
661 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
662 self.tool_use
663 .pending_tool_uses()
664 .into_iter()
665 .find(|tool_use| &tool_use.id == id)
666 }
667
668 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
669 self.tool_use
670 .pending_tool_uses()
671 .into_iter()
672 .filter(|tool_use| tool_use.status.needs_confirmation())
673 }
674
675 pub fn has_pending_tool_uses(&self) -> bool {
676 !self.tool_use.pending_tool_uses().is_empty()
677 }
678
679 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
680 self.checkpoints_by_message.get(&id).cloned()
681 }
682
683 pub fn restore_checkpoint(
684 &mut self,
685 checkpoint: ThreadCheckpoint,
686 cx: &mut Context<Self>,
687 ) -> Task<Result<()>> {
688 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
689 message_id: checkpoint.message_id,
690 });
691 cx.emit(ThreadEvent::CheckpointChanged);
692 cx.notify();
693
694 let git_store = self.project().read(cx).git_store().clone();
695 let restore = git_store.update(cx, |git_store, cx| {
696 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
697 });
698
699 cx.spawn(async move |this, cx| {
700 let result = restore.await;
701 this.update(cx, |this, cx| {
702 if let Err(err) = result.as_ref() {
703 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
704 message_id: checkpoint.message_id,
705 error: err.to_string(),
706 });
707 } else {
708 this.truncate(checkpoint.message_id, cx);
709 this.last_restore_checkpoint = None;
710 }
711 this.pending_checkpoint = None;
712 cx.emit(ThreadEvent::CheckpointChanged);
713 cx.notify();
714 })?;
715 result
716 })
717 }
718
719 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
720 let pending_checkpoint = if self.is_generating() {
721 return;
722 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
723 checkpoint
724 } else {
725 return;
726 };
727
728 let git_store = self.project.read(cx).git_store().clone();
729 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
730 cx.spawn(async move |this, cx| match final_checkpoint.await {
731 Ok(final_checkpoint) => {
732 let equal = git_store
733 .update(cx, |store, cx| {
734 store.compare_checkpoints(
735 pending_checkpoint.git_checkpoint.clone(),
736 final_checkpoint.clone(),
737 cx,
738 )
739 })?
740 .await
741 .unwrap_or(false);
742
743 if !equal {
744 this.update(cx, |this, cx| {
745 this.insert_checkpoint(pending_checkpoint, cx)
746 })?;
747 }
748
749 Ok(())
750 }
751 Err(_) => this.update(cx, |this, cx| {
752 this.insert_checkpoint(pending_checkpoint, cx)
753 }),
754 })
755 .detach();
756 }
757
758 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
759 self.checkpoints_by_message
760 .insert(checkpoint.message_id, checkpoint);
761 cx.emit(ThreadEvent::CheckpointChanged);
762 cx.notify();
763 }
764
765 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
766 self.last_restore_checkpoint.as_ref()
767 }
768
769 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
770 let Some(message_ix) = self
771 .messages
772 .iter()
773 .rposition(|message| message.id == message_id)
774 else {
775 return;
776 };
777 for deleted_message in self.messages.drain(message_ix..) {
778 self.checkpoints_by_message.remove(&deleted_message.id);
779 }
780 cx.notify();
781 }
782
783 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
784 self.messages
785 .iter()
786 .find(|message| message.id == id)
787 .into_iter()
788 .flat_map(|message| message.loaded_context.contexts.iter())
789 }
790
791 pub fn is_turn_end(&self, ix: usize) -> bool {
792 if self.messages.is_empty() {
793 return false;
794 }
795
796 if !self.is_generating() && ix == self.messages.len() - 1 {
797 return true;
798 }
799
800 let Some(message) = self.messages.get(ix) else {
801 return false;
802 };
803
804 if message.role != Role::Assistant {
805 return false;
806 }
807
808 self.messages
809 .get(ix + 1)
810 .and_then(|message| {
811 self.message(message.id)
812 .map(|next_message| next_message.role == Role::User)
813 })
814 .unwrap_or(false)
815 }
816
817 /// Returns whether all of the tool uses have finished running.
818 pub fn all_tools_finished(&self) -> bool {
819 // If the only pending tool uses left are the ones with errors, then
820 // that means that we've finished running all of the pending tools.
821 self.tool_use
822 .pending_tool_uses()
823 .iter()
824 .all(|tool_use| tool_use.status.is_error())
825 }
826
827 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
828 self.tool_use.tool_uses_for_message(id, cx)
829 }
830
831 pub fn tool_results_for_message(
832 &self,
833 assistant_message_id: MessageId,
834 ) -> Vec<&LanguageModelToolResult> {
835 self.tool_use.tool_results_for_message(assistant_message_id)
836 }
837
838 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
839 self.tool_use.tool_result(id)
840 }
841
842 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
843 Some(&self.tool_use.tool_result(id)?.content)
844 }
845
846 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
847 self.tool_use.tool_result_card(id).cloned()
848 }
849
850 /// Return tools that are both enabled and supported by the model
851 pub fn available_tools(
852 &self,
853 cx: &App,
854 model: Arc<dyn LanguageModel>,
855 ) -> Vec<LanguageModelRequestTool> {
856 if model.supports_tools() {
857 self.tools()
858 .read(cx)
859 .enabled_tools(cx)
860 .into_iter()
861 .filter_map(|tool| {
862 // Skip tools that cannot be supported
863 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
864 Some(LanguageModelRequestTool {
865 name: tool.name(),
866 description: tool.description(),
867 input_schema,
868 })
869 })
870 .collect()
871 } else {
872 Vec::default()
873 }
874 }
875
876 pub fn insert_user_message(
877 &mut self,
878 text: impl Into<String>,
879 loaded_context: ContextLoadResult,
880 git_checkpoint: Option<GitStoreCheckpoint>,
881 creases: Vec<MessageCrease>,
882 cx: &mut Context<Self>,
883 ) -> MessageId {
884 if !loaded_context.referenced_buffers.is_empty() {
885 self.action_log.update(cx, |log, cx| {
886 for buffer in loaded_context.referenced_buffers {
887 log.track_buffer(buffer, cx);
888 }
889 });
890 }
891
892 let message_id = self.insert_message(
893 Role::User,
894 vec![MessageSegment::Text(text.into())],
895 loaded_context.loaded_context,
896 creases,
897 cx,
898 );
899
900 if let Some(git_checkpoint) = git_checkpoint {
901 self.pending_checkpoint = Some(ThreadCheckpoint {
902 message_id,
903 git_checkpoint,
904 });
905 }
906
907 self.auto_capture_telemetry(cx);
908
909 message_id
910 }
911
912 pub fn insert_assistant_message(
913 &mut self,
914 segments: Vec<MessageSegment>,
915 cx: &mut Context<Self>,
916 ) -> MessageId {
917 self.insert_message(
918 Role::Assistant,
919 segments,
920 LoadedContext::default(),
921 Vec::new(),
922 cx,
923 )
924 }
925
926 pub fn insert_message(
927 &mut self,
928 role: Role,
929 segments: Vec<MessageSegment>,
930 loaded_context: LoadedContext,
931 creases: Vec<MessageCrease>,
932 cx: &mut Context<Self>,
933 ) -> MessageId {
934 let id = self.next_message_id.post_inc();
935 self.messages.push(Message {
936 id,
937 role,
938 segments,
939 loaded_context,
940 creases,
941 });
942 self.touch_updated_at();
943 cx.emit(ThreadEvent::MessageAdded(id));
944 id
945 }
946
947 pub fn edit_message(
948 &mut self,
949 id: MessageId,
950 new_role: Role,
951 new_segments: Vec<MessageSegment>,
952 loaded_context: Option<LoadedContext>,
953 cx: &mut Context<Self>,
954 ) -> bool {
955 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
956 return false;
957 };
958 message.role = new_role;
959 message.segments = new_segments;
960 if let Some(context) = loaded_context {
961 message.loaded_context = context;
962 }
963 self.touch_updated_at();
964 cx.emit(ThreadEvent::MessageEdited(id));
965 true
966 }
967
968 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
969 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
970 return false;
971 };
972 self.messages.remove(index);
973 self.touch_updated_at();
974 cx.emit(ThreadEvent::MessageDeleted(id));
975 true
976 }
977
978 /// Returns the representation of this [`Thread`] in a textual form.
979 ///
980 /// This is the representation we use when attaching a thread as context to another thread.
981 pub fn text(&self) -> String {
982 let mut text = String::new();
983
984 for message in &self.messages {
985 text.push_str(match message.role {
986 language_model::Role::User => "User:",
987 language_model::Role::Assistant => "Assistant:",
988 language_model::Role::System => "System:",
989 });
990 text.push('\n');
991
992 for segment in &message.segments {
993 match segment {
994 MessageSegment::Text(content) => text.push_str(content),
995 MessageSegment::Thinking { text: content, .. } => {
996 text.push_str(&format!("<think>{}</think>", content))
997 }
998 MessageSegment::RedactedThinking(_) => {}
999 }
1000 }
1001 text.push('\n');
1002 }
1003
1004 text
1005 }
1006
1007 /// Serializes this thread into a format for storage or telemetry.
1008 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1009 let initial_project_snapshot = self.initial_project_snapshot.clone();
1010 cx.spawn(async move |this, cx| {
1011 let initial_project_snapshot = initial_project_snapshot.await;
1012 this.read_with(cx, |this, cx| SerializedThread {
1013 version: SerializedThread::VERSION.to_string(),
1014 summary: this.summary_or_default(),
1015 updated_at: this.updated_at(),
1016 messages: this
1017 .messages()
1018 .map(|message| SerializedMessage {
1019 id: message.id,
1020 role: message.role,
1021 segments: message
1022 .segments
1023 .iter()
1024 .map(|segment| match segment {
1025 MessageSegment::Text(text) => {
1026 SerializedMessageSegment::Text { text: text.clone() }
1027 }
1028 MessageSegment::Thinking { text, signature } => {
1029 SerializedMessageSegment::Thinking {
1030 text: text.clone(),
1031 signature: signature.clone(),
1032 }
1033 }
1034 MessageSegment::RedactedThinking(data) => {
1035 SerializedMessageSegment::RedactedThinking {
1036 data: data.clone(),
1037 }
1038 }
1039 })
1040 .collect(),
1041 tool_uses: this
1042 .tool_uses_for_message(message.id, cx)
1043 .into_iter()
1044 .map(|tool_use| SerializedToolUse {
1045 id: tool_use.id,
1046 name: tool_use.name,
1047 input: tool_use.input,
1048 })
1049 .collect(),
1050 tool_results: this
1051 .tool_results_for_message(message.id)
1052 .into_iter()
1053 .map(|tool_result| SerializedToolResult {
1054 tool_use_id: tool_result.tool_use_id.clone(),
1055 is_error: tool_result.is_error,
1056 content: tool_result.content.clone(),
1057 })
1058 .collect(),
1059 context: message.loaded_context.text.clone(),
1060 creases: message
1061 .creases
1062 .iter()
1063 .map(|crease| SerializedCrease {
1064 start: crease.range.start,
1065 end: crease.range.end,
1066 icon_path: crease.metadata.icon_path.clone(),
1067 label: crease.metadata.label.clone(),
1068 })
1069 .collect(),
1070 })
1071 .collect(),
1072 initial_project_snapshot,
1073 cumulative_token_usage: this.cumulative_token_usage,
1074 request_token_usage: this.request_token_usage.clone(),
1075 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1076 exceeded_window_error: this.exceeded_window_error.clone(),
1077 model: this
1078 .configured_model
1079 .as_ref()
1080 .map(|model| SerializedLanguageModel {
1081 provider: model.provider.id().0.to_string(),
1082 model: model.model.id().0.to_string(),
1083 }),
1084 })
1085 })
1086 }
1087
1088 pub fn remaining_turns(&self) -> u32 {
1089 self.remaining_turns
1090 }
1091
1092 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1093 self.remaining_turns = remaining_turns;
1094 }
1095
1096 pub fn send_to_model(
1097 &mut self,
1098 model: Arc<dyn LanguageModel>,
1099 window: Option<AnyWindowHandle>,
1100 cx: &mut Context<Self>,
1101 ) {
1102 if self.remaining_turns == 0 {
1103 return;
1104 }
1105
1106 self.remaining_turns -= 1;
1107
1108 let request = self.to_completion_request(model.clone(), cx);
1109
1110 self.stream_completion(request, model, window, cx);
1111 }
1112
1113 pub fn used_tools_since_last_user_message(&self) -> bool {
1114 for message in self.messages.iter().rev() {
1115 if self.tool_use.message_has_tool_results(message.id) {
1116 return true;
1117 } else if message.role == Role::User {
1118 return false;
1119 }
1120 }
1121
1122 false
1123 }
1124
1125 pub fn to_completion_request(
1126 &self,
1127 model: Arc<dyn LanguageModel>,
1128 cx: &mut Context<Self>,
1129 ) -> LanguageModelRequest {
1130 let mut request = LanguageModelRequest {
1131 thread_id: Some(self.id.to_string()),
1132 prompt_id: Some(self.last_prompt_id.to_string()),
1133 mode: None,
1134 messages: vec![],
1135 tools: Vec::new(),
1136 stop: Vec::new(),
1137 temperature: None,
1138 };
1139
1140 let available_tools = self.available_tools(cx, model.clone());
1141 let available_tool_names = available_tools
1142 .iter()
1143 .map(|tool| tool.name.clone())
1144 .collect();
1145
1146 let model_context = &ModelContext {
1147 available_tools: available_tool_names,
1148 };
1149
1150 if let Some(project_context) = self.project_context.borrow().as_ref() {
1151 match self
1152 .prompt_builder
1153 .generate_assistant_system_prompt(project_context, model_context)
1154 {
1155 Err(err) => {
1156 let message = format!("{err:?}").into();
1157 log::error!("{message}");
1158 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1159 header: "Error generating system prompt".into(),
1160 message,
1161 }));
1162 }
1163 Ok(system_prompt) => {
1164 request.messages.push(LanguageModelRequestMessage {
1165 role: Role::System,
1166 content: vec![MessageContent::Text(system_prompt)],
1167 cache: true,
1168 });
1169 }
1170 }
1171 } else {
1172 let message = "Context for system prompt unexpectedly not ready.".into();
1173 log::error!("{message}");
1174 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1175 header: "Error generating system prompt".into(),
1176 message,
1177 }));
1178 }
1179
1180 for message in &self.messages {
1181 let mut request_message = LanguageModelRequestMessage {
1182 role: message.role,
1183 content: Vec::new(),
1184 cache: false,
1185 };
1186
1187 message
1188 .loaded_context
1189 .add_to_request_message(&mut request_message);
1190
1191 for segment in &message.segments {
1192 match segment {
1193 MessageSegment::Text(text) => {
1194 if !text.is_empty() {
1195 request_message
1196 .content
1197 .push(MessageContent::Text(text.into()));
1198 }
1199 }
1200 MessageSegment::Thinking { text, signature } => {
1201 if !text.is_empty() {
1202 request_message.content.push(MessageContent::Thinking {
1203 text: text.into(),
1204 signature: signature.clone(),
1205 });
1206 }
1207 }
1208 MessageSegment::RedactedThinking(data) => {
1209 request_message
1210 .content
1211 .push(MessageContent::RedactedThinking(data.clone()));
1212 }
1213 };
1214 }
1215
1216 self.tool_use
1217 .attach_tool_uses(message.id, &mut request_message);
1218
1219 request.messages.push(request_message);
1220
1221 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1222 request.messages.push(tool_results_message);
1223 }
1224 }
1225
1226 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1227 if let Some(last) = request.messages.last_mut() {
1228 last.cache = true;
1229 }
1230
1231 self.attached_tracked_files_state(&mut request.messages, cx);
1232
1233 request.tools = available_tools;
1234 request.mode = if model.supports_max_mode() {
1235 Some(self.completion_mode)
1236 } else {
1237 Some(CompletionMode::Normal)
1238 };
1239
1240 request
1241 }
1242
1243 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1244 let mut request = LanguageModelRequest {
1245 thread_id: None,
1246 prompt_id: None,
1247 mode: None,
1248 messages: vec![],
1249 tools: Vec::new(),
1250 stop: Vec::new(),
1251 temperature: None,
1252 };
1253
1254 for message in &self.messages {
1255 let mut request_message = LanguageModelRequestMessage {
1256 role: message.role,
1257 content: Vec::new(),
1258 cache: false,
1259 };
1260
1261 for segment in &message.segments {
1262 match segment {
1263 MessageSegment::Text(text) => request_message
1264 .content
1265 .push(MessageContent::Text(text.clone())),
1266 MessageSegment::Thinking { .. } => {}
1267 MessageSegment::RedactedThinking(_) => {}
1268 }
1269 }
1270
1271 if request_message.content.is_empty() {
1272 continue;
1273 }
1274
1275 request.messages.push(request_message);
1276 }
1277
1278 request.messages.push(LanguageModelRequestMessage {
1279 role: Role::User,
1280 content: vec![MessageContent::Text(added_user_message)],
1281 cache: false,
1282 });
1283
1284 request
1285 }
1286
1287 fn attached_tracked_files_state(
1288 &self,
1289 messages: &mut Vec<LanguageModelRequestMessage>,
1290 cx: &App,
1291 ) {
1292 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1293
1294 let mut stale_message = String::new();
1295
1296 let action_log = self.action_log.read(cx);
1297
1298 for stale_file in action_log.stale_buffers(cx) {
1299 let Some(file) = stale_file.read(cx).file() else {
1300 continue;
1301 };
1302
1303 if stale_message.is_empty() {
1304 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1305 }
1306
1307 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1308 }
1309
1310 let mut content = Vec::with_capacity(2);
1311
1312 if !stale_message.is_empty() {
1313 content.push(stale_message.into());
1314 }
1315
1316 if !content.is_empty() {
1317 let context_message = LanguageModelRequestMessage {
1318 role: Role::User,
1319 content,
1320 cache: false,
1321 };
1322
1323 messages.push(context_message);
1324 }
1325 }
1326
1327 pub fn stream_completion(
1328 &mut self,
1329 request: LanguageModelRequest,
1330 model: Arc<dyn LanguageModel>,
1331 window: Option<AnyWindowHandle>,
1332 cx: &mut Context<Self>,
1333 ) {
1334 let pending_completion_id = post_inc(&mut self.completion_count);
1335 let mut request_callback_parameters = if self.request_callback.is_some() {
1336 Some((request.clone(), Vec::new()))
1337 } else {
1338 None
1339 };
1340 let prompt_id = self.last_prompt_id.clone();
1341 let tool_use_metadata = ToolUseMetadata {
1342 model: model.clone(),
1343 thread_id: self.id.clone(),
1344 prompt_id: prompt_id.clone(),
1345 };
1346
1347 self.last_received_chunk_at = Some(Instant::now());
1348
1349 let task = cx.spawn(async move |thread, cx| {
1350 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1351 let initial_token_usage =
1352 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1353 let stream_completion = async {
1354 let (mut events, usage) = stream_completion_future.await?;
1355
1356 let mut stop_reason = StopReason::EndTurn;
1357 let mut current_token_usage = TokenUsage::default();
1358
1359 thread
1360 .update(cx, |_thread, cx| {
1361 if let Some(usage) = usage {
1362 cx.emit(ThreadEvent::UsageUpdated(usage));
1363 }
1364 cx.emit(ThreadEvent::NewRequest);
1365 })
1366 .ok();
1367
1368 let mut request_assistant_message_id = None;
1369
1370 while let Some(event) = events.next().await {
1371 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1372 response_events
1373 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1374 }
1375
1376 thread.update(cx, |thread, cx| {
1377 let event = match event {
1378 Ok(event) => event,
1379 Err(LanguageModelCompletionError::BadInputJson {
1380 id,
1381 tool_name,
1382 raw_input: invalid_input_json,
1383 json_parse_error,
1384 }) => {
1385 thread.receive_invalid_tool_json(
1386 id,
1387 tool_name,
1388 invalid_input_json,
1389 json_parse_error,
1390 window,
1391 cx,
1392 );
1393 return Ok(());
1394 }
1395 Err(LanguageModelCompletionError::Other(error)) => {
1396 return Err(error);
1397 }
1398 };
1399
1400 match event {
1401 LanguageModelCompletionEvent::StartMessage { .. } => {
1402 request_assistant_message_id =
1403 Some(thread.insert_assistant_message(
1404 vec![MessageSegment::Text(String::new())],
1405 cx,
1406 ));
1407 }
1408 LanguageModelCompletionEvent::Stop(reason) => {
1409 stop_reason = reason;
1410 }
1411 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1412 thread.update_token_usage_at_last_message(token_usage);
1413 thread.cumulative_token_usage = thread.cumulative_token_usage
1414 + token_usage
1415 - current_token_usage;
1416 current_token_usage = token_usage;
1417 }
1418 LanguageModelCompletionEvent::Text(chunk) => {
1419 thread.received_chunk();
1420
1421 cx.emit(ThreadEvent::ReceivedTextChunk);
1422 if let Some(last_message) = thread.messages.last_mut() {
1423 if last_message.role == Role::Assistant
1424 && !thread.tool_use.has_tool_results(last_message.id)
1425 {
1426 last_message.push_text(&chunk);
1427 cx.emit(ThreadEvent::StreamedAssistantText(
1428 last_message.id,
1429 chunk,
1430 ));
1431 } else {
1432 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1433 // of a new Assistant response.
1434 //
1435 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1436 // will result in duplicating the text of the chunk in the rendered Markdown.
1437 request_assistant_message_id =
1438 Some(thread.insert_assistant_message(
1439 vec![MessageSegment::Text(chunk.to_string())],
1440 cx,
1441 ));
1442 };
1443 }
1444 }
1445 LanguageModelCompletionEvent::Thinking {
1446 text: chunk,
1447 signature,
1448 } => {
1449 thread.received_chunk();
1450
1451 if let Some(last_message) = thread.messages.last_mut() {
1452 if last_message.role == Role::Assistant
1453 && !thread.tool_use.has_tool_results(last_message.id)
1454 {
1455 last_message.push_thinking(&chunk, signature);
1456 cx.emit(ThreadEvent::StreamedAssistantThinking(
1457 last_message.id,
1458 chunk,
1459 ));
1460 } else {
1461 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1462 // of a new Assistant response.
1463 //
1464 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1465 // will result in duplicating the text of the chunk in the rendered Markdown.
1466 request_assistant_message_id =
1467 Some(thread.insert_assistant_message(
1468 vec![MessageSegment::Thinking {
1469 text: chunk.to_string(),
1470 signature,
1471 }],
1472 cx,
1473 ));
1474 };
1475 }
1476 }
1477 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1478 let last_assistant_message_id = request_assistant_message_id
1479 .unwrap_or_else(|| {
1480 let new_assistant_message_id =
1481 thread.insert_assistant_message(vec![], cx);
1482 request_assistant_message_id =
1483 Some(new_assistant_message_id);
1484 new_assistant_message_id
1485 });
1486
1487 let tool_use_id = tool_use.id.clone();
1488 let streamed_input = if tool_use.is_input_complete {
1489 None
1490 } else {
1491 Some((&tool_use.input).clone())
1492 };
1493
1494 let ui_text = thread.tool_use.request_tool_use(
1495 last_assistant_message_id,
1496 tool_use,
1497 tool_use_metadata.clone(),
1498 cx,
1499 );
1500
1501 if let Some(input) = streamed_input {
1502 cx.emit(ThreadEvent::StreamedToolUse {
1503 tool_use_id,
1504 ui_text,
1505 input,
1506 });
1507 }
1508 }
1509 LanguageModelCompletionEvent::QueueUpdate(queue_event) => {
1510 if let Some(completion) = thread
1511 .pending_completions
1512 .iter_mut()
1513 .find(|completion| completion.id == pending_completion_id)
1514 {
1515 completion.queue_state = match queue_event {
1516 language_model::QueueState::Queued { position } => {
1517 QueueState::Queued { position }
1518 }
1519 language_model::QueueState::Started => QueueState::Started,
1520 }
1521 }
1522 }
1523 }
1524
1525 thread.touch_updated_at();
1526 cx.emit(ThreadEvent::StreamedCompletion);
1527 cx.notify();
1528
1529 thread.auto_capture_telemetry(cx);
1530 Ok(())
1531 })??;
1532
1533 smol::future::yield_now().await;
1534 }
1535
1536 thread.update(cx, |thread, cx| {
1537 thread.last_received_chunk_at = None;
1538 thread
1539 .pending_completions
1540 .retain(|completion| completion.id != pending_completion_id);
1541
1542 // If there is a response without tool use, summarize the message. Otherwise,
1543 // allow two tool uses before summarizing.
1544 if thread.summary.is_none()
1545 && thread.messages.len() >= 2
1546 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1547 {
1548 thread.summarize(cx);
1549 }
1550 })?;
1551
1552 anyhow::Ok(stop_reason)
1553 };
1554
1555 let result = stream_completion.await;
1556
1557 thread
1558 .update(cx, |thread, cx| {
1559 thread.finalize_pending_checkpoint(cx);
1560 match result.as_ref() {
1561 Ok(stop_reason) => match stop_reason {
1562 StopReason::ToolUse => {
1563 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1564 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1565 }
1566 StopReason::EndTurn => {}
1567 StopReason::MaxTokens => {}
1568 },
1569 Err(error) => {
1570 if error.is::<PaymentRequiredError>() {
1571 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1572 } else if error.is::<MaxMonthlySpendReachedError>() {
1573 cx.emit(ThreadEvent::ShowError(
1574 ThreadError::MaxMonthlySpendReached,
1575 ));
1576 } else if let Some(error) =
1577 error.downcast_ref::<ModelRequestLimitReachedError>()
1578 {
1579 cx.emit(ThreadEvent::ShowError(
1580 ThreadError::ModelRequestLimitReached { plan: error.plan },
1581 ));
1582 } else if let Some(known_error) =
1583 error.downcast_ref::<LanguageModelKnownError>()
1584 {
1585 match known_error {
1586 LanguageModelKnownError::ContextWindowLimitExceeded {
1587 tokens,
1588 } => {
1589 thread.exceeded_window_error = Some(ExceededWindowError {
1590 model_id: model.id(),
1591 token_count: *tokens,
1592 });
1593 cx.notify();
1594 }
1595 }
1596 } else {
1597 let error_message = error
1598 .chain()
1599 .map(|err| err.to_string())
1600 .collect::<Vec<_>>()
1601 .join("\n");
1602 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1603 header: "Error interacting with language model".into(),
1604 message: SharedString::from(error_message.clone()),
1605 }));
1606 }
1607
1608 thread.cancel_last_completion(window, cx);
1609 }
1610 }
1611 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1612
1613 if let Some((request_callback, (request, response_events))) = thread
1614 .request_callback
1615 .as_mut()
1616 .zip(request_callback_parameters.as_ref())
1617 {
1618 request_callback(request, response_events);
1619 }
1620
1621 thread.auto_capture_telemetry(cx);
1622
1623 if let Ok(initial_usage) = initial_token_usage {
1624 let usage = thread.cumulative_token_usage - initial_usage;
1625
1626 telemetry::event!(
1627 "Assistant Thread Completion",
1628 thread_id = thread.id().to_string(),
1629 prompt_id = prompt_id,
1630 model = model.telemetry_id(),
1631 model_provider = model.provider_id().to_string(),
1632 input_tokens = usage.input_tokens,
1633 output_tokens = usage.output_tokens,
1634 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1635 cache_read_input_tokens = usage.cache_read_input_tokens,
1636 );
1637 }
1638 })
1639 .ok();
1640 });
1641
1642 self.pending_completions.push(PendingCompletion {
1643 id: pending_completion_id,
1644 queue_state: QueueState::Sending,
1645 _task: task,
1646 });
1647 }
1648
1649 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1650 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1651 return;
1652 };
1653
1654 if !model.provider.is_authenticated(cx) {
1655 return;
1656 }
1657
1658 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1659 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1660 If the conversation is about a specific subject, include it in the title. \
1661 Be descriptive. DO NOT speak in the first person.";
1662
1663 let request = self.to_summarize_request(added_user_message.into());
1664
1665 self.pending_summary = cx.spawn(async move |this, cx| {
1666 async move {
1667 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1668 let (mut messages, usage) = stream.await?;
1669
1670 if let Some(usage) = usage {
1671 this.update(cx, |_thread, cx| {
1672 cx.emit(ThreadEvent::UsageUpdated(usage));
1673 })
1674 .ok();
1675 }
1676
1677 let mut new_summary = String::new();
1678 while let Some(message) = messages.stream.next().await {
1679 let text = message?;
1680 let mut lines = text.lines();
1681 new_summary.extend(lines.next());
1682
1683 // Stop if the LLM generated multiple lines.
1684 if lines.next().is_some() {
1685 break;
1686 }
1687 }
1688
1689 this.update(cx, |this, cx| {
1690 if !new_summary.is_empty() {
1691 this.summary = Some(new_summary.into());
1692 }
1693
1694 cx.emit(ThreadEvent::SummaryGenerated);
1695 })?;
1696
1697 anyhow::Ok(())
1698 }
1699 .log_err()
1700 .await
1701 });
1702 }
1703
1704 pub fn start_generating_detailed_summary_if_needed(
1705 &mut self,
1706 thread_store: WeakEntity<ThreadStore>,
1707 cx: &mut Context<Self>,
1708 ) {
1709 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1710 return;
1711 };
1712
1713 match &*self.detailed_summary_rx.borrow() {
1714 DetailedSummaryState::Generating { message_id, .. }
1715 | DetailedSummaryState::Generated { message_id, .. }
1716 if *message_id == last_message_id =>
1717 {
1718 // Already up-to-date
1719 return;
1720 }
1721 _ => {}
1722 }
1723
1724 let Some(ConfiguredModel { model, provider }) =
1725 LanguageModelRegistry::read_global(cx).thread_summary_model()
1726 else {
1727 return;
1728 };
1729
1730 if !provider.is_authenticated(cx) {
1731 return;
1732 }
1733
1734 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1735 1. A brief overview of what was discussed\n\
1736 2. Key facts or information discovered\n\
1737 3. Outcomes or conclusions reached\n\
1738 4. Any action items or next steps if any\n\
1739 Format it in Markdown with headings and bullet points.";
1740
1741 let request = self.to_summarize_request(added_user_message.into());
1742
1743 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1744 message_id: last_message_id,
1745 };
1746
1747 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1748 // be better to allow the old task to complete, but this would require logic for choosing
1749 // which result to prefer (the old task could complete after the new one, resulting in a
1750 // stale summary).
1751 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1752 let stream = model.stream_completion_text(request, &cx);
1753 let Some(mut messages) = stream.await.log_err() else {
1754 thread
1755 .update(cx, |thread, _cx| {
1756 *thread.detailed_summary_tx.borrow_mut() =
1757 DetailedSummaryState::NotGenerated;
1758 })
1759 .ok()?;
1760 return None;
1761 };
1762
1763 let mut new_detailed_summary = String::new();
1764
1765 while let Some(chunk) = messages.stream.next().await {
1766 if let Some(chunk) = chunk.log_err() {
1767 new_detailed_summary.push_str(&chunk);
1768 }
1769 }
1770
1771 thread
1772 .update(cx, |thread, _cx| {
1773 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1774 text: new_detailed_summary.into(),
1775 message_id: last_message_id,
1776 };
1777 })
1778 .ok()?;
1779
1780 // Save thread so its summary can be reused later
1781 if let Some(thread) = thread.upgrade() {
1782 if let Ok(Ok(save_task)) = cx.update(|cx| {
1783 thread_store
1784 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1785 }) {
1786 save_task.await.log_err();
1787 }
1788 }
1789
1790 Some(())
1791 });
1792 }
1793
1794 pub async fn wait_for_detailed_summary_or_text(
1795 this: &Entity<Self>,
1796 cx: &mut AsyncApp,
1797 ) -> Option<SharedString> {
1798 let mut detailed_summary_rx = this
1799 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1800 .ok()?;
1801 loop {
1802 match detailed_summary_rx.recv().await? {
1803 DetailedSummaryState::Generating { .. } => {}
1804 DetailedSummaryState::NotGenerated => {
1805 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1806 }
1807 DetailedSummaryState::Generated { text, .. } => return Some(text),
1808 }
1809 }
1810 }
1811
1812 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1813 self.detailed_summary_rx
1814 .borrow()
1815 .text()
1816 .unwrap_or_else(|| self.text().into())
1817 }
1818
1819 pub fn is_generating_detailed_summary(&self) -> bool {
1820 matches!(
1821 &*self.detailed_summary_rx.borrow(),
1822 DetailedSummaryState::Generating { .. }
1823 )
1824 }
1825
1826 pub fn use_pending_tools(
1827 &mut self,
1828 window: Option<AnyWindowHandle>,
1829 cx: &mut Context<Self>,
1830 model: Arc<dyn LanguageModel>,
1831 ) -> Vec<PendingToolUse> {
1832 self.auto_capture_telemetry(cx);
1833 let request = self.to_completion_request(model, cx);
1834 let messages = Arc::new(request.messages);
1835 let pending_tool_uses = self
1836 .tool_use
1837 .pending_tool_uses()
1838 .into_iter()
1839 .filter(|tool_use| tool_use.status.is_idle())
1840 .cloned()
1841 .collect::<Vec<_>>();
1842
1843 for tool_use in pending_tool_uses.iter() {
1844 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1845 if tool.needs_confirmation(&tool_use.input, cx)
1846 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1847 {
1848 self.tool_use.confirm_tool_use(
1849 tool_use.id.clone(),
1850 tool_use.ui_text.clone(),
1851 tool_use.input.clone(),
1852 messages.clone(),
1853 tool,
1854 );
1855 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1856 } else {
1857 self.run_tool(
1858 tool_use.id.clone(),
1859 tool_use.ui_text.clone(),
1860 tool_use.input.clone(),
1861 &messages,
1862 tool,
1863 window,
1864 cx,
1865 );
1866 }
1867 }
1868 }
1869
1870 pending_tool_uses
1871 }
1872
1873 pub fn receive_invalid_tool_json(
1874 &mut self,
1875 tool_use_id: LanguageModelToolUseId,
1876 tool_name: Arc<str>,
1877 invalid_json: Arc<str>,
1878 error: String,
1879 window: Option<AnyWindowHandle>,
1880 cx: &mut Context<Thread>,
1881 ) {
1882 log::error!("The model returned invalid input JSON: {invalid_json}");
1883
1884 let pending_tool_use = self.tool_use.insert_tool_output(
1885 tool_use_id.clone(),
1886 tool_name,
1887 Err(anyhow!("Error parsing input JSON: {error}")),
1888 self.configured_model.as_ref(),
1889 );
1890 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1891 pending_tool_use.ui_text.clone()
1892 } else {
1893 log::error!(
1894 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1895 );
1896 format!("Unknown tool {}", tool_use_id).into()
1897 };
1898
1899 cx.emit(ThreadEvent::InvalidToolInput {
1900 tool_use_id: tool_use_id.clone(),
1901 ui_text,
1902 invalid_input_json: invalid_json,
1903 });
1904
1905 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1906 }
1907
1908 pub fn run_tool(
1909 &mut self,
1910 tool_use_id: LanguageModelToolUseId,
1911 ui_text: impl Into<SharedString>,
1912 input: serde_json::Value,
1913 messages: &[LanguageModelRequestMessage],
1914 tool: Arc<dyn Tool>,
1915 window: Option<AnyWindowHandle>,
1916 cx: &mut Context<Thread>,
1917 ) {
1918 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1919 self.tool_use
1920 .run_pending_tool(tool_use_id, ui_text.into(), task);
1921 }
1922
1923 fn spawn_tool_use(
1924 &mut self,
1925 tool_use_id: LanguageModelToolUseId,
1926 messages: &[LanguageModelRequestMessage],
1927 input: serde_json::Value,
1928 tool: Arc<dyn Tool>,
1929 window: Option<AnyWindowHandle>,
1930 cx: &mut Context<Thread>,
1931 ) -> Task<()> {
1932 let tool_name: Arc<str> = tool.name().into();
1933
1934 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1935 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1936 } else {
1937 tool.run(
1938 input,
1939 messages,
1940 self.project.clone(),
1941 self.action_log.clone(),
1942 window,
1943 cx,
1944 )
1945 };
1946
1947 // Store the card separately if it exists
1948 if let Some(card) = tool_result.card.clone() {
1949 self.tool_use
1950 .insert_tool_result_card(tool_use_id.clone(), card);
1951 }
1952
1953 cx.spawn({
1954 async move |thread: WeakEntity<Thread>, cx| {
1955 let output = tool_result.output.await;
1956
1957 thread
1958 .update(cx, |thread, cx| {
1959 let pending_tool_use = thread.tool_use.insert_tool_output(
1960 tool_use_id.clone(),
1961 tool_name,
1962 output,
1963 thread.configured_model.as_ref(),
1964 );
1965 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1966 })
1967 .ok();
1968 }
1969 })
1970 }
1971
1972 fn tool_finished(
1973 &mut self,
1974 tool_use_id: LanguageModelToolUseId,
1975 pending_tool_use: Option<PendingToolUse>,
1976 canceled: bool,
1977 window: Option<AnyWindowHandle>,
1978 cx: &mut Context<Self>,
1979 ) {
1980 if self.all_tools_finished() {
1981 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
1982 if !canceled {
1983 self.send_to_model(model.clone(), window, cx);
1984 }
1985 self.auto_capture_telemetry(cx);
1986 }
1987 }
1988
1989 cx.emit(ThreadEvent::ToolFinished {
1990 tool_use_id,
1991 pending_tool_use,
1992 });
1993 }
1994
1995 /// Cancels the last pending completion, if there are any pending.
1996 ///
1997 /// Returns whether a completion was canceled.
1998 pub fn cancel_last_completion(
1999 &mut self,
2000 window: Option<AnyWindowHandle>,
2001 cx: &mut Context<Self>,
2002 ) -> bool {
2003 let mut canceled = self.pending_completions.pop().is_some();
2004
2005 for pending_tool_use in self.tool_use.cancel_pending() {
2006 canceled = true;
2007 self.tool_finished(
2008 pending_tool_use.id.clone(),
2009 Some(pending_tool_use),
2010 true,
2011 window,
2012 cx,
2013 );
2014 }
2015
2016 self.finalize_pending_checkpoint(cx);
2017
2018 if canceled {
2019 cx.emit(ThreadEvent::CompletionCanceled);
2020 }
2021
2022 canceled
2023 }
2024
2025 /// Signals that any in-progress editing should be canceled.
2026 ///
2027 /// This method is used to notify listeners (like ActiveThread) that
2028 /// they should cancel any editing operations.
2029 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2030 cx.emit(ThreadEvent::CancelEditing);
2031 }
2032
2033 pub fn feedback(&self) -> Option<ThreadFeedback> {
2034 self.feedback
2035 }
2036
2037 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2038 self.message_feedback.get(&message_id).copied()
2039 }
2040
2041 pub fn report_message_feedback(
2042 &mut self,
2043 message_id: MessageId,
2044 feedback: ThreadFeedback,
2045 cx: &mut Context<Self>,
2046 ) -> Task<Result<()>> {
2047 if self.message_feedback.get(&message_id) == Some(&feedback) {
2048 return Task::ready(Ok(()));
2049 }
2050
2051 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2052 let serialized_thread = self.serialize(cx);
2053 let thread_id = self.id().clone();
2054 let client = self.project.read(cx).client();
2055
2056 let enabled_tool_names: Vec<String> = self
2057 .tools()
2058 .read(cx)
2059 .enabled_tools(cx)
2060 .iter()
2061 .map(|tool| tool.name().to_string())
2062 .collect();
2063
2064 self.message_feedback.insert(message_id, feedback);
2065
2066 cx.notify();
2067
2068 let message_content = self
2069 .message(message_id)
2070 .map(|msg| msg.to_string())
2071 .unwrap_or_default();
2072
2073 cx.background_spawn(async move {
2074 let final_project_snapshot = final_project_snapshot.await;
2075 let serialized_thread = serialized_thread.await?;
2076 let thread_data =
2077 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2078
2079 let rating = match feedback {
2080 ThreadFeedback::Positive => "positive",
2081 ThreadFeedback::Negative => "negative",
2082 };
2083 telemetry::event!(
2084 "Assistant Thread Rated",
2085 rating,
2086 thread_id,
2087 enabled_tool_names,
2088 message_id = message_id.0,
2089 message_content,
2090 thread_data,
2091 final_project_snapshot
2092 );
2093 client.telemetry().flush_events().await;
2094
2095 Ok(())
2096 })
2097 }
2098
2099 pub fn report_feedback(
2100 &mut self,
2101 feedback: ThreadFeedback,
2102 cx: &mut Context<Self>,
2103 ) -> Task<Result<()>> {
2104 let last_assistant_message_id = self
2105 .messages
2106 .iter()
2107 .rev()
2108 .find(|msg| msg.role == Role::Assistant)
2109 .map(|msg| msg.id);
2110
2111 if let Some(message_id) = last_assistant_message_id {
2112 self.report_message_feedback(message_id, feedback, cx)
2113 } else {
2114 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2115 let serialized_thread = self.serialize(cx);
2116 let thread_id = self.id().clone();
2117 let client = self.project.read(cx).client();
2118 self.feedback = Some(feedback);
2119 cx.notify();
2120
2121 cx.background_spawn(async move {
2122 let final_project_snapshot = final_project_snapshot.await;
2123 let serialized_thread = serialized_thread.await?;
2124 let thread_data = serde_json::to_value(serialized_thread)
2125 .unwrap_or_else(|_| serde_json::Value::Null);
2126
2127 let rating = match feedback {
2128 ThreadFeedback::Positive => "positive",
2129 ThreadFeedback::Negative => "negative",
2130 };
2131 telemetry::event!(
2132 "Assistant Thread Rated",
2133 rating,
2134 thread_id,
2135 thread_data,
2136 final_project_snapshot
2137 );
2138 client.telemetry().flush_events().await;
2139
2140 Ok(())
2141 })
2142 }
2143 }
2144
2145 /// Create a snapshot of the current project state including git information and unsaved buffers.
2146 fn project_snapshot(
2147 project: Entity<Project>,
2148 cx: &mut Context<Self>,
2149 ) -> Task<Arc<ProjectSnapshot>> {
2150 let git_store = project.read(cx).git_store().clone();
2151 let worktree_snapshots: Vec<_> = project
2152 .read(cx)
2153 .visible_worktrees(cx)
2154 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2155 .collect();
2156
2157 cx.spawn(async move |_, cx| {
2158 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2159
2160 let mut unsaved_buffers = Vec::new();
2161 cx.update(|app_cx| {
2162 let buffer_store = project.read(app_cx).buffer_store();
2163 for buffer_handle in buffer_store.read(app_cx).buffers() {
2164 let buffer = buffer_handle.read(app_cx);
2165 if buffer.is_dirty() {
2166 if let Some(file) = buffer.file() {
2167 let path = file.path().to_string_lossy().to_string();
2168 unsaved_buffers.push(path);
2169 }
2170 }
2171 }
2172 })
2173 .ok();
2174
2175 Arc::new(ProjectSnapshot {
2176 worktree_snapshots,
2177 unsaved_buffer_paths: unsaved_buffers,
2178 timestamp: Utc::now(),
2179 })
2180 })
2181 }
2182
2183 fn worktree_snapshot(
2184 worktree: Entity<project::Worktree>,
2185 git_store: Entity<GitStore>,
2186 cx: &App,
2187 ) -> Task<WorktreeSnapshot> {
2188 cx.spawn(async move |cx| {
2189 // Get worktree path and snapshot
2190 let worktree_info = cx.update(|app_cx| {
2191 let worktree = worktree.read(app_cx);
2192 let path = worktree.abs_path().to_string_lossy().to_string();
2193 let snapshot = worktree.snapshot();
2194 (path, snapshot)
2195 });
2196
2197 let Ok((worktree_path, _snapshot)) = worktree_info else {
2198 return WorktreeSnapshot {
2199 worktree_path: String::new(),
2200 git_state: None,
2201 };
2202 };
2203
2204 let git_state = git_store
2205 .update(cx, |git_store, cx| {
2206 git_store
2207 .repositories()
2208 .values()
2209 .find(|repo| {
2210 repo.read(cx)
2211 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2212 .is_some()
2213 })
2214 .cloned()
2215 })
2216 .ok()
2217 .flatten()
2218 .map(|repo| {
2219 repo.update(cx, |repo, _| {
2220 let current_branch =
2221 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2222 repo.send_job(None, |state, _| async move {
2223 let RepositoryState::Local { backend, .. } = state else {
2224 return GitState {
2225 remote_url: None,
2226 head_sha: None,
2227 current_branch,
2228 diff: None,
2229 };
2230 };
2231
2232 let remote_url = backend.remote_url("origin");
2233 let head_sha = backend.head_sha().await;
2234 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2235
2236 GitState {
2237 remote_url,
2238 head_sha,
2239 current_branch,
2240 diff,
2241 }
2242 })
2243 })
2244 });
2245
2246 let git_state = match git_state {
2247 Some(git_state) => match git_state.ok() {
2248 Some(git_state) => git_state.await.ok(),
2249 None => None,
2250 },
2251 None => None,
2252 };
2253
2254 WorktreeSnapshot {
2255 worktree_path,
2256 git_state,
2257 }
2258 })
2259 }
2260
2261 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2262 let mut markdown = Vec::new();
2263
2264 if let Some(summary) = self.summary() {
2265 writeln!(markdown, "# {summary}\n")?;
2266 };
2267
2268 for message in self.messages() {
2269 writeln!(
2270 markdown,
2271 "## {role}\n",
2272 role = match message.role {
2273 Role::User => "User",
2274 Role::Assistant => "Assistant",
2275 Role::System => "System",
2276 }
2277 )?;
2278
2279 if !message.loaded_context.text.is_empty() {
2280 writeln!(markdown, "{}", message.loaded_context.text)?;
2281 }
2282
2283 if !message.loaded_context.images.is_empty() {
2284 writeln!(
2285 markdown,
2286 "\n{} images attached as context.\n",
2287 message.loaded_context.images.len()
2288 )?;
2289 }
2290
2291 for segment in &message.segments {
2292 match segment {
2293 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2294 MessageSegment::Thinking { text, .. } => {
2295 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2296 }
2297 MessageSegment::RedactedThinking(_) => {}
2298 }
2299 }
2300
2301 for tool_use in self.tool_uses_for_message(message.id, cx) {
2302 writeln!(
2303 markdown,
2304 "**Use Tool: {} ({})**",
2305 tool_use.name, tool_use.id
2306 )?;
2307 writeln!(markdown, "```json")?;
2308 writeln!(
2309 markdown,
2310 "{}",
2311 serde_json::to_string_pretty(&tool_use.input)?
2312 )?;
2313 writeln!(markdown, "```")?;
2314 }
2315
2316 for tool_result in self.tool_results_for_message(message.id) {
2317 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2318 if tool_result.is_error {
2319 write!(markdown, " (Error)")?;
2320 }
2321
2322 writeln!(markdown, "**\n")?;
2323 writeln!(markdown, "{}", tool_result.content)?;
2324 }
2325 }
2326
2327 Ok(String::from_utf8_lossy(&markdown).to_string())
2328 }
2329
2330 pub fn keep_edits_in_range(
2331 &mut self,
2332 buffer: Entity<language::Buffer>,
2333 buffer_range: Range<language::Anchor>,
2334 cx: &mut Context<Self>,
2335 ) {
2336 self.action_log.update(cx, |action_log, cx| {
2337 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2338 });
2339 }
2340
2341 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2342 self.action_log
2343 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2344 }
2345
2346 pub fn reject_edits_in_ranges(
2347 &mut self,
2348 buffer: Entity<language::Buffer>,
2349 buffer_ranges: Vec<Range<language::Anchor>>,
2350 cx: &mut Context<Self>,
2351 ) -> Task<Result<()>> {
2352 self.action_log.update(cx, |action_log, cx| {
2353 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2354 })
2355 }
2356
2357 pub fn action_log(&self) -> &Entity<ActionLog> {
2358 &self.action_log
2359 }
2360
2361 pub fn project(&self) -> &Entity<Project> {
2362 &self.project
2363 }
2364
2365 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2366 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2367 return;
2368 }
2369
2370 let now = Instant::now();
2371 if let Some(last) = self.last_auto_capture_at {
2372 if now.duration_since(last).as_secs() < 10 {
2373 return;
2374 }
2375 }
2376
2377 self.last_auto_capture_at = Some(now);
2378
2379 let thread_id = self.id().clone();
2380 let github_login = self
2381 .project
2382 .read(cx)
2383 .user_store()
2384 .read(cx)
2385 .current_user()
2386 .map(|user| user.github_login.clone());
2387 let client = self.project.read(cx).client().clone();
2388 let serialize_task = self.serialize(cx);
2389
2390 cx.background_executor()
2391 .spawn(async move {
2392 if let Ok(serialized_thread) = serialize_task.await {
2393 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2394 telemetry::event!(
2395 "Agent Thread Auto-Captured",
2396 thread_id = thread_id.to_string(),
2397 thread_data = thread_data,
2398 auto_capture_reason = "tracked_user",
2399 github_login = github_login
2400 );
2401
2402 client.telemetry().flush_events().await;
2403 }
2404 }
2405 })
2406 .detach();
2407 }
2408
2409 pub fn cumulative_token_usage(&self) -> TokenUsage {
2410 self.cumulative_token_usage
2411 }
2412
2413 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2414 let Some(model) = self.configured_model.as_ref() else {
2415 return TotalTokenUsage::default();
2416 };
2417
2418 let max = model.model.max_token_count();
2419
2420 let index = self
2421 .messages
2422 .iter()
2423 .position(|msg| msg.id == message_id)
2424 .unwrap_or(0);
2425
2426 if index == 0 {
2427 return TotalTokenUsage { total: 0, max };
2428 }
2429
2430 let token_usage = &self
2431 .request_token_usage
2432 .get(index - 1)
2433 .cloned()
2434 .unwrap_or_default();
2435
2436 TotalTokenUsage {
2437 total: token_usage.total_tokens() as usize,
2438 max,
2439 }
2440 }
2441
2442 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2443 let model = self.configured_model.as_ref()?;
2444
2445 let max = model.model.max_token_count();
2446
2447 if let Some(exceeded_error) = &self.exceeded_window_error {
2448 if model.model.id() == exceeded_error.model_id {
2449 return Some(TotalTokenUsage {
2450 total: exceeded_error.token_count,
2451 max,
2452 });
2453 }
2454 }
2455
2456 let total = self
2457 .token_usage_at_last_message()
2458 .unwrap_or_default()
2459 .total_tokens() as usize;
2460
2461 Some(TotalTokenUsage { total, max })
2462 }
2463
2464 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2465 self.request_token_usage
2466 .get(self.messages.len().saturating_sub(1))
2467 .or_else(|| self.request_token_usage.last())
2468 .cloned()
2469 }
2470
2471 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2472 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2473 self.request_token_usage
2474 .resize(self.messages.len(), placeholder);
2475
2476 if let Some(last) = self.request_token_usage.last_mut() {
2477 *last = token_usage;
2478 }
2479 }
2480
2481 pub fn deny_tool_use(
2482 &mut self,
2483 tool_use_id: LanguageModelToolUseId,
2484 tool_name: Arc<str>,
2485 window: Option<AnyWindowHandle>,
2486 cx: &mut Context<Self>,
2487 ) {
2488 let err = Err(anyhow::anyhow!(
2489 "Permission to run tool action denied by user"
2490 ));
2491
2492 self.tool_use.insert_tool_output(
2493 tool_use_id.clone(),
2494 tool_name,
2495 err,
2496 self.configured_model.as_ref(),
2497 );
2498 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2499 }
2500}
2501
2502#[derive(Debug, Clone, Error)]
2503pub enum ThreadError {
2504 #[error("Payment required")]
2505 PaymentRequired,
2506 #[error("Max monthly spend reached")]
2507 MaxMonthlySpendReached,
2508 #[error("Model request limit reached")]
2509 ModelRequestLimitReached { plan: Plan },
2510 #[error("Message {header}: {message}")]
2511 Message {
2512 header: SharedString,
2513 message: SharedString,
2514 },
2515}
2516
2517#[derive(Debug, Clone)]
2518pub enum ThreadEvent {
2519 ShowError(ThreadError),
2520 UsageUpdated(RequestUsage),
2521 StreamedCompletion,
2522 ReceivedTextChunk,
2523 NewRequest,
2524 StreamedAssistantText(MessageId, String),
2525 StreamedAssistantThinking(MessageId, String),
2526 StreamedToolUse {
2527 tool_use_id: LanguageModelToolUseId,
2528 ui_text: Arc<str>,
2529 input: serde_json::Value,
2530 },
2531 InvalidToolInput {
2532 tool_use_id: LanguageModelToolUseId,
2533 ui_text: Arc<str>,
2534 invalid_input_json: Arc<str>,
2535 },
2536 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2537 MessageAdded(MessageId),
2538 MessageEdited(MessageId),
2539 MessageDeleted(MessageId),
2540 SummaryGenerated,
2541 SummaryChanged,
2542 UsePendingTools {
2543 tool_uses: Vec<PendingToolUse>,
2544 },
2545 ToolFinished {
2546 #[allow(unused)]
2547 tool_use_id: LanguageModelToolUseId,
2548 /// The pending tool use that corresponds to this tool.
2549 pending_tool_use: Option<PendingToolUse>,
2550 },
2551 CheckpointChanged,
2552 ToolConfirmationNeeded,
2553 CancelEditing,
2554 CompletionCanceled,
2555}
2556
2557impl EventEmitter<ThreadEvent> for Thread {}
2558
2559struct PendingCompletion {
2560 id: usize,
2561 queue_state: QueueState,
2562 _task: Task<()>,
2563}
2564
2565#[cfg(test)]
2566mod tests {
2567 use super::*;
2568 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2569 use assistant_settings::AssistantSettings;
2570 use assistant_tool::ToolRegistry;
2571 use context_server::ContextServerSettings;
2572 use editor::EditorSettings;
2573 use gpui::TestAppContext;
2574 use language_model::fake_provider::FakeLanguageModel;
2575 use project::{FakeFs, Project};
2576 use prompt_store::PromptBuilder;
2577 use serde_json::json;
2578 use settings::{Settings, SettingsStore};
2579 use std::sync::Arc;
2580 use theme::ThemeSettings;
2581 use util::path;
2582 use workspace::Workspace;
2583
2584 #[gpui::test]
2585 async fn test_message_with_context(cx: &mut TestAppContext) {
2586 init_test_settings(cx);
2587
2588 let project = create_test_project(
2589 cx,
2590 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2591 )
2592 .await;
2593
2594 let (_workspace, _thread_store, thread, context_store, model) =
2595 setup_test_environment(cx, project.clone()).await;
2596
2597 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2598 .await
2599 .unwrap();
2600
2601 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2602 let loaded_context = cx
2603 .update(|cx| load_context(vec![context], &project, &None, cx))
2604 .await;
2605
2606 // Insert user message with context
2607 let message_id = thread.update(cx, |thread, cx| {
2608 thread.insert_user_message(
2609 "Please explain this code",
2610 loaded_context,
2611 None,
2612 Vec::new(),
2613 cx,
2614 )
2615 });
2616
2617 // Check content and context in message object
2618 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2619
2620 // Use different path format strings based on platform for the test
2621 #[cfg(windows)]
2622 let path_part = r"test\code.rs";
2623 #[cfg(not(windows))]
2624 let path_part = "test/code.rs";
2625
2626 let expected_context = format!(
2627 r#"
2628<context>
2629The following items were attached by the user. They are up-to-date and don't need to be re-read.
2630
2631<files>
2632```rs {path_part}
2633fn main() {{
2634 println!("Hello, world!");
2635}}
2636```
2637</files>
2638</context>
2639"#
2640 );
2641
2642 assert_eq!(message.role, Role::User);
2643 assert_eq!(message.segments.len(), 1);
2644 assert_eq!(
2645 message.segments[0],
2646 MessageSegment::Text("Please explain this code".to_string())
2647 );
2648 assert_eq!(message.loaded_context.text, expected_context);
2649
2650 // Check message in request
2651 let request = thread.update(cx, |thread, cx| {
2652 thread.to_completion_request(model.clone(), cx)
2653 });
2654
2655 assert_eq!(request.messages.len(), 2);
2656 let expected_full_message = format!("{}Please explain this code", expected_context);
2657 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2658 }
2659
2660 #[gpui::test]
2661 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2662 init_test_settings(cx);
2663
2664 let project = create_test_project(
2665 cx,
2666 json!({
2667 "file1.rs": "fn function1() {}\n",
2668 "file2.rs": "fn function2() {}\n",
2669 "file3.rs": "fn function3() {}\n",
2670 "file4.rs": "fn function4() {}\n",
2671 }),
2672 )
2673 .await;
2674
2675 let (_, _thread_store, thread, context_store, model) =
2676 setup_test_environment(cx, project.clone()).await;
2677
2678 // First message with context 1
2679 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2680 .await
2681 .unwrap();
2682 let new_contexts = context_store.update(cx, |store, cx| {
2683 store.new_context_for_thread(thread.read(cx), None)
2684 });
2685 assert_eq!(new_contexts.len(), 1);
2686 let loaded_context = cx
2687 .update(|cx| load_context(new_contexts, &project, &None, cx))
2688 .await;
2689 let message1_id = thread.update(cx, |thread, cx| {
2690 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2691 });
2692
2693 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2694 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2695 .await
2696 .unwrap();
2697 let new_contexts = context_store.update(cx, |store, cx| {
2698 store.new_context_for_thread(thread.read(cx), None)
2699 });
2700 assert_eq!(new_contexts.len(), 1);
2701 let loaded_context = cx
2702 .update(|cx| load_context(new_contexts, &project, &None, cx))
2703 .await;
2704 let message2_id = thread.update(cx, |thread, cx| {
2705 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2706 });
2707
2708 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2709 //
2710 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2711 .await
2712 .unwrap();
2713 let new_contexts = context_store.update(cx, |store, cx| {
2714 store.new_context_for_thread(thread.read(cx), None)
2715 });
2716 assert_eq!(new_contexts.len(), 1);
2717 let loaded_context = cx
2718 .update(|cx| load_context(new_contexts, &project, &None, cx))
2719 .await;
2720 let message3_id = thread.update(cx, |thread, cx| {
2721 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2722 });
2723
2724 // Check what contexts are included in each message
2725 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2726 (
2727 thread.message(message1_id).unwrap().clone(),
2728 thread.message(message2_id).unwrap().clone(),
2729 thread.message(message3_id).unwrap().clone(),
2730 )
2731 });
2732
2733 // First message should include context 1
2734 assert!(message1.loaded_context.text.contains("file1.rs"));
2735
2736 // Second message should include only context 2 (not 1)
2737 assert!(!message2.loaded_context.text.contains("file1.rs"));
2738 assert!(message2.loaded_context.text.contains("file2.rs"));
2739
2740 // Third message should include only context 3 (not 1 or 2)
2741 assert!(!message3.loaded_context.text.contains("file1.rs"));
2742 assert!(!message3.loaded_context.text.contains("file2.rs"));
2743 assert!(message3.loaded_context.text.contains("file3.rs"));
2744
2745 // Check entire request to make sure all contexts are properly included
2746 let request = thread.update(cx, |thread, cx| {
2747 thread.to_completion_request(model.clone(), cx)
2748 });
2749
2750 // The request should contain all 3 messages
2751 assert_eq!(request.messages.len(), 4);
2752
2753 // Check that the contexts are properly formatted in each message
2754 assert!(request.messages[1].string_contents().contains("file1.rs"));
2755 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2756 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2757
2758 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2759 assert!(request.messages[2].string_contents().contains("file2.rs"));
2760 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2761
2762 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2763 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2764 assert!(request.messages[3].string_contents().contains("file3.rs"));
2765
2766 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2767 .await
2768 .unwrap();
2769 let new_contexts = context_store.update(cx, |store, cx| {
2770 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2771 });
2772 assert_eq!(new_contexts.len(), 3);
2773 let loaded_context = cx
2774 .update(|cx| load_context(new_contexts, &project, &None, cx))
2775 .await
2776 .loaded_context;
2777
2778 assert!(!loaded_context.text.contains("file1.rs"));
2779 assert!(loaded_context.text.contains("file2.rs"));
2780 assert!(loaded_context.text.contains("file3.rs"));
2781 assert!(loaded_context.text.contains("file4.rs"));
2782
2783 let new_contexts = context_store.update(cx, |store, cx| {
2784 // Remove file4.rs
2785 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2786 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2787 });
2788 assert_eq!(new_contexts.len(), 2);
2789 let loaded_context = cx
2790 .update(|cx| load_context(new_contexts, &project, &None, cx))
2791 .await
2792 .loaded_context;
2793
2794 assert!(!loaded_context.text.contains("file1.rs"));
2795 assert!(loaded_context.text.contains("file2.rs"));
2796 assert!(loaded_context.text.contains("file3.rs"));
2797 assert!(!loaded_context.text.contains("file4.rs"));
2798
2799 let new_contexts = context_store.update(cx, |store, cx| {
2800 // Remove file3.rs
2801 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2802 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2803 });
2804 assert_eq!(new_contexts.len(), 1);
2805 let loaded_context = cx
2806 .update(|cx| load_context(new_contexts, &project, &None, cx))
2807 .await
2808 .loaded_context;
2809
2810 assert!(!loaded_context.text.contains("file1.rs"));
2811 assert!(loaded_context.text.contains("file2.rs"));
2812 assert!(!loaded_context.text.contains("file3.rs"));
2813 assert!(!loaded_context.text.contains("file4.rs"));
2814 }
2815
2816 #[gpui::test]
2817 async fn test_message_without_files(cx: &mut TestAppContext) {
2818 init_test_settings(cx);
2819
2820 let project = create_test_project(
2821 cx,
2822 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2823 )
2824 .await;
2825
2826 let (_, _thread_store, thread, _context_store, model) =
2827 setup_test_environment(cx, project.clone()).await;
2828
2829 // Insert user message without any context (empty context vector)
2830 let message_id = thread.update(cx, |thread, cx| {
2831 thread.insert_user_message(
2832 "What is the best way to learn Rust?",
2833 ContextLoadResult::default(),
2834 None,
2835 Vec::new(),
2836 cx,
2837 )
2838 });
2839
2840 // Check content and context in message object
2841 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2842
2843 // Context should be empty when no files are included
2844 assert_eq!(message.role, Role::User);
2845 assert_eq!(message.segments.len(), 1);
2846 assert_eq!(
2847 message.segments[0],
2848 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2849 );
2850 assert_eq!(message.loaded_context.text, "");
2851
2852 // Check message in request
2853 let request = thread.update(cx, |thread, cx| {
2854 thread.to_completion_request(model.clone(), cx)
2855 });
2856
2857 assert_eq!(request.messages.len(), 2);
2858 assert_eq!(
2859 request.messages[1].string_contents(),
2860 "What is the best way to learn Rust?"
2861 );
2862
2863 // Add second message, also without context
2864 let message2_id = thread.update(cx, |thread, cx| {
2865 thread.insert_user_message(
2866 "Are there any good books?",
2867 ContextLoadResult::default(),
2868 None,
2869 Vec::new(),
2870 cx,
2871 )
2872 });
2873
2874 let message2 =
2875 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2876 assert_eq!(message2.loaded_context.text, "");
2877
2878 // Check that both messages appear in the request
2879 let request = thread.update(cx, |thread, cx| {
2880 thread.to_completion_request(model.clone(), cx)
2881 });
2882
2883 assert_eq!(request.messages.len(), 3);
2884 assert_eq!(
2885 request.messages[1].string_contents(),
2886 "What is the best way to learn Rust?"
2887 );
2888 assert_eq!(
2889 request.messages[2].string_contents(),
2890 "Are there any good books?"
2891 );
2892 }
2893
2894 #[gpui::test]
2895 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2896 init_test_settings(cx);
2897
2898 let project = create_test_project(
2899 cx,
2900 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2901 )
2902 .await;
2903
2904 let (_workspace, _thread_store, thread, context_store, model) =
2905 setup_test_environment(cx, project.clone()).await;
2906
2907 // Open buffer and add it to context
2908 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2909 .await
2910 .unwrap();
2911
2912 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2913 let loaded_context = cx
2914 .update(|cx| load_context(vec![context], &project, &None, cx))
2915 .await;
2916
2917 // Insert user message with the buffer as context
2918 thread.update(cx, |thread, cx| {
2919 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2920 });
2921
2922 // Create a request and check that it doesn't have a stale buffer warning yet
2923 let initial_request = thread.update(cx, |thread, cx| {
2924 thread.to_completion_request(model.clone(), cx)
2925 });
2926
2927 // Make sure we don't have a stale file warning yet
2928 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2929 msg.string_contents()
2930 .contains("These files changed since last read:")
2931 });
2932 assert!(
2933 !has_stale_warning,
2934 "Should not have stale buffer warning before buffer is modified"
2935 );
2936
2937 // Modify the buffer
2938 buffer.update(cx, |buffer, cx| {
2939 // Find a position at the end of line 1
2940 buffer.edit(
2941 [(1..1, "\n println!(\"Added a new line\");\n")],
2942 None,
2943 cx,
2944 );
2945 });
2946
2947 // Insert another user message without context
2948 thread.update(cx, |thread, cx| {
2949 thread.insert_user_message(
2950 "What does the code do now?",
2951 ContextLoadResult::default(),
2952 None,
2953 Vec::new(),
2954 cx,
2955 )
2956 });
2957
2958 // Create a new request and check for the stale buffer warning
2959 let new_request = thread.update(cx, |thread, cx| {
2960 thread.to_completion_request(model.clone(), cx)
2961 });
2962
2963 // We should have a stale file warning as the last message
2964 let last_message = new_request
2965 .messages
2966 .last()
2967 .expect("Request should have messages");
2968
2969 // The last message should be the stale buffer notification
2970 assert_eq!(last_message.role, Role::User);
2971
2972 // Check the exact content of the message
2973 let expected_content = "These files changed since last read:\n- code.rs\n";
2974 assert_eq!(
2975 last_message.string_contents(),
2976 expected_content,
2977 "Last message should be exactly the stale buffer notification"
2978 );
2979 }
2980
2981 fn init_test_settings(cx: &mut TestAppContext) {
2982 cx.update(|cx| {
2983 let settings_store = SettingsStore::test(cx);
2984 cx.set_global(settings_store);
2985 language::init(cx);
2986 Project::init_settings(cx);
2987 AssistantSettings::register(cx);
2988 prompt_store::init(cx);
2989 thread_store::init(cx);
2990 workspace::init_settings(cx);
2991 language_model::init_settings(cx);
2992 ThemeSettings::register(cx);
2993 ContextServerSettings::register(cx);
2994 EditorSettings::register(cx);
2995 ToolRegistry::default_global(cx);
2996 });
2997 }
2998
2999 // Helper to create a test project with test files
3000 async fn create_test_project(
3001 cx: &mut TestAppContext,
3002 files: serde_json::Value,
3003 ) -> Entity<Project> {
3004 let fs = FakeFs::new(cx.executor());
3005 fs.insert_tree(path!("/test"), files).await;
3006 Project::test(fs, [path!("/test").as_ref()], cx).await
3007 }
3008
3009 async fn setup_test_environment(
3010 cx: &mut TestAppContext,
3011 project: Entity<Project>,
3012 ) -> (
3013 Entity<Workspace>,
3014 Entity<ThreadStore>,
3015 Entity<Thread>,
3016 Entity<ContextStore>,
3017 Arc<dyn LanguageModel>,
3018 ) {
3019 let (workspace, cx) =
3020 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3021
3022 let thread_store = cx
3023 .update(|_, cx| {
3024 ThreadStore::load(
3025 project.clone(),
3026 cx.new(|_| ToolWorkingSet::default()),
3027 None,
3028 Arc::new(PromptBuilder::new(None).unwrap()),
3029 cx,
3030 )
3031 })
3032 .await
3033 .unwrap();
3034
3035 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3036 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3037
3038 let model = FakeLanguageModel::default();
3039 let model: Arc<dyn LanguageModel> = Arc::new(model);
3040
3041 (workspace, thread_store, thread, context_store, model)
3042 }
3043
3044 async fn add_file_to_context(
3045 project: &Entity<Project>,
3046 context_store: &Entity<ContextStore>,
3047 path: &str,
3048 cx: &mut TestAppContext,
3049 ) -> Result<Entity<language::Buffer>> {
3050 let buffer_path = project
3051 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3052 .unwrap();
3053
3054 let buffer = project
3055 .update(cx, |project, cx| {
3056 project.open_buffer(buffer_path.clone(), cx)
3057 })
3058 .await
3059 .unwrap();
3060
3061 context_store.update(cx, |context_store, cx| {
3062 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3063 });
3064
3065 Ok(buffer)
3066 }
3067}