1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use util::{ResultExt as _, TryFutureExt as _, post_inc};
39use uuid::Uuid;
40use zed_llm_client::CompletionMode;
41
42use crate::ThreadStore;
43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
44use crate::thread_store::{
45 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
46 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
47};
48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
101#[derive(Clone, Debug)]
102pub struct MessageCrease {
103 pub range: Range<usize>,
104 pub metadata: CreaseMetadata,
105 /// None for a deserialized message, Some otherwise.
106 pub context: Option<AgentContextHandle>,
107}
108
109/// A message in a [`Thread`].
110#[derive(Debug, Clone)]
111pub struct Message {
112 pub id: MessageId,
113 pub role: Role,
114 pub segments: Vec<MessageSegment>,
115 pub loaded_context: LoadedContext,
116 pub creases: Vec<MessageCrease>,
117}
118
119impl Message {
120 /// Returns whether the message contains any meaningful text that should be displayed
121 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
122 pub fn should_display_content(&self) -> bool {
123 self.segments.iter().all(|segment| segment.should_display())
124 }
125
126 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
127 if let Some(MessageSegment::Thinking {
128 text: segment,
129 signature: current_signature,
130 }) = self.segments.last_mut()
131 {
132 if let Some(signature) = signature {
133 *current_signature = Some(signature);
134 }
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Thinking {
138 text: text.to_string(),
139 signature,
140 });
141 }
142 }
143
144 pub fn push_text(&mut self, text: &str) {
145 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
146 segment.push_str(text);
147 } else {
148 self.segments.push(MessageSegment::Text(text.to_string()));
149 }
150 }
151
152 pub fn to_string(&self) -> String {
153 let mut result = String::new();
154
155 if !self.loaded_context.text.is_empty() {
156 result.push_str(&self.loaded_context.text);
157 }
158
159 for segment in &self.segments {
160 match segment {
161 MessageSegment::Text(text) => result.push_str(text),
162 MessageSegment::Thinking { text, .. } => {
163 result.push_str("<think>\n");
164 result.push_str(text);
165 result.push_str("\n</think>");
166 }
167 MessageSegment::RedactedThinking(_) => {}
168 }
169 }
170
171 result
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum MessageSegment {
177 Text(String),
178 Thinking {
179 text: String,
180 signature: Option<String>,
181 },
182 RedactedThinking(Vec<u8>),
183}
184
185impl MessageSegment {
186 pub fn should_display(&self) -> bool {
187 match self {
188 Self::Text(text) => text.is_empty(),
189 Self::Thinking { text, .. } => text.is_empty(),
190 Self::RedactedThinking(_) => false,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ProjectSnapshot {
197 pub worktree_snapshots: Vec<WorktreeSnapshot>,
198 pub unsaved_buffer_paths: Vec<String>,
199 pub timestamp: DateTime<Utc>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WorktreeSnapshot {
204 pub worktree_path: String,
205 pub git_state: Option<GitState>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GitState {
210 pub remote_url: Option<String>,
211 pub head_sha: Option<String>,
212 pub current_branch: Option<String>,
213 pub diff: Option<String>,
214}
215
216#[derive(Clone)]
217pub struct ThreadCheckpoint {
218 message_id: MessageId,
219 git_checkpoint: GitStoreCheckpoint,
220}
221
222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
223pub enum ThreadFeedback {
224 Positive,
225 Negative,
226}
227
228pub enum LastRestoreCheckpoint {
229 Pending {
230 message_id: MessageId,
231 },
232 Error {
233 message_id: MessageId,
234 error: String,
235 },
236}
237
238impl LastRestoreCheckpoint {
239 pub fn message_id(&self) -> MessageId {
240 match self {
241 LastRestoreCheckpoint::Pending { message_id } => *message_id,
242 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub enum DetailedSummaryState {
249 #[default]
250 NotGenerated,
251 Generating {
252 message_id: MessageId,
253 },
254 Generated {
255 text: SharedString,
256 message_id: MessageId,
257 },
258}
259
260impl DetailedSummaryState {
261 fn text(&self) -> Option<SharedString> {
262 if let Self::Generated { text, .. } = self {
263 Some(text.clone())
264 } else {
265 None
266 }
267 }
268}
269
270#[derive(Default)]
271pub struct TotalTokenUsage {
272 pub total: usize,
273 pub max: usize,
274}
275
276impl TotalTokenUsage {
277 pub fn ratio(&self) -> TokenUsageRatio {
278 #[cfg(debug_assertions)]
279 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
280 .unwrap_or("0.8".to_string())
281 .parse()
282 .unwrap();
283 #[cfg(not(debug_assertions))]
284 let warning_threshold: f32 = 0.8;
285
286 // When the maximum is unknown because there is no selected model,
287 // avoid showing the token limit warning.
288 if self.max == 0 {
289 TokenUsageRatio::Normal
290 } else if self.total >= self.max {
291 TokenUsageRatio::Exceeded
292 } else if self.total as f32 / self.max as f32 >= warning_threshold {
293 TokenUsageRatio::Warning
294 } else {
295 TokenUsageRatio::Normal
296 }
297 }
298
299 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
300 TotalTokenUsage {
301 total: self.total + tokens,
302 max: self.max,
303 }
304 }
305}
306
307#[derive(Debug, Default, PartialEq, Eq)]
308pub enum TokenUsageRatio {
309 #[default]
310 Normal,
311 Warning,
312 Exceeded,
313}
314
315fn default_completion_mode(cx: &App) -> CompletionMode {
316 if cx.is_staff() {
317 CompletionMode::Max
318 } else {
319 CompletionMode::Normal
320 }
321}
322
323#[derive(Debug, Clone, Copy)]
324pub enum QueueState {
325 Sending,
326 Queued { position: usize },
327 Started,
328}
329
330/// A thread of conversation with the LLM.
331pub struct Thread {
332 id: ThreadId,
333 updated_at: DateTime<Utc>,
334 summary: Option<SharedString>,
335 pending_summary: Task<Option<()>>,
336 detailed_summary_task: Task<Option<()>>,
337 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
338 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
339 completion_mode: CompletionMode,
340 messages: Vec<Message>,
341 next_message_id: MessageId,
342 last_prompt_id: PromptId,
343 project_context: SharedProjectContext,
344 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
345 completion_count: usize,
346 pending_completions: Vec<PendingCompletion>,
347 project: Entity<Project>,
348 prompt_builder: Arc<PromptBuilder>,
349 tools: Entity<ToolWorkingSet>,
350 tool_use: ToolUseState,
351 action_log: Entity<ActionLog>,
352 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
353 pending_checkpoint: Option<ThreadCheckpoint>,
354 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
355 request_token_usage: Vec<TokenUsage>,
356 cumulative_token_usage: TokenUsage,
357 exceeded_window_error: Option<ExceededWindowError>,
358 feedback: Option<ThreadFeedback>,
359 message_feedback: HashMap<MessageId, ThreadFeedback>,
360 last_auto_capture_at: Option<Instant>,
361 request_callback: Option<
362 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
363 >,
364 remaining_turns: u32,
365 configured_model: Option<ConfiguredModel>,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct ExceededWindowError {
370 /// Model used when last message exceeded context window
371 model_id: LanguageModelId,
372 /// Token count including last message
373 token_count: usize,
374}
375
376impl Thread {
377 pub fn new(
378 project: Entity<Project>,
379 tools: Entity<ToolWorkingSet>,
380 prompt_builder: Arc<PromptBuilder>,
381 system_prompt: SharedProjectContext,
382 cx: &mut Context<Self>,
383 ) -> Self {
384 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
385 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
386
387 Self {
388 id: ThreadId::new(),
389 updated_at: Utc::now(),
390 summary: None,
391 pending_summary: Task::ready(None),
392 detailed_summary_task: Task::ready(None),
393 detailed_summary_tx,
394 detailed_summary_rx,
395 completion_mode: default_completion_mode(cx),
396 messages: Vec::new(),
397 next_message_id: MessageId(0),
398 last_prompt_id: PromptId::new(),
399 project_context: system_prompt,
400 checkpoints_by_message: HashMap::default(),
401 completion_count: 0,
402 pending_completions: Vec::new(),
403 project: project.clone(),
404 prompt_builder,
405 tools: tools.clone(),
406 last_restore_checkpoint: None,
407 pending_checkpoint: None,
408 tool_use: ToolUseState::new(tools.clone()),
409 action_log: cx.new(|_| ActionLog::new(project.clone())),
410 initial_project_snapshot: {
411 let project_snapshot = Self::project_snapshot(project, cx);
412 cx.foreground_executor()
413 .spawn(async move { Some(project_snapshot.await) })
414 .shared()
415 },
416 request_token_usage: Vec::new(),
417 cumulative_token_usage: TokenUsage::default(),
418 exceeded_window_error: None,
419 feedback: None,
420 message_feedback: HashMap::default(),
421 last_auto_capture_at: None,
422 request_callback: None,
423 remaining_turns: u32::MAX,
424 configured_model,
425 }
426 }
427
428 pub fn deserialize(
429 id: ThreadId,
430 serialized: SerializedThread,
431 project: Entity<Project>,
432 tools: Entity<ToolWorkingSet>,
433 prompt_builder: Arc<PromptBuilder>,
434 project_context: SharedProjectContext,
435 cx: &mut Context<Self>,
436 ) -> Self {
437 let next_message_id = MessageId(
438 serialized
439 .messages
440 .last()
441 .map(|message| message.id.0 + 1)
442 .unwrap_or(0),
443 );
444 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
445 let (detailed_summary_tx, detailed_summary_rx) =
446 postage::watch::channel_with(serialized.detailed_summary_state);
447
448 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
449 serialized
450 .model
451 .and_then(|model| {
452 let model = SelectedModel {
453 provider: model.provider.clone().into(),
454 model: model.model.clone().into(),
455 };
456 registry.select_model(&model, cx)
457 })
458 .or_else(|| registry.default_model())
459 });
460
461 Self {
462 id,
463 updated_at: serialized.updated_at,
464 summary: Some(serialized.summary),
465 pending_summary: Task::ready(None),
466 detailed_summary_task: Task::ready(None),
467 detailed_summary_tx,
468 detailed_summary_rx,
469 completion_mode: default_completion_mode(cx),
470 messages: serialized
471 .messages
472 .into_iter()
473 .map(|message| Message {
474 id: message.id,
475 role: message.role,
476 segments: message
477 .segments
478 .into_iter()
479 .map(|segment| match segment {
480 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
481 SerializedMessageSegment::Thinking { text, signature } => {
482 MessageSegment::Thinking { text, signature }
483 }
484 SerializedMessageSegment::RedactedThinking { data } => {
485 MessageSegment::RedactedThinking(data)
486 }
487 })
488 .collect(),
489 loaded_context: LoadedContext {
490 contexts: Vec::new(),
491 text: message.context,
492 images: Vec::new(),
493 },
494 creases: message
495 .creases
496 .into_iter()
497 .map(|crease| MessageCrease {
498 range: crease.start..crease.end,
499 metadata: CreaseMetadata {
500 icon_path: crease.icon_path,
501 label: crease.label,
502 },
503 context: None,
504 })
505 .collect(),
506 })
507 .collect(),
508 next_message_id,
509 last_prompt_id: PromptId::new(),
510 project_context,
511 checkpoints_by_message: HashMap::default(),
512 completion_count: 0,
513 pending_completions: Vec::new(),
514 last_restore_checkpoint: None,
515 pending_checkpoint: None,
516 project: project.clone(),
517 prompt_builder,
518 tools,
519 tool_use,
520 action_log: cx.new(|_| ActionLog::new(project)),
521 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
522 request_token_usage: serialized.request_token_usage,
523 cumulative_token_usage: serialized.cumulative_token_usage,
524 exceeded_window_error: None,
525 feedback: None,
526 message_feedback: HashMap::default(),
527 last_auto_capture_at: None,
528 request_callback: None,
529 remaining_turns: u32::MAX,
530 configured_model,
531 }
532 }
533
534 pub fn set_request_callback(
535 &mut self,
536 callback: impl 'static
537 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
538 ) {
539 self.request_callback = Some(Box::new(callback));
540 }
541
542 pub fn id(&self) -> &ThreadId {
543 &self.id
544 }
545
546 pub fn is_empty(&self) -> bool {
547 self.messages.is_empty()
548 }
549
550 pub fn updated_at(&self) -> DateTime<Utc> {
551 self.updated_at
552 }
553
554 pub fn touch_updated_at(&mut self) {
555 self.updated_at = Utc::now();
556 }
557
558 pub fn advance_prompt_id(&mut self) {
559 self.last_prompt_id = PromptId::new();
560 }
561
562 pub fn summary(&self) -> Option<SharedString> {
563 self.summary.clone()
564 }
565
566 pub fn project_context(&self) -> SharedProjectContext {
567 self.project_context.clone()
568 }
569
570 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
571 if self.configured_model.is_none() {
572 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
573 }
574 self.configured_model.clone()
575 }
576
577 pub fn configured_model(&self) -> Option<ConfiguredModel> {
578 self.configured_model.clone()
579 }
580
581 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
582 self.configured_model = model;
583 cx.notify();
584 }
585
586 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
587
588 pub fn summary_or_default(&self) -> SharedString {
589 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
590 }
591
592 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
593 let Some(current_summary) = &self.summary else {
594 // Don't allow setting summary until generated
595 return;
596 };
597
598 let mut new_summary = new_summary.into();
599
600 if new_summary.is_empty() {
601 new_summary = Self::DEFAULT_SUMMARY;
602 }
603
604 if current_summary != &new_summary {
605 self.summary = Some(new_summary);
606 cx.emit(ThreadEvent::SummaryChanged);
607 }
608 }
609
610 pub fn completion_mode(&self) -> CompletionMode {
611 self.completion_mode
612 }
613
614 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
615 self.completion_mode = mode;
616 }
617
618 pub fn message(&self, id: MessageId) -> Option<&Message> {
619 let index = self
620 .messages
621 .binary_search_by(|message| message.id.cmp(&id))
622 .ok()?;
623
624 self.messages.get(index)
625 }
626
627 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
628 self.messages.iter()
629 }
630
631 pub fn is_generating(&self) -> bool {
632 !self.pending_completions.is_empty() || !self.all_tools_finished()
633 }
634
635 pub fn queue_state(&self) -> Option<QueueState> {
636 self.pending_completions
637 .first()
638 .map(|pending_completion| pending_completion.queue_state)
639 }
640
641 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
642 &self.tools
643 }
644
645 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
646 self.tool_use
647 .pending_tool_uses()
648 .into_iter()
649 .find(|tool_use| &tool_use.id == id)
650 }
651
652 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
653 self.tool_use
654 .pending_tool_uses()
655 .into_iter()
656 .filter(|tool_use| tool_use.status.needs_confirmation())
657 }
658
659 pub fn has_pending_tool_uses(&self) -> bool {
660 !self.tool_use.pending_tool_uses().is_empty()
661 }
662
663 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
664 self.checkpoints_by_message.get(&id).cloned()
665 }
666
667 pub fn restore_checkpoint(
668 &mut self,
669 checkpoint: ThreadCheckpoint,
670 cx: &mut Context<Self>,
671 ) -> Task<Result<()>> {
672 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
673 message_id: checkpoint.message_id,
674 });
675 cx.emit(ThreadEvent::CheckpointChanged);
676 cx.notify();
677
678 let git_store = self.project().read(cx).git_store().clone();
679 let restore = git_store.update(cx, |git_store, cx| {
680 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
681 });
682
683 cx.spawn(async move |this, cx| {
684 let result = restore.await;
685 this.update(cx, |this, cx| {
686 if let Err(err) = result.as_ref() {
687 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
688 message_id: checkpoint.message_id,
689 error: err.to_string(),
690 });
691 } else {
692 this.truncate(checkpoint.message_id, cx);
693 this.last_restore_checkpoint = None;
694 }
695 this.pending_checkpoint = None;
696 cx.emit(ThreadEvent::CheckpointChanged);
697 cx.notify();
698 })?;
699 result
700 })
701 }
702
703 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
704 let pending_checkpoint = if self.is_generating() {
705 return;
706 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
707 checkpoint
708 } else {
709 return;
710 };
711
712 let git_store = self.project.read(cx).git_store().clone();
713 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
714 cx.spawn(async move |this, cx| match final_checkpoint.await {
715 Ok(final_checkpoint) => {
716 let equal = git_store
717 .update(cx, |store, cx| {
718 store.compare_checkpoints(
719 pending_checkpoint.git_checkpoint.clone(),
720 final_checkpoint.clone(),
721 cx,
722 )
723 })?
724 .await
725 .unwrap_or(false);
726
727 if !equal {
728 this.update(cx, |this, cx| {
729 this.insert_checkpoint(pending_checkpoint, cx)
730 })?;
731 }
732
733 Ok(())
734 }
735 Err(_) => this.update(cx, |this, cx| {
736 this.insert_checkpoint(pending_checkpoint, cx)
737 }),
738 })
739 .detach();
740 }
741
742 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
743 self.checkpoints_by_message
744 .insert(checkpoint.message_id, checkpoint);
745 cx.emit(ThreadEvent::CheckpointChanged);
746 cx.notify();
747 }
748
749 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
750 self.last_restore_checkpoint.as_ref()
751 }
752
753 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
754 let Some(message_ix) = self
755 .messages
756 .iter()
757 .rposition(|message| message.id == message_id)
758 else {
759 return;
760 };
761 for deleted_message in self.messages.drain(message_ix..) {
762 self.checkpoints_by_message.remove(&deleted_message.id);
763 }
764 cx.notify();
765 }
766
767 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
768 self.messages
769 .iter()
770 .find(|message| message.id == id)
771 .into_iter()
772 .flat_map(|message| message.loaded_context.contexts.iter())
773 }
774
775 pub fn is_turn_end(&self, ix: usize) -> bool {
776 if self.messages.is_empty() {
777 return false;
778 }
779
780 if !self.is_generating() && ix == self.messages.len() - 1 {
781 return true;
782 }
783
784 let Some(message) = self.messages.get(ix) else {
785 return false;
786 };
787
788 if message.role != Role::Assistant {
789 return false;
790 }
791
792 self.messages
793 .get(ix + 1)
794 .and_then(|message| {
795 self.message(message.id)
796 .map(|next_message| next_message.role == Role::User)
797 })
798 .unwrap_or(false)
799 }
800
801 /// Returns whether all of the tool uses have finished running.
802 pub fn all_tools_finished(&self) -> bool {
803 // If the only pending tool uses left are the ones with errors, then
804 // that means that we've finished running all of the pending tools.
805 self.tool_use
806 .pending_tool_uses()
807 .iter()
808 .all(|tool_use| tool_use.status.is_error())
809 }
810
811 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
812 self.tool_use.tool_uses_for_message(id, cx)
813 }
814
815 pub fn tool_results_for_message(
816 &self,
817 assistant_message_id: MessageId,
818 ) -> Vec<&LanguageModelToolResult> {
819 self.tool_use.tool_results_for_message(assistant_message_id)
820 }
821
822 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
823 self.tool_use.tool_result(id)
824 }
825
826 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
827 Some(&self.tool_use.tool_result(id)?.content)
828 }
829
830 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
831 self.tool_use.tool_result_card(id).cloned()
832 }
833
834 /// Return tools that are both enabled and supported by the model
835 pub fn available_tools(
836 &self,
837 cx: &App,
838 model: Arc<dyn LanguageModel>,
839 ) -> Vec<LanguageModelRequestTool> {
840 if model.supports_tools() {
841 self.tools()
842 .read(cx)
843 .enabled_tools(cx)
844 .into_iter()
845 .filter_map(|tool| {
846 // Skip tools that cannot be supported
847 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
848 Some(LanguageModelRequestTool {
849 name: tool.name(),
850 description: tool.description(),
851 input_schema,
852 })
853 })
854 .collect()
855 } else {
856 Vec::default()
857 }
858 }
859
860 pub fn insert_user_message(
861 &mut self,
862 text: impl Into<String>,
863 loaded_context: ContextLoadResult,
864 git_checkpoint: Option<GitStoreCheckpoint>,
865 creases: Vec<MessageCrease>,
866 cx: &mut Context<Self>,
867 ) -> MessageId {
868 if !loaded_context.referenced_buffers.is_empty() {
869 self.action_log.update(cx, |log, cx| {
870 for buffer in loaded_context.referenced_buffers {
871 log.track_buffer(buffer, cx);
872 }
873 });
874 }
875
876 let message_id = self.insert_message(
877 Role::User,
878 vec![MessageSegment::Text(text.into())],
879 loaded_context.loaded_context,
880 creases,
881 cx,
882 );
883
884 if let Some(git_checkpoint) = git_checkpoint {
885 self.pending_checkpoint = Some(ThreadCheckpoint {
886 message_id,
887 git_checkpoint,
888 });
889 }
890
891 self.auto_capture_telemetry(cx);
892
893 message_id
894 }
895
896 pub fn insert_assistant_message(
897 &mut self,
898 segments: Vec<MessageSegment>,
899 cx: &mut Context<Self>,
900 ) -> MessageId {
901 self.insert_message(
902 Role::Assistant,
903 segments,
904 LoadedContext::default(),
905 Vec::new(),
906 cx,
907 )
908 }
909
910 pub fn insert_message(
911 &mut self,
912 role: Role,
913 segments: Vec<MessageSegment>,
914 loaded_context: LoadedContext,
915 creases: Vec<MessageCrease>,
916 cx: &mut Context<Self>,
917 ) -> MessageId {
918 let id = self.next_message_id.post_inc();
919 self.messages.push(Message {
920 id,
921 role,
922 segments,
923 loaded_context,
924 creases,
925 });
926 self.touch_updated_at();
927 cx.emit(ThreadEvent::MessageAdded(id));
928 id
929 }
930
931 pub fn edit_message(
932 &mut self,
933 id: MessageId,
934 new_role: Role,
935 new_segments: Vec<MessageSegment>,
936 loaded_context: Option<LoadedContext>,
937 cx: &mut Context<Self>,
938 ) -> bool {
939 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
940 return false;
941 };
942 message.role = new_role;
943 message.segments = new_segments;
944 if let Some(context) = loaded_context {
945 message.loaded_context = context;
946 }
947 self.touch_updated_at();
948 cx.emit(ThreadEvent::MessageEdited(id));
949 true
950 }
951
952 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
953 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
954 return false;
955 };
956 self.messages.remove(index);
957 self.touch_updated_at();
958 cx.emit(ThreadEvent::MessageDeleted(id));
959 true
960 }
961
962 /// Returns the representation of this [`Thread`] in a textual form.
963 ///
964 /// This is the representation we use when attaching a thread as context to another thread.
965 pub fn text(&self) -> String {
966 let mut text = String::new();
967
968 for message in &self.messages {
969 text.push_str(match message.role {
970 language_model::Role::User => "User:",
971 language_model::Role::Assistant => "Assistant:",
972 language_model::Role::System => "System:",
973 });
974 text.push('\n');
975
976 for segment in &message.segments {
977 match segment {
978 MessageSegment::Text(content) => text.push_str(content),
979 MessageSegment::Thinking { text: content, .. } => {
980 text.push_str(&format!("<think>{}</think>", content))
981 }
982 MessageSegment::RedactedThinking(_) => {}
983 }
984 }
985 text.push('\n');
986 }
987
988 text
989 }
990
991 /// Serializes this thread into a format for storage or telemetry.
992 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
993 let initial_project_snapshot = self.initial_project_snapshot.clone();
994 cx.spawn(async move |this, cx| {
995 let initial_project_snapshot = initial_project_snapshot.await;
996 this.read_with(cx, |this, cx| SerializedThread {
997 version: SerializedThread::VERSION.to_string(),
998 summary: this.summary_or_default(),
999 updated_at: this.updated_at(),
1000 messages: this
1001 .messages()
1002 .map(|message| SerializedMessage {
1003 id: message.id,
1004 role: message.role,
1005 segments: message
1006 .segments
1007 .iter()
1008 .map(|segment| match segment {
1009 MessageSegment::Text(text) => {
1010 SerializedMessageSegment::Text { text: text.clone() }
1011 }
1012 MessageSegment::Thinking { text, signature } => {
1013 SerializedMessageSegment::Thinking {
1014 text: text.clone(),
1015 signature: signature.clone(),
1016 }
1017 }
1018 MessageSegment::RedactedThinking(data) => {
1019 SerializedMessageSegment::RedactedThinking {
1020 data: data.clone(),
1021 }
1022 }
1023 })
1024 .collect(),
1025 tool_uses: this
1026 .tool_uses_for_message(message.id, cx)
1027 .into_iter()
1028 .map(|tool_use| SerializedToolUse {
1029 id: tool_use.id,
1030 name: tool_use.name,
1031 input: tool_use.input,
1032 })
1033 .collect(),
1034 tool_results: this
1035 .tool_results_for_message(message.id)
1036 .into_iter()
1037 .map(|tool_result| SerializedToolResult {
1038 tool_use_id: tool_result.tool_use_id.clone(),
1039 is_error: tool_result.is_error,
1040 content: tool_result.content.clone(),
1041 })
1042 .collect(),
1043 context: message.loaded_context.text.clone(),
1044 creases: message
1045 .creases
1046 .iter()
1047 .map(|crease| SerializedCrease {
1048 start: crease.range.start,
1049 end: crease.range.end,
1050 icon_path: crease.metadata.icon_path.clone(),
1051 label: crease.metadata.label.clone(),
1052 })
1053 .collect(),
1054 })
1055 .collect(),
1056 initial_project_snapshot,
1057 cumulative_token_usage: this.cumulative_token_usage,
1058 request_token_usage: this.request_token_usage.clone(),
1059 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1060 exceeded_window_error: this.exceeded_window_error.clone(),
1061 model: this
1062 .configured_model
1063 .as_ref()
1064 .map(|model| SerializedLanguageModel {
1065 provider: model.provider.id().0.to_string(),
1066 model: model.model.id().0.to_string(),
1067 }),
1068 })
1069 })
1070 }
1071
1072 pub fn remaining_turns(&self) -> u32 {
1073 self.remaining_turns
1074 }
1075
1076 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1077 self.remaining_turns = remaining_turns;
1078 }
1079
1080 pub fn send_to_model(
1081 &mut self,
1082 model: Arc<dyn LanguageModel>,
1083 window: Option<AnyWindowHandle>,
1084 cx: &mut Context<Self>,
1085 ) {
1086 if self.remaining_turns == 0 {
1087 return;
1088 }
1089
1090 self.remaining_turns -= 1;
1091
1092 let request = self.to_completion_request(model.clone(), cx);
1093
1094 self.stream_completion(request, model, window, cx);
1095 }
1096
1097 pub fn used_tools_since_last_user_message(&self) -> bool {
1098 for message in self.messages.iter().rev() {
1099 if self.tool_use.message_has_tool_results(message.id) {
1100 return true;
1101 } else if message.role == Role::User {
1102 return false;
1103 }
1104 }
1105
1106 false
1107 }
1108
1109 pub fn to_completion_request(
1110 &self,
1111 model: Arc<dyn LanguageModel>,
1112 cx: &mut Context<Self>,
1113 ) -> LanguageModelRequest {
1114 let mut request = LanguageModelRequest {
1115 thread_id: Some(self.id.to_string()),
1116 prompt_id: Some(self.last_prompt_id.to_string()),
1117 mode: None,
1118 messages: vec![],
1119 tools: Vec::new(),
1120 stop: Vec::new(),
1121 temperature: None,
1122 };
1123
1124 let available_tools = self.available_tools(cx, model.clone());
1125 let available_tool_names = available_tools
1126 .iter()
1127 .map(|tool| tool.name.clone())
1128 .collect();
1129
1130 let model_context = &ModelContext {
1131 available_tools: available_tool_names,
1132 };
1133
1134 if let Some(project_context) = self.project_context.borrow().as_ref() {
1135 match self
1136 .prompt_builder
1137 .generate_assistant_system_prompt(project_context, model_context)
1138 {
1139 Err(err) => {
1140 let message = format!("{err:?}").into();
1141 log::error!("{message}");
1142 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1143 header: "Error generating system prompt".into(),
1144 message,
1145 }));
1146 }
1147 Ok(system_prompt) => {
1148 request.messages.push(LanguageModelRequestMessage {
1149 role: Role::System,
1150 content: vec![MessageContent::Text(system_prompt)],
1151 cache: true,
1152 });
1153 }
1154 }
1155 } else {
1156 let message = "Context for system prompt unexpectedly not ready.".into();
1157 log::error!("{message}");
1158 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1159 header: "Error generating system prompt".into(),
1160 message,
1161 }));
1162 }
1163
1164 for message in &self.messages {
1165 let mut request_message = LanguageModelRequestMessage {
1166 role: message.role,
1167 content: Vec::new(),
1168 cache: false,
1169 };
1170
1171 message
1172 .loaded_context
1173 .add_to_request_message(&mut request_message);
1174
1175 for segment in &message.segments {
1176 match segment {
1177 MessageSegment::Text(text) => {
1178 if !text.is_empty() {
1179 request_message
1180 .content
1181 .push(MessageContent::Text(text.into()));
1182 }
1183 }
1184 MessageSegment::Thinking { text, signature } => {
1185 if !text.is_empty() {
1186 request_message.content.push(MessageContent::Thinking {
1187 text: text.into(),
1188 signature: signature.clone(),
1189 });
1190 }
1191 }
1192 MessageSegment::RedactedThinking(data) => {
1193 request_message
1194 .content
1195 .push(MessageContent::RedactedThinking(data.clone()));
1196 }
1197 };
1198 }
1199
1200 self.tool_use
1201 .attach_tool_uses(message.id, &mut request_message);
1202
1203 request.messages.push(request_message);
1204
1205 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1206 request.messages.push(tool_results_message);
1207 }
1208 }
1209
1210 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1211 if let Some(last) = request.messages.last_mut() {
1212 last.cache = true;
1213 }
1214
1215 self.attached_tracked_files_state(&mut request.messages, cx);
1216
1217 request.tools = available_tools;
1218 request.mode = if model.supports_max_mode() {
1219 Some(self.completion_mode)
1220 } else {
1221 Some(CompletionMode::Normal)
1222 };
1223
1224 request
1225 }
1226
1227 fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest {
1228 let mut request = LanguageModelRequest {
1229 thread_id: None,
1230 prompt_id: None,
1231 mode: None,
1232 messages: vec![],
1233 tools: Vec::new(),
1234 stop: Vec::new(),
1235 temperature: None,
1236 };
1237
1238 for message in &self.messages {
1239 let mut request_message = LanguageModelRequestMessage {
1240 role: message.role,
1241 content: Vec::new(),
1242 cache: false,
1243 };
1244
1245 for segment in &message.segments {
1246 match segment {
1247 MessageSegment::Text(text) => request_message
1248 .content
1249 .push(MessageContent::Text(text.clone())),
1250 MessageSegment::Thinking { .. } => {}
1251 MessageSegment::RedactedThinking(_) => {}
1252 }
1253 }
1254
1255 if request_message.content.is_empty() {
1256 continue;
1257 }
1258
1259 request.messages.push(request_message);
1260 }
1261
1262 request.messages.push(LanguageModelRequestMessage {
1263 role: Role::User,
1264 content: vec![MessageContent::Text(added_user_message)],
1265 cache: false,
1266 });
1267
1268 request
1269 }
1270
1271 fn attached_tracked_files_state(
1272 &self,
1273 messages: &mut Vec<LanguageModelRequestMessage>,
1274 cx: &App,
1275 ) {
1276 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1277
1278 let mut stale_message = String::new();
1279
1280 let action_log = self.action_log.read(cx);
1281
1282 for stale_file in action_log.stale_buffers(cx) {
1283 let Some(file) = stale_file.read(cx).file() else {
1284 continue;
1285 };
1286
1287 if stale_message.is_empty() {
1288 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1289 }
1290
1291 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1292 }
1293
1294 let mut content = Vec::with_capacity(2);
1295
1296 if !stale_message.is_empty() {
1297 content.push(stale_message.into());
1298 }
1299
1300 if !content.is_empty() {
1301 let context_message = LanguageModelRequestMessage {
1302 role: Role::User,
1303 content,
1304 cache: false,
1305 };
1306
1307 messages.push(context_message);
1308 }
1309 }
1310
1311 pub fn stream_completion(
1312 &mut self,
1313 request: LanguageModelRequest,
1314 model: Arc<dyn LanguageModel>,
1315 window: Option<AnyWindowHandle>,
1316 cx: &mut Context<Self>,
1317 ) {
1318 let pending_completion_id = post_inc(&mut self.completion_count);
1319 let mut request_callback_parameters = if self.request_callback.is_some() {
1320 Some((request.clone(), Vec::new()))
1321 } else {
1322 None
1323 };
1324 let prompt_id = self.last_prompt_id.clone();
1325 let tool_use_metadata = ToolUseMetadata {
1326 model: model.clone(),
1327 thread_id: self.id.clone(),
1328 prompt_id: prompt_id.clone(),
1329 };
1330
1331 let task = cx.spawn(async move |thread, cx| {
1332 let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1333 let initial_token_usage =
1334 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1335 let stream_completion = async {
1336 let (mut events, usage) = stream_completion_future.await?;
1337
1338 let mut stop_reason = StopReason::EndTurn;
1339 let mut current_token_usage = TokenUsage::default();
1340
1341 if let Some(usage) = usage {
1342 thread
1343 .update(cx, |_thread, cx| {
1344 cx.emit(ThreadEvent::UsageUpdated(usage));
1345 })
1346 .ok();
1347 }
1348
1349 let mut request_assistant_message_id = None;
1350
1351 while let Some(event) = events.next().await {
1352 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1353 response_events
1354 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1355 }
1356
1357 thread.update(cx, |thread, cx| {
1358 let event = match event {
1359 Ok(event) => event,
1360 Err(LanguageModelCompletionError::BadInputJson {
1361 id,
1362 tool_name,
1363 raw_input: invalid_input_json,
1364 json_parse_error,
1365 }) => {
1366 thread.receive_invalid_tool_json(
1367 id,
1368 tool_name,
1369 invalid_input_json,
1370 json_parse_error,
1371 window,
1372 cx,
1373 );
1374 return Ok(());
1375 }
1376 Err(LanguageModelCompletionError::Other(error)) => {
1377 return Err(error);
1378 }
1379 };
1380
1381 match event {
1382 LanguageModelCompletionEvent::StartMessage { .. } => {
1383 request_assistant_message_id =
1384 Some(thread.insert_assistant_message(
1385 vec![MessageSegment::Text(String::new())],
1386 cx,
1387 ));
1388 }
1389 LanguageModelCompletionEvent::Stop(reason) => {
1390 stop_reason = reason;
1391 }
1392 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1393 thread.update_token_usage_at_last_message(token_usage);
1394 thread.cumulative_token_usage = thread.cumulative_token_usage
1395 + token_usage
1396 - current_token_usage;
1397 current_token_usage = token_usage;
1398 }
1399 LanguageModelCompletionEvent::Text(chunk) => {
1400 cx.emit(ThreadEvent::ReceivedTextChunk);
1401 if let Some(last_message) = thread.messages.last_mut() {
1402 if last_message.role == Role::Assistant
1403 && !thread.tool_use.has_tool_results(last_message.id)
1404 {
1405 last_message.push_text(&chunk);
1406 cx.emit(ThreadEvent::StreamedAssistantText(
1407 last_message.id,
1408 chunk,
1409 ));
1410 } else {
1411 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1412 // of a new Assistant response.
1413 //
1414 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1415 // will result in duplicating the text of the chunk in the rendered Markdown.
1416 request_assistant_message_id =
1417 Some(thread.insert_assistant_message(
1418 vec![MessageSegment::Text(chunk.to_string())],
1419 cx,
1420 ));
1421 };
1422 }
1423 }
1424 LanguageModelCompletionEvent::Thinking {
1425 text: chunk,
1426 signature,
1427 } => {
1428 if let Some(last_message) = thread.messages.last_mut() {
1429 if last_message.role == Role::Assistant
1430 && !thread.tool_use.has_tool_results(last_message.id)
1431 {
1432 last_message.push_thinking(&chunk, signature);
1433 cx.emit(ThreadEvent::StreamedAssistantThinking(
1434 last_message.id,
1435 chunk,
1436 ));
1437 } else {
1438 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1439 // of a new Assistant response.
1440 //
1441 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1442 // will result in duplicating the text of the chunk in the rendered Markdown.
1443 request_assistant_message_id =
1444 Some(thread.insert_assistant_message(
1445 vec![MessageSegment::Thinking {
1446 text: chunk.to_string(),
1447 signature,
1448 }],
1449 cx,
1450 ));
1451 };
1452 }
1453 }
1454 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1455 let last_assistant_message_id = request_assistant_message_id
1456 .unwrap_or_else(|| {
1457 let new_assistant_message_id =
1458 thread.insert_assistant_message(vec![], cx);
1459 request_assistant_message_id =
1460 Some(new_assistant_message_id);
1461 new_assistant_message_id
1462 });
1463
1464 let tool_use_id = tool_use.id.clone();
1465 let streamed_input = if tool_use.is_input_complete {
1466 None
1467 } else {
1468 Some((&tool_use.input).clone())
1469 };
1470
1471 let ui_text = thread.tool_use.request_tool_use(
1472 last_assistant_message_id,
1473 tool_use,
1474 tool_use_metadata.clone(),
1475 cx,
1476 );
1477
1478 if let Some(input) = streamed_input {
1479 cx.emit(ThreadEvent::StreamedToolUse {
1480 tool_use_id,
1481 ui_text,
1482 input,
1483 });
1484 }
1485 }
1486 LanguageModelCompletionEvent::QueueUpdate(queue_event) => {
1487 if let Some(completion) = thread
1488 .pending_completions
1489 .iter_mut()
1490 .find(|completion| completion.id == pending_completion_id)
1491 {
1492 completion.queue_state = match queue_event {
1493 language_model::QueueState::Queued { position } => {
1494 QueueState::Queued { position }
1495 }
1496 language_model::QueueState::Started => QueueState::Started,
1497 }
1498 }
1499 }
1500 }
1501
1502 thread.touch_updated_at();
1503 cx.emit(ThreadEvent::StreamedCompletion);
1504 cx.notify();
1505
1506 thread.auto_capture_telemetry(cx);
1507 Ok(())
1508 })??;
1509
1510 smol::future::yield_now().await;
1511 }
1512
1513 thread.update(cx, |thread, cx| {
1514 thread
1515 .pending_completions
1516 .retain(|completion| completion.id != pending_completion_id);
1517
1518 // If there is a response without tool use, summarize the message. Otherwise,
1519 // allow two tool uses before summarizing.
1520 if thread.summary.is_none()
1521 && thread.messages.len() >= 2
1522 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1523 {
1524 thread.summarize(cx);
1525 }
1526 })?;
1527
1528 anyhow::Ok(stop_reason)
1529 };
1530
1531 let result = stream_completion.await;
1532
1533 thread
1534 .update(cx, |thread, cx| {
1535 thread.finalize_pending_checkpoint(cx);
1536 match result.as_ref() {
1537 Ok(stop_reason) => match stop_reason {
1538 StopReason::ToolUse => {
1539 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1540 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1541 }
1542 StopReason::EndTurn => {}
1543 StopReason::MaxTokens => {}
1544 },
1545 Err(error) => {
1546 if error.is::<PaymentRequiredError>() {
1547 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1548 } else if error.is::<MaxMonthlySpendReachedError>() {
1549 cx.emit(ThreadEvent::ShowError(
1550 ThreadError::MaxMonthlySpendReached,
1551 ));
1552 } else if let Some(error) =
1553 error.downcast_ref::<ModelRequestLimitReachedError>()
1554 {
1555 cx.emit(ThreadEvent::ShowError(
1556 ThreadError::ModelRequestLimitReached { plan: error.plan },
1557 ));
1558 } else if let Some(known_error) =
1559 error.downcast_ref::<LanguageModelKnownError>()
1560 {
1561 match known_error {
1562 LanguageModelKnownError::ContextWindowLimitExceeded {
1563 tokens,
1564 } => {
1565 thread.exceeded_window_error = Some(ExceededWindowError {
1566 model_id: model.id(),
1567 token_count: *tokens,
1568 });
1569 cx.notify();
1570 }
1571 }
1572 } else {
1573 let error_message = error
1574 .chain()
1575 .map(|err| err.to_string())
1576 .collect::<Vec<_>>()
1577 .join("\n");
1578 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1579 header: "Error interacting with language model".into(),
1580 message: SharedString::from(error_message.clone()),
1581 }));
1582 }
1583
1584 thread.cancel_last_completion(window, cx);
1585 }
1586 }
1587 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1588
1589 if let Some((request_callback, (request, response_events))) = thread
1590 .request_callback
1591 .as_mut()
1592 .zip(request_callback_parameters.as_ref())
1593 {
1594 request_callback(request, response_events);
1595 }
1596
1597 thread.auto_capture_telemetry(cx);
1598
1599 if let Ok(initial_usage) = initial_token_usage {
1600 let usage = thread.cumulative_token_usage - initial_usage;
1601
1602 telemetry::event!(
1603 "Assistant Thread Completion",
1604 thread_id = thread.id().to_string(),
1605 prompt_id = prompt_id,
1606 model = model.telemetry_id(),
1607 model_provider = model.provider_id().to_string(),
1608 input_tokens = usage.input_tokens,
1609 output_tokens = usage.output_tokens,
1610 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1611 cache_read_input_tokens = usage.cache_read_input_tokens,
1612 );
1613 }
1614 })
1615 .ok();
1616 });
1617
1618 self.pending_completions.push(PendingCompletion {
1619 id: pending_completion_id,
1620 queue_state: QueueState::Sending,
1621 _task: task,
1622 });
1623 }
1624
1625 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1626 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1627 return;
1628 };
1629
1630 if !model.provider.is_authenticated(cx) {
1631 return;
1632 }
1633
1634 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1635 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1636 If the conversation is about a specific subject, include it in the title. \
1637 Be descriptive. DO NOT speak in the first person.";
1638
1639 let request = self.to_summarize_request(added_user_message.into());
1640
1641 self.pending_summary = cx.spawn(async move |this, cx| {
1642 async move {
1643 let stream = model.model.stream_completion_text_with_usage(request, &cx);
1644 let (mut messages, usage) = stream.await?;
1645
1646 if let Some(usage) = usage {
1647 this.update(cx, |_thread, cx| {
1648 cx.emit(ThreadEvent::UsageUpdated(usage));
1649 })
1650 .ok();
1651 }
1652
1653 let mut new_summary = String::new();
1654 while let Some(message) = messages.stream.next().await {
1655 let text = message?;
1656 let mut lines = text.lines();
1657 new_summary.extend(lines.next());
1658
1659 // Stop if the LLM generated multiple lines.
1660 if lines.next().is_some() {
1661 break;
1662 }
1663 }
1664
1665 this.update(cx, |this, cx| {
1666 if !new_summary.is_empty() {
1667 this.summary = Some(new_summary.into());
1668 }
1669
1670 cx.emit(ThreadEvent::SummaryGenerated);
1671 })?;
1672
1673 anyhow::Ok(())
1674 }
1675 .log_err()
1676 .await
1677 });
1678 }
1679
1680 pub fn start_generating_detailed_summary_if_needed(
1681 &mut self,
1682 thread_store: WeakEntity<ThreadStore>,
1683 cx: &mut Context<Self>,
1684 ) {
1685 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1686 return;
1687 };
1688
1689 match &*self.detailed_summary_rx.borrow() {
1690 DetailedSummaryState::Generating { message_id, .. }
1691 | DetailedSummaryState::Generated { message_id, .. }
1692 if *message_id == last_message_id =>
1693 {
1694 // Already up-to-date
1695 return;
1696 }
1697 _ => {}
1698 }
1699
1700 let Some(ConfiguredModel { model, provider }) =
1701 LanguageModelRegistry::read_global(cx).thread_summary_model()
1702 else {
1703 return;
1704 };
1705
1706 if !provider.is_authenticated(cx) {
1707 return;
1708 }
1709
1710 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1711 1. A brief overview of what was discussed\n\
1712 2. Key facts or information discovered\n\
1713 3. Outcomes or conclusions reached\n\
1714 4. Any action items or next steps if any\n\
1715 Format it in Markdown with headings and bullet points.";
1716
1717 let request = self.to_summarize_request(added_user_message.into());
1718
1719 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1720 message_id: last_message_id,
1721 };
1722
1723 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1724 // be better to allow the old task to complete, but this would require logic for choosing
1725 // which result to prefer (the old task could complete after the new one, resulting in a
1726 // stale summary).
1727 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1728 let stream = model.stream_completion_text(request, &cx);
1729 let Some(mut messages) = stream.await.log_err() else {
1730 thread
1731 .update(cx, |thread, _cx| {
1732 *thread.detailed_summary_tx.borrow_mut() =
1733 DetailedSummaryState::NotGenerated;
1734 })
1735 .ok()?;
1736 return None;
1737 };
1738
1739 let mut new_detailed_summary = String::new();
1740
1741 while let Some(chunk) = messages.stream.next().await {
1742 if let Some(chunk) = chunk.log_err() {
1743 new_detailed_summary.push_str(&chunk);
1744 }
1745 }
1746
1747 thread
1748 .update(cx, |thread, _cx| {
1749 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1750 text: new_detailed_summary.into(),
1751 message_id: last_message_id,
1752 };
1753 })
1754 .ok()?;
1755
1756 // Save thread so its summary can be reused later
1757 if let Some(thread) = thread.upgrade() {
1758 if let Ok(Ok(save_task)) = cx.update(|cx| {
1759 thread_store
1760 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1761 }) {
1762 save_task.await.log_err();
1763 }
1764 }
1765
1766 Some(())
1767 });
1768 }
1769
1770 pub async fn wait_for_detailed_summary_or_text(
1771 this: &Entity<Self>,
1772 cx: &mut AsyncApp,
1773 ) -> Option<SharedString> {
1774 let mut detailed_summary_rx = this
1775 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1776 .ok()?;
1777 loop {
1778 match detailed_summary_rx.recv().await? {
1779 DetailedSummaryState::Generating { .. } => {}
1780 DetailedSummaryState::NotGenerated => {
1781 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1782 }
1783 DetailedSummaryState::Generated { text, .. } => return Some(text),
1784 }
1785 }
1786 }
1787
1788 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1789 self.detailed_summary_rx
1790 .borrow()
1791 .text()
1792 .unwrap_or_else(|| self.text().into())
1793 }
1794
1795 pub fn is_generating_detailed_summary(&self) -> bool {
1796 matches!(
1797 &*self.detailed_summary_rx.borrow(),
1798 DetailedSummaryState::Generating { .. }
1799 )
1800 }
1801
1802 pub fn use_pending_tools(
1803 &mut self,
1804 window: Option<AnyWindowHandle>,
1805 cx: &mut Context<Self>,
1806 model: Arc<dyn LanguageModel>,
1807 ) -> Vec<PendingToolUse> {
1808 self.auto_capture_telemetry(cx);
1809 let request = self.to_completion_request(model, cx);
1810 let messages = Arc::new(request.messages);
1811 let pending_tool_uses = self
1812 .tool_use
1813 .pending_tool_uses()
1814 .into_iter()
1815 .filter(|tool_use| tool_use.status.is_idle())
1816 .cloned()
1817 .collect::<Vec<_>>();
1818
1819 for tool_use in pending_tool_uses.iter() {
1820 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1821 if tool.needs_confirmation(&tool_use.input, cx)
1822 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1823 {
1824 self.tool_use.confirm_tool_use(
1825 tool_use.id.clone(),
1826 tool_use.ui_text.clone(),
1827 tool_use.input.clone(),
1828 messages.clone(),
1829 tool,
1830 );
1831 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1832 } else {
1833 self.run_tool(
1834 tool_use.id.clone(),
1835 tool_use.ui_text.clone(),
1836 tool_use.input.clone(),
1837 &messages,
1838 tool,
1839 window,
1840 cx,
1841 );
1842 }
1843 }
1844 }
1845
1846 pending_tool_uses
1847 }
1848
1849 pub fn receive_invalid_tool_json(
1850 &mut self,
1851 tool_use_id: LanguageModelToolUseId,
1852 tool_name: Arc<str>,
1853 invalid_json: Arc<str>,
1854 error: String,
1855 window: Option<AnyWindowHandle>,
1856 cx: &mut Context<Thread>,
1857 ) {
1858 log::error!("The model returned invalid input JSON: {invalid_json}");
1859
1860 let pending_tool_use = self.tool_use.insert_tool_output(
1861 tool_use_id.clone(),
1862 tool_name,
1863 Err(anyhow!("Error parsing input JSON: {error}")),
1864 self.configured_model.as_ref(),
1865 );
1866 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1867 pending_tool_use.ui_text.clone()
1868 } else {
1869 log::error!(
1870 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1871 );
1872 format!("Unknown tool {}", tool_use_id).into()
1873 };
1874
1875 cx.emit(ThreadEvent::InvalidToolInput {
1876 tool_use_id: tool_use_id.clone(),
1877 ui_text,
1878 invalid_input_json: invalid_json,
1879 });
1880
1881 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1882 }
1883
1884 pub fn run_tool(
1885 &mut self,
1886 tool_use_id: LanguageModelToolUseId,
1887 ui_text: impl Into<SharedString>,
1888 input: serde_json::Value,
1889 messages: &[LanguageModelRequestMessage],
1890 tool: Arc<dyn Tool>,
1891 window: Option<AnyWindowHandle>,
1892 cx: &mut Context<Thread>,
1893 ) {
1894 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
1895 self.tool_use
1896 .run_pending_tool(tool_use_id, ui_text.into(), task);
1897 }
1898
1899 fn spawn_tool_use(
1900 &mut self,
1901 tool_use_id: LanguageModelToolUseId,
1902 messages: &[LanguageModelRequestMessage],
1903 input: serde_json::Value,
1904 tool: Arc<dyn Tool>,
1905 window: Option<AnyWindowHandle>,
1906 cx: &mut Context<Thread>,
1907 ) -> Task<()> {
1908 let tool_name: Arc<str> = tool.name().into();
1909
1910 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1911 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1912 } else {
1913 tool.run(
1914 input,
1915 messages,
1916 self.project.clone(),
1917 self.action_log.clone(),
1918 window,
1919 cx,
1920 )
1921 };
1922
1923 // Store the card separately if it exists
1924 if let Some(card) = tool_result.card.clone() {
1925 self.tool_use
1926 .insert_tool_result_card(tool_use_id.clone(), card);
1927 }
1928
1929 cx.spawn({
1930 async move |thread: WeakEntity<Thread>, cx| {
1931 let output = tool_result.output.await;
1932
1933 thread
1934 .update(cx, |thread, cx| {
1935 let pending_tool_use = thread.tool_use.insert_tool_output(
1936 tool_use_id.clone(),
1937 tool_name,
1938 output,
1939 thread.configured_model.as_ref(),
1940 );
1941 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1942 })
1943 .ok();
1944 }
1945 })
1946 }
1947
1948 fn tool_finished(
1949 &mut self,
1950 tool_use_id: LanguageModelToolUseId,
1951 pending_tool_use: Option<PendingToolUse>,
1952 canceled: bool,
1953 window: Option<AnyWindowHandle>,
1954 cx: &mut Context<Self>,
1955 ) {
1956 if self.all_tools_finished() {
1957 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
1958 if !canceled {
1959 self.send_to_model(model.clone(), window, cx);
1960 }
1961 self.auto_capture_telemetry(cx);
1962 }
1963 }
1964
1965 cx.emit(ThreadEvent::ToolFinished {
1966 tool_use_id,
1967 pending_tool_use,
1968 });
1969 }
1970
1971 /// Cancels the last pending completion, if there are any pending.
1972 ///
1973 /// Returns whether a completion was canceled.
1974 pub fn cancel_last_completion(
1975 &mut self,
1976 window: Option<AnyWindowHandle>,
1977 cx: &mut Context<Self>,
1978 ) -> bool {
1979 let mut canceled = self.pending_completions.pop().is_some();
1980
1981 for pending_tool_use in self.tool_use.cancel_pending() {
1982 canceled = true;
1983 self.tool_finished(
1984 pending_tool_use.id.clone(),
1985 Some(pending_tool_use),
1986 true,
1987 window,
1988 cx,
1989 );
1990 }
1991
1992 self.finalize_pending_checkpoint(cx);
1993 canceled
1994 }
1995
1996 /// Signals that any in-progress editing should be canceled.
1997 ///
1998 /// This method is used to notify listeners (like ActiveThread) that
1999 /// they should cancel any editing operations.
2000 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2001 cx.emit(ThreadEvent::CancelEditing);
2002 }
2003
2004 pub fn feedback(&self) -> Option<ThreadFeedback> {
2005 self.feedback
2006 }
2007
2008 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2009 self.message_feedback.get(&message_id).copied()
2010 }
2011
2012 pub fn report_message_feedback(
2013 &mut self,
2014 message_id: MessageId,
2015 feedback: ThreadFeedback,
2016 cx: &mut Context<Self>,
2017 ) -> Task<Result<()>> {
2018 if self.message_feedback.get(&message_id) == Some(&feedback) {
2019 return Task::ready(Ok(()));
2020 }
2021
2022 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2023 let serialized_thread = self.serialize(cx);
2024 let thread_id = self.id().clone();
2025 let client = self.project.read(cx).client();
2026
2027 let enabled_tool_names: Vec<String> = self
2028 .tools()
2029 .read(cx)
2030 .enabled_tools(cx)
2031 .iter()
2032 .map(|tool| tool.name().to_string())
2033 .collect();
2034
2035 self.message_feedback.insert(message_id, feedback);
2036
2037 cx.notify();
2038
2039 let message_content = self
2040 .message(message_id)
2041 .map(|msg| msg.to_string())
2042 .unwrap_or_default();
2043
2044 cx.background_spawn(async move {
2045 let final_project_snapshot = final_project_snapshot.await;
2046 let serialized_thread = serialized_thread.await?;
2047 let thread_data =
2048 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2049
2050 let rating = match feedback {
2051 ThreadFeedback::Positive => "positive",
2052 ThreadFeedback::Negative => "negative",
2053 };
2054 telemetry::event!(
2055 "Assistant Thread Rated",
2056 rating,
2057 thread_id,
2058 enabled_tool_names,
2059 message_id = message_id.0,
2060 message_content,
2061 thread_data,
2062 final_project_snapshot
2063 );
2064 client.telemetry().flush_events().await;
2065
2066 Ok(())
2067 })
2068 }
2069
2070 pub fn report_feedback(
2071 &mut self,
2072 feedback: ThreadFeedback,
2073 cx: &mut Context<Self>,
2074 ) -> Task<Result<()>> {
2075 let last_assistant_message_id = self
2076 .messages
2077 .iter()
2078 .rev()
2079 .find(|msg| msg.role == Role::Assistant)
2080 .map(|msg| msg.id);
2081
2082 if let Some(message_id) = last_assistant_message_id {
2083 self.report_message_feedback(message_id, feedback, cx)
2084 } else {
2085 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2086 let serialized_thread = self.serialize(cx);
2087 let thread_id = self.id().clone();
2088 let client = self.project.read(cx).client();
2089 self.feedback = Some(feedback);
2090 cx.notify();
2091
2092 cx.background_spawn(async move {
2093 let final_project_snapshot = final_project_snapshot.await;
2094 let serialized_thread = serialized_thread.await?;
2095 let thread_data = serde_json::to_value(serialized_thread)
2096 .unwrap_or_else(|_| serde_json::Value::Null);
2097
2098 let rating = match feedback {
2099 ThreadFeedback::Positive => "positive",
2100 ThreadFeedback::Negative => "negative",
2101 };
2102 telemetry::event!(
2103 "Assistant Thread Rated",
2104 rating,
2105 thread_id,
2106 thread_data,
2107 final_project_snapshot
2108 );
2109 client.telemetry().flush_events().await;
2110
2111 Ok(())
2112 })
2113 }
2114 }
2115
2116 /// Create a snapshot of the current project state including git information and unsaved buffers.
2117 fn project_snapshot(
2118 project: Entity<Project>,
2119 cx: &mut Context<Self>,
2120 ) -> Task<Arc<ProjectSnapshot>> {
2121 let git_store = project.read(cx).git_store().clone();
2122 let worktree_snapshots: Vec<_> = project
2123 .read(cx)
2124 .visible_worktrees(cx)
2125 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2126 .collect();
2127
2128 cx.spawn(async move |_, cx| {
2129 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2130
2131 let mut unsaved_buffers = Vec::new();
2132 cx.update(|app_cx| {
2133 let buffer_store = project.read(app_cx).buffer_store();
2134 for buffer_handle in buffer_store.read(app_cx).buffers() {
2135 let buffer = buffer_handle.read(app_cx);
2136 if buffer.is_dirty() {
2137 if let Some(file) = buffer.file() {
2138 let path = file.path().to_string_lossy().to_string();
2139 unsaved_buffers.push(path);
2140 }
2141 }
2142 }
2143 })
2144 .ok();
2145
2146 Arc::new(ProjectSnapshot {
2147 worktree_snapshots,
2148 unsaved_buffer_paths: unsaved_buffers,
2149 timestamp: Utc::now(),
2150 })
2151 })
2152 }
2153
2154 fn worktree_snapshot(
2155 worktree: Entity<project::Worktree>,
2156 git_store: Entity<GitStore>,
2157 cx: &App,
2158 ) -> Task<WorktreeSnapshot> {
2159 cx.spawn(async move |cx| {
2160 // Get worktree path and snapshot
2161 let worktree_info = cx.update(|app_cx| {
2162 let worktree = worktree.read(app_cx);
2163 let path = worktree.abs_path().to_string_lossy().to_string();
2164 let snapshot = worktree.snapshot();
2165 (path, snapshot)
2166 });
2167
2168 let Ok((worktree_path, _snapshot)) = worktree_info else {
2169 return WorktreeSnapshot {
2170 worktree_path: String::new(),
2171 git_state: None,
2172 };
2173 };
2174
2175 let git_state = git_store
2176 .update(cx, |git_store, cx| {
2177 git_store
2178 .repositories()
2179 .values()
2180 .find(|repo| {
2181 repo.read(cx)
2182 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2183 .is_some()
2184 })
2185 .cloned()
2186 })
2187 .ok()
2188 .flatten()
2189 .map(|repo| {
2190 repo.update(cx, |repo, _| {
2191 let current_branch =
2192 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2193 repo.send_job(None, |state, _| async move {
2194 let RepositoryState::Local { backend, .. } = state else {
2195 return GitState {
2196 remote_url: None,
2197 head_sha: None,
2198 current_branch,
2199 diff: None,
2200 };
2201 };
2202
2203 let remote_url = backend.remote_url("origin");
2204 let head_sha = backend.head_sha().await;
2205 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2206
2207 GitState {
2208 remote_url,
2209 head_sha,
2210 current_branch,
2211 diff,
2212 }
2213 })
2214 })
2215 });
2216
2217 let git_state = match git_state {
2218 Some(git_state) => match git_state.ok() {
2219 Some(git_state) => git_state.await.ok(),
2220 None => None,
2221 },
2222 None => None,
2223 };
2224
2225 WorktreeSnapshot {
2226 worktree_path,
2227 git_state,
2228 }
2229 })
2230 }
2231
2232 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2233 let mut markdown = Vec::new();
2234
2235 if let Some(summary) = self.summary() {
2236 writeln!(markdown, "# {summary}\n")?;
2237 };
2238
2239 for message in self.messages() {
2240 writeln!(
2241 markdown,
2242 "## {role}\n",
2243 role = match message.role {
2244 Role::User => "User",
2245 Role::Assistant => "Assistant",
2246 Role::System => "System",
2247 }
2248 )?;
2249
2250 if !message.loaded_context.text.is_empty() {
2251 writeln!(markdown, "{}", message.loaded_context.text)?;
2252 }
2253
2254 if !message.loaded_context.images.is_empty() {
2255 writeln!(
2256 markdown,
2257 "\n{} images attached as context.\n",
2258 message.loaded_context.images.len()
2259 )?;
2260 }
2261
2262 for segment in &message.segments {
2263 match segment {
2264 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2265 MessageSegment::Thinking { text, .. } => {
2266 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2267 }
2268 MessageSegment::RedactedThinking(_) => {}
2269 }
2270 }
2271
2272 for tool_use in self.tool_uses_for_message(message.id, cx) {
2273 writeln!(
2274 markdown,
2275 "**Use Tool: {} ({})**",
2276 tool_use.name, tool_use.id
2277 )?;
2278 writeln!(markdown, "```json")?;
2279 writeln!(
2280 markdown,
2281 "{}",
2282 serde_json::to_string_pretty(&tool_use.input)?
2283 )?;
2284 writeln!(markdown, "```")?;
2285 }
2286
2287 for tool_result in self.tool_results_for_message(message.id) {
2288 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2289 if tool_result.is_error {
2290 write!(markdown, " (Error)")?;
2291 }
2292
2293 writeln!(markdown, "**\n")?;
2294 writeln!(markdown, "{}", tool_result.content)?;
2295 }
2296 }
2297
2298 Ok(String::from_utf8_lossy(&markdown).to_string())
2299 }
2300
2301 pub fn keep_edits_in_range(
2302 &mut self,
2303 buffer: Entity<language::Buffer>,
2304 buffer_range: Range<language::Anchor>,
2305 cx: &mut Context<Self>,
2306 ) {
2307 self.action_log.update(cx, |action_log, cx| {
2308 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2309 });
2310 }
2311
2312 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2313 self.action_log
2314 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2315 }
2316
2317 pub fn reject_edits_in_ranges(
2318 &mut self,
2319 buffer: Entity<language::Buffer>,
2320 buffer_ranges: Vec<Range<language::Anchor>>,
2321 cx: &mut Context<Self>,
2322 ) -> Task<Result<()>> {
2323 self.action_log.update(cx, |action_log, cx| {
2324 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2325 })
2326 }
2327
2328 pub fn action_log(&self) -> &Entity<ActionLog> {
2329 &self.action_log
2330 }
2331
2332 pub fn project(&self) -> &Entity<Project> {
2333 &self.project
2334 }
2335
2336 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2337 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2338 return;
2339 }
2340
2341 let now = Instant::now();
2342 if let Some(last) = self.last_auto_capture_at {
2343 if now.duration_since(last).as_secs() < 10 {
2344 return;
2345 }
2346 }
2347
2348 self.last_auto_capture_at = Some(now);
2349
2350 let thread_id = self.id().clone();
2351 let github_login = self
2352 .project
2353 .read(cx)
2354 .user_store()
2355 .read(cx)
2356 .current_user()
2357 .map(|user| user.github_login.clone());
2358 let client = self.project.read(cx).client().clone();
2359 let serialize_task = self.serialize(cx);
2360
2361 cx.background_executor()
2362 .spawn(async move {
2363 if let Ok(serialized_thread) = serialize_task.await {
2364 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2365 telemetry::event!(
2366 "Agent Thread Auto-Captured",
2367 thread_id = thread_id.to_string(),
2368 thread_data = thread_data,
2369 auto_capture_reason = "tracked_user",
2370 github_login = github_login
2371 );
2372
2373 client.telemetry().flush_events().await;
2374 }
2375 }
2376 })
2377 .detach();
2378 }
2379
2380 pub fn cumulative_token_usage(&self) -> TokenUsage {
2381 self.cumulative_token_usage
2382 }
2383
2384 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2385 let Some(model) = self.configured_model.as_ref() else {
2386 return TotalTokenUsage::default();
2387 };
2388
2389 let max = model.model.max_token_count();
2390
2391 let index = self
2392 .messages
2393 .iter()
2394 .position(|msg| msg.id == message_id)
2395 .unwrap_or(0);
2396
2397 if index == 0 {
2398 return TotalTokenUsage { total: 0, max };
2399 }
2400
2401 let token_usage = &self
2402 .request_token_usage
2403 .get(index - 1)
2404 .cloned()
2405 .unwrap_or_default();
2406
2407 TotalTokenUsage {
2408 total: token_usage.total_tokens() as usize,
2409 max,
2410 }
2411 }
2412
2413 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2414 let model = self.configured_model.as_ref()?;
2415
2416 let max = model.model.max_token_count();
2417
2418 if let Some(exceeded_error) = &self.exceeded_window_error {
2419 if model.model.id() == exceeded_error.model_id {
2420 return Some(TotalTokenUsage {
2421 total: exceeded_error.token_count,
2422 max,
2423 });
2424 }
2425 }
2426
2427 let total = self
2428 .token_usage_at_last_message()
2429 .unwrap_or_default()
2430 .total_tokens() as usize;
2431
2432 Some(TotalTokenUsage { total, max })
2433 }
2434
2435 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2436 self.request_token_usage
2437 .get(self.messages.len().saturating_sub(1))
2438 .or_else(|| self.request_token_usage.last())
2439 .cloned()
2440 }
2441
2442 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2443 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2444 self.request_token_usage
2445 .resize(self.messages.len(), placeholder);
2446
2447 if let Some(last) = self.request_token_usage.last_mut() {
2448 *last = token_usage;
2449 }
2450 }
2451
2452 pub fn deny_tool_use(
2453 &mut self,
2454 tool_use_id: LanguageModelToolUseId,
2455 tool_name: Arc<str>,
2456 window: Option<AnyWindowHandle>,
2457 cx: &mut Context<Self>,
2458 ) {
2459 let err = Err(anyhow::anyhow!(
2460 "Permission to run tool action denied by user"
2461 ));
2462
2463 self.tool_use.insert_tool_output(
2464 tool_use_id.clone(),
2465 tool_name,
2466 err,
2467 self.configured_model.as_ref(),
2468 );
2469 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2470 }
2471}
2472
2473#[derive(Debug, Clone, Error)]
2474pub enum ThreadError {
2475 #[error("Payment required")]
2476 PaymentRequired,
2477 #[error("Max monthly spend reached")]
2478 MaxMonthlySpendReached,
2479 #[error("Model request limit reached")]
2480 ModelRequestLimitReached { plan: Plan },
2481 #[error("Message {header}: {message}")]
2482 Message {
2483 header: SharedString,
2484 message: SharedString,
2485 },
2486}
2487
2488#[derive(Debug, Clone)]
2489pub enum ThreadEvent {
2490 ShowError(ThreadError),
2491 UsageUpdated(RequestUsage),
2492 StreamedCompletion,
2493 ReceivedTextChunk,
2494 StreamedAssistantText(MessageId, String),
2495 StreamedAssistantThinking(MessageId, String),
2496 StreamedToolUse {
2497 tool_use_id: LanguageModelToolUseId,
2498 ui_text: Arc<str>,
2499 input: serde_json::Value,
2500 },
2501 InvalidToolInput {
2502 tool_use_id: LanguageModelToolUseId,
2503 ui_text: Arc<str>,
2504 invalid_input_json: Arc<str>,
2505 },
2506 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2507 MessageAdded(MessageId),
2508 MessageEdited(MessageId),
2509 MessageDeleted(MessageId),
2510 SummaryGenerated,
2511 SummaryChanged,
2512 UsePendingTools {
2513 tool_uses: Vec<PendingToolUse>,
2514 },
2515 ToolFinished {
2516 #[allow(unused)]
2517 tool_use_id: LanguageModelToolUseId,
2518 /// The pending tool use that corresponds to this tool.
2519 pending_tool_use: Option<PendingToolUse>,
2520 },
2521 CheckpointChanged,
2522 ToolConfirmationNeeded,
2523 CancelEditing,
2524}
2525
2526impl EventEmitter<ThreadEvent> for Thread {}
2527
2528struct PendingCompletion {
2529 id: usize,
2530 queue_state: QueueState,
2531 _task: Task<()>,
2532}
2533
2534#[cfg(test)]
2535mod tests {
2536 use super::*;
2537 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2538 use assistant_settings::AssistantSettings;
2539 use assistant_tool::ToolRegistry;
2540 use context_server::ContextServerSettings;
2541 use editor::EditorSettings;
2542 use gpui::TestAppContext;
2543 use language_model::fake_provider::FakeLanguageModel;
2544 use project::{FakeFs, Project};
2545 use prompt_store::PromptBuilder;
2546 use serde_json::json;
2547 use settings::{Settings, SettingsStore};
2548 use std::sync::Arc;
2549 use theme::ThemeSettings;
2550 use util::path;
2551 use workspace::Workspace;
2552
2553 #[gpui::test]
2554 async fn test_message_with_context(cx: &mut TestAppContext) {
2555 init_test_settings(cx);
2556
2557 let project = create_test_project(
2558 cx,
2559 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2560 )
2561 .await;
2562
2563 let (_workspace, _thread_store, thread, context_store, model) =
2564 setup_test_environment(cx, project.clone()).await;
2565
2566 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2567 .await
2568 .unwrap();
2569
2570 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2571 let loaded_context = cx
2572 .update(|cx| load_context(vec![context], &project, &None, cx))
2573 .await;
2574
2575 // Insert user message with context
2576 let message_id = thread.update(cx, |thread, cx| {
2577 thread.insert_user_message(
2578 "Please explain this code",
2579 loaded_context,
2580 None,
2581 Vec::new(),
2582 cx,
2583 )
2584 });
2585
2586 // Check content and context in message object
2587 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2588
2589 // Use different path format strings based on platform for the test
2590 #[cfg(windows)]
2591 let path_part = r"test\code.rs";
2592 #[cfg(not(windows))]
2593 let path_part = "test/code.rs";
2594
2595 let expected_context = format!(
2596 r#"
2597<context>
2598The following items were attached by the user. They are up-to-date and don't need to be re-read.
2599
2600<files>
2601```rs {path_part}
2602fn main() {{
2603 println!("Hello, world!");
2604}}
2605```
2606</files>
2607</context>
2608"#
2609 );
2610
2611 assert_eq!(message.role, Role::User);
2612 assert_eq!(message.segments.len(), 1);
2613 assert_eq!(
2614 message.segments[0],
2615 MessageSegment::Text("Please explain this code".to_string())
2616 );
2617 assert_eq!(message.loaded_context.text, expected_context);
2618
2619 // Check message in request
2620 let request = thread.update(cx, |thread, cx| {
2621 thread.to_completion_request(model.clone(), cx)
2622 });
2623
2624 assert_eq!(request.messages.len(), 2);
2625 let expected_full_message = format!("{}Please explain this code", expected_context);
2626 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2627 }
2628
2629 #[gpui::test]
2630 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2631 init_test_settings(cx);
2632
2633 let project = create_test_project(
2634 cx,
2635 json!({
2636 "file1.rs": "fn function1() {}\n",
2637 "file2.rs": "fn function2() {}\n",
2638 "file3.rs": "fn function3() {}\n",
2639 "file4.rs": "fn function4() {}\n",
2640 }),
2641 )
2642 .await;
2643
2644 let (_, _thread_store, thread, context_store, model) =
2645 setup_test_environment(cx, project.clone()).await;
2646
2647 // First message with context 1
2648 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2649 .await
2650 .unwrap();
2651 let new_contexts = context_store.update(cx, |store, cx| {
2652 store.new_context_for_thread(thread.read(cx), None)
2653 });
2654 assert_eq!(new_contexts.len(), 1);
2655 let loaded_context = cx
2656 .update(|cx| load_context(new_contexts, &project, &None, cx))
2657 .await;
2658 let message1_id = thread.update(cx, |thread, cx| {
2659 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2660 });
2661
2662 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2663 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2664 .await
2665 .unwrap();
2666 let new_contexts = context_store.update(cx, |store, cx| {
2667 store.new_context_for_thread(thread.read(cx), None)
2668 });
2669 assert_eq!(new_contexts.len(), 1);
2670 let loaded_context = cx
2671 .update(|cx| load_context(new_contexts, &project, &None, cx))
2672 .await;
2673 let message2_id = thread.update(cx, |thread, cx| {
2674 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2675 });
2676
2677 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2678 //
2679 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2680 .await
2681 .unwrap();
2682 let new_contexts = context_store.update(cx, |store, cx| {
2683 store.new_context_for_thread(thread.read(cx), None)
2684 });
2685 assert_eq!(new_contexts.len(), 1);
2686 let loaded_context = cx
2687 .update(|cx| load_context(new_contexts, &project, &None, cx))
2688 .await;
2689 let message3_id = thread.update(cx, |thread, cx| {
2690 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2691 });
2692
2693 // Check what contexts are included in each message
2694 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2695 (
2696 thread.message(message1_id).unwrap().clone(),
2697 thread.message(message2_id).unwrap().clone(),
2698 thread.message(message3_id).unwrap().clone(),
2699 )
2700 });
2701
2702 // First message should include context 1
2703 assert!(message1.loaded_context.text.contains("file1.rs"));
2704
2705 // Second message should include only context 2 (not 1)
2706 assert!(!message2.loaded_context.text.contains("file1.rs"));
2707 assert!(message2.loaded_context.text.contains("file2.rs"));
2708
2709 // Third message should include only context 3 (not 1 or 2)
2710 assert!(!message3.loaded_context.text.contains("file1.rs"));
2711 assert!(!message3.loaded_context.text.contains("file2.rs"));
2712 assert!(message3.loaded_context.text.contains("file3.rs"));
2713
2714 // Check entire request to make sure all contexts are properly included
2715 let request = thread.update(cx, |thread, cx| {
2716 thread.to_completion_request(model.clone(), cx)
2717 });
2718
2719 // The request should contain all 3 messages
2720 assert_eq!(request.messages.len(), 4);
2721
2722 // Check that the contexts are properly formatted in each message
2723 assert!(request.messages[1].string_contents().contains("file1.rs"));
2724 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2725 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2726
2727 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2728 assert!(request.messages[2].string_contents().contains("file2.rs"));
2729 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2730
2731 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2732 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2733 assert!(request.messages[3].string_contents().contains("file3.rs"));
2734
2735 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2736 .await
2737 .unwrap();
2738 let new_contexts = context_store.update(cx, |store, cx| {
2739 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2740 });
2741 assert_eq!(new_contexts.len(), 3);
2742 let loaded_context = cx
2743 .update(|cx| load_context(new_contexts, &project, &None, cx))
2744 .await
2745 .loaded_context;
2746
2747 assert!(!loaded_context.text.contains("file1.rs"));
2748 assert!(loaded_context.text.contains("file2.rs"));
2749 assert!(loaded_context.text.contains("file3.rs"));
2750 assert!(loaded_context.text.contains("file4.rs"));
2751
2752 let new_contexts = context_store.update(cx, |store, cx| {
2753 // Remove file4.rs
2754 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2755 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2756 });
2757 assert_eq!(new_contexts.len(), 2);
2758 let loaded_context = cx
2759 .update(|cx| load_context(new_contexts, &project, &None, cx))
2760 .await
2761 .loaded_context;
2762
2763 assert!(!loaded_context.text.contains("file1.rs"));
2764 assert!(loaded_context.text.contains("file2.rs"));
2765 assert!(loaded_context.text.contains("file3.rs"));
2766 assert!(!loaded_context.text.contains("file4.rs"));
2767
2768 let new_contexts = context_store.update(cx, |store, cx| {
2769 // Remove file3.rs
2770 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2771 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2772 });
2773 assert_eq!(new_contexts.len(), 1);
2774 let loaded_context = cx
2775 .update(|cx| load_context(new_contexts, &project, &None, cx))
2776 .await
2777 .loaded_context;
2778
2779 assert!(!loaded_context.text.contains("file1.rs"));
2780 assert!(loaded_context.text.contains("file2.rs"));
2781 assert!(!loaded_context.text.contains("file3.rs"));
2782 assert!(!loaded_context.text.contains("file4.rs"));
2783 }
2784
2785 #[gpui::test]
2786 async fn test_message_without_files(cx: &mut TestAppContext) {
2787 init_test_settings(cx);
2788
2789 let project = create_test_project(
2790 cx,
2791 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2792 )
2793 .await;
2794
2795 let (_, _thread_store, thread, _context_store, model) =
2796 setup_test_environment(cx, project.clone()).await;
2797
2798 // Insert user message without any context (empty context vector)
2799 let message_id = thread.update(cx, |thread, cx| {
2800 thread.insert_user_message(
2801 "What is the best way to learn Rust?",
2802 ContextLoadResult::default(),
2803 None,
2804 Vec::new(),
2805 cx,
2806 )
2807 });
2808
2809 // Check content and context in message object
2810 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2811
2812 // Context should be empty when no files are included
2813 assert_eq!(message.role, Role::User);
2814 assert_eq!(message.segments.len(), 1);
2815 assert_eq!(
2816 message.segments[0],
2817 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2818 );
2819 assert_eq!(message.loaded_context.text, "");
2820
2821 // Check message in request
2822 let request = thread.update(cx, |thread, cx| {
2823 thread.to_completion_request(model.clone(), cx)
2824 });
2825
2826 assert_eq!(request.messages.len(), 2);
2827 assert_eq!(
2828 request.messages[1].string_contents(),
2829 "What is the best way to learn Rust?"
2830 );
2831
2832 // Add second message, also without context
2833 let message2_id = thread.update(cx, |thread, cx| {
2834 thread.insert_user_message(
2835 "Are there any good books?",
2836 ContextLoadResult::default(),
2837 None,
2838 Vec::new(),
2839 cx,
2840 )
2841 });
2842
2843 let message2 =
2844 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2845 assert_eq!(message2.loaded_context.text, "");
2846
2847 // Check that both messages appear in the request
2848 let request = thread.update(cx, |thread, cx| {
2849 thread.to_completion_request(model.clone(), cx)
2850 });
2851
2852 assert_eq!(request.messages.len(), 3);
2853 assert_eq!(
2854 request.messages[1].string_contents(),
2855 "What is the best way to learn Rust?"
2856 );
2857 assert_eq!(
2858 request.messages[2].string_contents(),
2859 "Are there any good books?"
2860 );
2861 }
2862
2863 #[gpui::test]
2864 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2865 init_test_settings(cx);
2866
2867 let project = create_test_project(
2868 cx,
2869 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2870 )
2871 .await;
2872
2873 let (_workspace, _thread_store, thread, context_store, model) =
2874 setup_test_environment(cx, project.clone()).await;
2875
2876 // Open buffer and add it to context
2877 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2878 .await
2879 .unwrap();
2880
2881 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2882 let loaded_context = cx
2883 .update(|cx| load_context(vec![context], &project, &None, cx))
2884 .await;
2885
2886 // Insert user message with the buffer as context
2887 thread.update(cx, |thread, cx| {
2888 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
2889 });
2890
2891 // Create a request and check that it doesn't have a stale buffer warning yet
2892 let initial_request = thread.update(cx, |thread, cx| {
2893 thread.to_completion_request(model.clone(), cx)
2894 });
2895
2896 // Make sure we don't have a stale file warning yet
2897 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2898 msg.string_contents()
2899 .contains("These files changed since last read:")
2900 });
2901 assert!(
2902 !has_stale_warning,
2903 "Should not have stale buffer warning before buffer is modified"
2904 );
2905
2906 // Modify the buffer
2907 buffer.update(cx, |buffer, cx| {
2908 // Find a position at the end of line 1
2909 buffer.edit(
2910 [(1..1, "\n println!(\"Added a new line\");\n")],
2911 None,
2912 cx,
2913 );
2914 });
2915
2916 // Insert another user message without context
2917 thread.update(cx, |thread, cx| {
2918 thread.insert_user_message(
2919 "What does the code do now?",
2920 ContextLoadResult::default(),
2921 None,
2922 Vec::new(),
2923 cx,
2924 )
2925 });
2926
2927 // Create a new request and check for the stale buffer warning
2928 let new_request = thread.update(cx, |thread, cx| {
2929 thread.to_completion_request(model.clone(), cx)
2930 });
2931
2932 // We should have a stale file warning as the last message
2933 let last_message = new_request
2934 .messages
2935 .last()
2936 .expect("Request should have messages");
2937
2938 // The last message should be the stale buffer notification
2939 assert_eq!(last_message.role, Role::User);
2940
2941 // Check the exact content of the message
2942 let expected_content = "These files changed since last read:\n- code.rs\n";
2943 assert_eq!(
2944 last_message.string_contents(),
2945 expected_content,
2946 "Last message should be exactly the stale buffer notification"
2947 );
2948 }
2949
2950 fn init_test_settings(cx: &mut TestAppContext) {
2951 cx.update(|cx| {
2952 let settings_store = SettingsStore::test(cx);
2953 cx.set_global(settings_store);
2954 language::init(cx);
2955 Project::init_settings(cx);
2956 AssistantSettings::register(cx);
2957 prompt_store::init(cx);
2958 thread_store::init(cx);
2959 workspace::init_settings(cx);
2960 language_model::init_settings(cx);
2961 ThemeSettings::register(cx);
2962 ContextServerSettings::register(cx);
2963 EditorSettings::register(cx);
2964 ToolRegistry::default_global(cx);
2965 });
2966 }
2967
2968 // Helper to create a test project with test files
2969 async fn create_test_project(
2970 cx: &mut TestAppContext,
2971 files: serde_json::Value,
2972 ) -> Entity<Project> {
2973 let fs = FakeFs::new(cx.executor());
2974 fs.insert_tree(path!("/test"), files).await;
2975 Project::test(fs, [path!("/test").as_ref()], cx).await
2976 }
2977
2978 async fn setup_test_environment(
2979 cx: &mut TestAppContext,
2980 project: Entity<Project>,
2981 ) -> (
2982 Entity<Workspace>,
2983 Entity<ThreadStore>,
2984 Entity<Thread>,
2985 Entity<ContextStore>,
2986 Arc<dyn LanguageModel>,
2987 ) {
2988 let (workspace, cx) =
2989 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2990
2991 let thread_store = cx
2992 .update(|_, cx| {
2993 ThreadStore::load(
2994 project.clone(),
2995 cx.new(|_| ToolWorkingSet::default()),
2996 None,
2997 Arc::new(PromptBuilder::new(None).unwrap()),
2998 cx,
2999 )
3000 })
3001 .await
3002 .unwrap();
3003
3004 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3005 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3006
3007 let model = FakeLanguageModel::default();
3008 let model: Arc<dyn LanguageModel> = Arc::new(model);
3009
3010 (workspace, thread_store, thread, context_store, model)
3011 }
3012
3013 async fn add_file_to_context(
3014 project: &Entity<Project>,
3015 context_store: &Entity<ContextStore>,
3016 path: &str,
3017 cx: &mut TestAppContext,
3018 ) -> Result<Entity<language::Buffer>> {
3019 let buffer_path = project
3020 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3021 .unwrap();
3022
3023 let buffer = project
3024 .update(cx, |project, cx| {
3025 project.open_buffer(buffer_path.clone(), cx)
3026 })
3027 .await
3028 .unwrap();
3029
3030 context_store.update(cx, |context_store, cx| {
3031 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3032 });
3033
3034 Ok(buffer)
3035 }
3036}