1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, TryFutureExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: Option<SharedString>,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct ExceededWindowError {
366 /// Model used when last message exceeded context window
367 model_id: LanguageModelId,
368 /// Token count including last message
369 token_count: usize,
370}
371
372impl Thread {
373 pub fn new(
374 project: Entity<Project>,
375 tools: Entity<ToolWorkingSet>,
376 prompt_builder: Arc<PromptBuilder>,
377 system_prompt: SharedProjectContext,
378 cx: &mut Context<Self>,
379 ) -> Self {
380 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
381 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
382
383 Self {
384 id: ThreadId::new(),
385 updated_at: Utc::now(),
386 summary: None,
387 pending_summary: Task::ready(None),
388 detailed_summary_task: Task::ready(None),
389 detailed_summary_tx,
390 detailed_summary_rx,
391 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
392 messages: Vec::new(),
393 next_message_id: MessageId(0),
394 last_prompt_id: PromptId::new(),
395 project_context: system_prompt,
396 checkpoints_by_message: HashMap::default(),
397 completion_count: 0,
398 pending_completions: Vec::new(),
399 project: project.clone(),
400 prompt_builder,
401 tools: tools.clone(),
402 last_restore_checkpoint: None,
403 pending_checkpoint: None,
404 tool_use: ToolUseState::new(tools.clone()),
405 action_log: cx.new(|_| ActionLog::new(project.clone())),
406 initial_project_snapshot: {
407 let project_snapshot = Self::project_snapshot(project, cx);
408 cx.foreground_executor()
409 .spawn(async move { Some(project_snapshot.await) })
410 .shared()
411 },
412 request_token_usage: Vec::new(),
413 cumulative_token_usage: TokenUsage::default(),
414 exceeded_window_error: None,
415 last_usage: None,
416 tool_use_limit_reached: false,
417 feedback: None,
418 message_feedback: HashMap::default(),
419 last_auto_capture_at: None,
420 last_received_chunk_at: None,
421 request_callback: None,
422 remaining_turns: u32::MAX,
423 configured_model,
424 }
425 }
426
427 pub fn deserialize(
428 id: ThreadId,
429 serialized: SerializedThread,
430 project: Entity<Project>,
431 tools: Entity<ToolWorkingSet>,
432 prompt_builder: Arc<PromptBuilder>,
433 project_context: SharedProjectContext,
434 window: &mut Window,
435 cx: &mut Context<Self>,
436 ) -> Self {
437 let next_message_id = MessageId(
438 serialized
439 .messages
440 .last()
441 .map(|message| message.id.0 + 1)
442 .unwrap_or(0),
443 );
444 let tool_use = ToolUseState::from_serialized_messages(
445 tools.clone(),
446 &serialized.messages,
447 project.clone(),
448 window,
449 cx,
450 );
451 let (detailed_summary_tx, detailed_summary_rx) =
452 postage::watch::channel_with(serialized.detailed_summary_state);
453
454 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
455 serialized
456 .model
457 .and_then(|model| {
458 let model = SelectedModel {
459 provider: model.provider.clone().into(),
460 model: model.model.clone().into(),
461 };
462 registry.select_model(&model, cx)
463 })
464 .or_else(|| registry.default_model())
465 });
466
467 let completion_mode = serialized
468 .completion_mode
469 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
470
471 Self {
472 id,
473 updated_at: serialized.updated_at,
474 summary: Some(serialized.summary),
475 pending_summary: Task::ready(None),
476 detailed_summary_task: Task::ready(None),
477 detailed_summary_tx,
478 detailed_summary_rx,
479 completion_mode,
480 messages: serialized
481 .messages
482 .into_iter()
483 .map(|message| Message {
484 id: message.id,
485 role: message.role,
486 segments: message
487 .segments
488 .into_iter()
489 .map(|segment| match segment {
490 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
491 SerializedMessageSegment::Thinking { text, signature } => {
492 MessageSegment::Thinking { text, signature }
493 }
494 SerializedMessageSegment::RedactedThinking { data } => {
495 MessageSegment::RedactedThinking(data)
496 }
497 })
498 .collect(),
499 loaded_context: LoadedContext {
500 contexts: Vec::new(),
501 text: message.context,
502 images: Vec::new(),
503 },
504 creases: message
505 .creases
506 .into_iter()
507 .map(|crease| MessageCrease {
508 range: crease.start..crease.end,
509 metadata: CreaseMetadata {
510 icon_path: crease.icon_path,
511 label: crease.label,
512 },
513 context: None,
514 })
515 .collect(),
516 })
517 .collect(),
518 next_message_id,
519 last_prompt_id: PromptId::new(),
520 project_context,
521 checkpoints_by_message: HashMap::default(),
522 completion_count: 0,
523 pending_completions: Vec::new(),
524 last_restore_checkpoint: None,
525 pending_checkpoint: None,
526 project: project.clone(),
527 prompt_builder,
528 tools,
529 tool_use,
530 action_log: cx.new(|_| ActionLog::new(project)),
531 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
532 request_token_usage: serialized.request_token_usage,
533 cumulative_token_usage: serialized.cumulative_token_usage,
534 exceeded_window_error: None,
535 last_usage: None,
536 tool_use_limit_reached: false,
537 feedback: None,
538 message_feedback: HashMap::default(),
539 last_auto_capture_at: None,
540 last_received_chunk_at: None,
541 request_callback: None,
542 remaining_turns: u32::MAX,
543 configured_model,
544 }
545 }
546
547 pub fn set_request_callback(
548 &mut self,
549 callback: impl 'static
550 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
551 ) {
552 self.request_callback = Some(Box::new(callback));
553 }
554
555 pub fn id(&self) -> &ThreadId {
556 &self.id
557 }
558
559 pub fn is_empty(&self) -> bool {
560 self.messages.is_empty()
561 }
562
563 pub fn updated_at(&self) -> DateTime<Utc> {
564 self.updated_at
565 }
566
567 pub fn touch_updated_at(&mut self) {
568 self.updated_at = Utc::now();
569 }
570
571 pub fn advance_prompt_id(&mut self) {
572 self.last_prompt_id = PromptId::new();
573 }
574
575 pub fn summary(&self) -> Option<SharedString> {
576 self.summary.clone()
577 }
578
579 pub fn project_context(&self) -> SharedProjectContext {
580 self.project_context.clone()
581 }
582
583 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
584 if self.configured_model.is_none() {
585 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
586 }
587 self.configured_model.clone()
588 }
589
590 pub fn configured_model(&self) -> Option<ConfiguredModel> {
591 self.configured_model.clone()
592 }
593
594 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
595 self.configured_model = model;
596 cx.notify();
597 }
598
599 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
600
601 pub fn summary_or_default(&self) -> SharedString {
602 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
603 }
604
605 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
606 let Some(current_summary) = &self.summary else {
607 // Don't allow setting summary until generated
608 return;
609 };
610
611 let mut new_summary = new_summary.into();
612
613 if new_summary.is_empty() {
614 new_summary = Self::DEFAULT_SUMMARY;
615 }
616
617 if current_summary != &new_summary {
618 self.summary = Some(new_summary);
619 cx.emit(ThreadEvent::SummaryChanged);
620 }
621 }
622
623 pub fn completion_mode(&self) -> CompletionMode {
624 self.completion_mode
625 }
626
627 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
628 self.completion_mode = mode;
629 }
630
631 pub fn message(&self, id: MessageId) -> Option<&Message> {
632 let index = self
633 .messages
634 .binary_search_by(|message| message.id.cmp(&id))
635 .ok()?;
636
637 self.messages.get(index)
638 }
639
640 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
641 self.messages.iter()
642 }
643
644 pub fn is_generating(&self) -> bool {
645 !self.pending_completions.is_empty() || !self.all_tools_finished()
646 }
647
648 /// Indicates whether streaming of language model events is stale.
649 /// When `is_generating()` is false, this method returns `None`.
650 pub fn is_generation_stale(&self) -> Option<bool> {
651 const STALE_THRESHOLD: u128 = 250;
652
653 self.last_received_chunk_at
654 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
655 }
656
657 fn received_chunk(&mut self) {
658 self.last_received_chunk_at = Some(Instant::now());
659 }
660
661 pub fn queue_state(&self) -> Option<QueueState> {
662 self.pending_completions
663 .first()
664 .map(|pending_completion| pending_completion.queue_state)
665 }
666
667 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
668 &self.tools
669 }
670
671 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
672 self.tool_use
673 .pending_tool_uses()
674 .into_iter()
675 .find(|tool_use| &tool_use.id == id)
676 }
677
678 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
679 self.tool_use
680 .pending_tool_uses()
681 .into_iter()
682 .filter(|tool_use| tool_use.status.needs_confirmation())
683 }
684
685 pub fn has_pending_tool_uses(&self) -> bool {
686 !self.tool_use.pending_tool_uses().is_empty()
687 }
688
689 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
690 self.checkpoints_by_message.get(&id).cloned()
691 }
692
693 pub fn restore_checkpoint(
694 &mut self,
695 checkpoint: ThreadCheckpoint,
696 cx: &mut Context<Self>,
697 ) -> Task<Result<()>> {
698 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
699 message_id: checkpoint.message_id,
700 });
701 cx.emit(ThreadEvent::CheckpointChanged);
702 cx.notify();
703
704 let git_store = self.project().read(cx).git_store().clone();
705 let restore = git_store.update(cx, |git_store, cx| {
706 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
707 });
708
709 cx.spawn(async move |this, cx| {
710 let result = restore.await;
711 this.update(cx, |this, cx| {
712 if let Err(err) = result.as_ref() {
713 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
714 message_id: checkpoint.message_id,
715 error: err.to_string(),
716 });
717 } else {
718 this.truncate(checkpoint.message_id, cx);
719 this.last_restore_checkpoint = None;
720 }
721 this.pending_checkpoint = None;
722 cx.emit(ThreadEvent::CheckpointChanged);
723 cx.notify();
724 })?;
725 result
726 })
727 }
728
729 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
730 let pending_checkpoint = if self.is_generating() {
731 return;
732 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
733 checkpoint
734 } else {
735 return;
736 };
737
738 let git_store = self.project.read(cx).git_store().clone();
739 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
740 cx.spawn(async move |this, cx| match final_checkpoint.await {
741 Ok(final_checkpoint) => {
742 let equal = git_store
743 .update(cx, |store, cx| {
744 store.compare_checkpoints(
745 pending_checkpoint.git_checkpoint.clone(),
746 final_checkpoint.clone(),
747 cx,
748 )
749 })?
750 .await
751 .unwrap_or(false);
752
753 if !equal {
754 this.update(cx, |this, cx| {
755 this.insert_checkpoint(pending_checkpoint, cx)
756 })?;
757 }
758
759 Ok(())
760 }
761 Err(_) => this.update(cx, |this, cx| {
762 this.insert_checkpoint(pending_checkpoint, cx)
763 }),
764 })
765 .detach();
766 }
767
768 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
769 self.checkpoints_by_message
770 .insert(checkpoint.message_id, checkpoint);
771 cx.emit(ThreadEvent::CheckpointChanged);
772 cx.notify();
773 }
774
775 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
776 self.last_restore_checkpoint.as_ref()
777 }
778
779 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
780 let Some(message_ix) = self
781 .messages
782 .iter()
783 .rposition(|message| message.id == message_id)
784 else {
785 return;
786 };
787 for deleted_message in self.messages.drain(message_ix..) {
788 self.checkpoints_by_message.remove(&deleted_message.id);
789 }
790 cx.notify();
791 }
792
793 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
794 self.messages
795 .iter()
796 .find(|message| message.id == id)
797 .into_iter()
798 .flat_map(|message| message.loaded_context.contexts.iter())
799 }
800
801 pub fn is_turn_end(&self, ix: usize) -> bool {
802 if self.messages.is_empty() {
803 return false;
804 }
805
806 if !self.is_generating() && ix == self.messages.len() - 1 {
807 return true;
808 }
809
810 let Some(message) = self.messages.get(ix) else {
811 return false;
812 };
813
814 if message.role != Role::Assistant {
815 return false;
816 }
817
818 self.messages
819 .get(ix + 1)
820 .and_then(|message| {
821 self.message(message.id)
822 .map(|next_message| next_message.role == Role::User)
823 })
824 .unwrap_or(false)
825 }
826
827 pub fn last_usage(&self) -> Option<RequestUsage> {
828 self.last_usage
829 }
830
831 pub fn tool_use_limit_reached(&self) -> bool {
832 self.tool_use_limit_reached
833 }
834
835 /// Returns whether all of the tool uses have finished running.
836 pub fn all_tools_finished(&self) -> bool {
837 // If the only pending tool uses left are the ones with errors, then
838 // that means that we've finished running all of the pending tools.
839 self.tool_use
840 .pending_tool_uses()
841 .iter()
842 .all(|tool_use| tool_use.status.is_error())
843 }
844
845 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
846 self.tool_use.tool_uses_for_message(id, cx)
847 }
848
849 pub fn tool_results_for_message(
850 &self,
851 assistant_message_id: MessageId,
852 ) -> Vec<&LanguageModelToolResult> {
853 self.tool_use.tool_results_for_message(assistant_message_id)
854 }
855
856 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
857 self.tool_use.tool_result(id)
858 }
859
860 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
861 Some(&self.tool_use.tool_result(id)?.content)
862 }
863
864 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
865 self.tool_use.tool_result_card(id).cloned()
866 }
867
868 /// Return tools that are both enabled and supported by the model
869 pub fn available_tools(
870 &self,
871 cx: &App,
872 model: Arc<dyn LanguageModel>,
873 ) -> Vec<LanguageModelRequestTool> {
874 if model.supports_tools() {
875 self.tools()
876 .read(cx)
877 .enabled_tools(cx)
878 .into_iter()
879 .filter_map(|tool| {
880 // Skip tools that cannot be supported
881 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
882 Some(LanguageModelRequestTool {
883 name: tool.name(),
884 description: tool.description(),
885 input_schema,
886 })
887 })
888 .collect()
889 } else {
890 Vec::default()
891 }
892 }
893
894 pub fn insert_user_message(
895 &mut self,
896 text: impl Into<String>,
897 loaded_context: ContextLoadResult,
898 git_checkpoint: Option<GitStoreCheckpoint>,
899 creases: Vec<MessageCrease>,
900 cx: &mut Context<Self>,
901 ) -> MessageId {
902 if !loaded_context.referenced_buffers.is_empty() {
903 self.action_log.update(cx, |log, cx| {
904 for buffer in loaded_context.referenced_buffers {
905 log.buffer_read(buffer, cx);
906 }
907 });
908 }
909
910 let message_id = self.insert_message(
911 Role::User,
912 vec![MessageSegment::Text(text.into())],
913 loaded_context.loaded_context,
914 creases,
915 cx,
916 );
917
918 if let Some(git_checkpoint) = git_checkpoint {
919 self.pending_checkpoint = Some(ThreadCheckpoint {
920 message_id,
921 git_checkpoint,
922 });
923 }
924
925 self.auto_capture_telemetry(cx);
926
927 message_id
928 }
929
930 pub fn insert_assistant_message(
931 &mut self,
932 segments: Vec<MessageSegment>,
933 cx: &mut Context<Self>,
934 ) -> MessageId {
935 self.insert_message(
936 Role::Assistant,
937 segments,
938 LoadedContext::default(),
939 Vec::new(),
940 cx,
941 )
942 }
943
944 pub fn insert_message(
945 &mut self,
946 role: Role,
947 segments: Vec<MessageSegment>,
948 loaded_context: LoadedContext,
949 creases: Vec<MessageCrease>,
950 cx: &mut Context<Self>,
951 ) -> MessageId {
952 let id = self.next_message_id.post_inc();
953 self.messages.push(Message {
954 id,
955 role,
956 segments,
957 loaded_context,
958 creases,
959 });
960 self.touch_updated_at();
961 cx.emit(ThreadEvent::MessageAdded(id));
962 id
963 }
964
965 pub fn edit_message(
966 &mut self,
967 id: MessageId,
968 new_role: Role,
969 new_segments: Vec<MessageSegment>,
970 loaded_context: Option<LoadedContext>,
971 cx: &mut Context<Self>,
972 ) -> bool {
973 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
974 return false;
975 };
976 message.role = new_role;
977 message.segments = new_segments;
978 if let Some(context) = loaded_context {
979 message.loaded_context = context;
980 }
981 self.touch_updated_at();
982 cx.emit(ThreadEvent::MessageEdited(id));
983 true
984 }
985
986 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
987 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
988 return false;
989 };
990 self.messages.remove(index);
991 self.touch_updated_at();
992 cx.emit(ThreadEvent::MessageDeleted(id));
993 true
994 }
995
996 /// Returns the representation of this [`Thread`] in a textual form.
997 ///
998 /// This is the representation we use when attaching a thread as context to another thread.
999 pub fn text(&self) -> String {
1000 let mut text = String::new();
1001
1002 for message in &self.messages {
1003 text.push_str(match message.role {
1004 language_model::Role::User => "User:",
1005 language_model::Role::Assistant => "Agent:",
1006 language_model::Role::System => "System:",
1007 });
1008 text.push('\n');
1009
1010 for segment in &message.segments {
1011 match segment {
1012 MessageSegment::Text(content) => text.push_str(content),
1013 MessageSegment::Thinking { text: content, .. } => {
1014 text.push_str(&format!("<think>{}</think>", content))
1015 }
1016 MessageSegment::RedactedThinking(_) => {}
1017 }
1018 }
1019 text.push('\n');
1020 }
1021
1022 text
1023 }
1024
1025 /// Serializes this thread into a format for storage or telemetry.
1026 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1027 let initial_project_snapshot = self.initial_project_snapshot.clone();
1028 cx.spawn(async move |this, cx| {
1029 let initial_project_snapshot = initial_project_snapshot.await;
1030 this.read_with(cx, |this, cx| SerializedThread {
1031 version: SerializedThread::VERSION.to_string(),
1032 summary: this.summary_or_default(),
1033 updated_at: this.updated_at(),
1034 messages: this
1035 .messages()
1036 .map(|message| SerializedMessage {
1037 id: message.id,
1038 role: message.role,
1039 segments: message
1040 .segments
1041 .iter()
1042 .map(|segment| match segment {
1043 MessageSegment::Text(text) => {
1044 SerializedMessageSegment::Text { text: text.clone() }
1045 }
1046 MessageSegment::Thinking { text, signature } => {
1047 SerializedMessageSegment::Thinking {
1048 text: text.clone(),
1049 signature: signature.clone(),
1050 }
1051 }
1052 MessageSegment::RedactedThinking(data) => {
1053 SerializedMessageSegment::RedactedThinking {
1054 data: data.clone(),
1055 }
1056 }
1057 })
1058 .collect(),
1059 tool_uses: this
1060 .tool_uses_for_message(message.id, cx)
1061 .into_iter()
1062 .map(|tool_use| SerializedToolUse {
1063 id: tool_use.id,
1064 name: tool_use.name,
1065 input: tool_use.input,
1066 })
1067 .collect(),
1068 tool_results: this
1069 .tool_results_for_message(message.id)
1070 .into_iter()
1071 .map(|tool_result| SerializedToolResult {
1072 tool_use_id: tool_result.tool_use_id.clone(),
1073 is_error: tool_result.is_error,
1074 content: tool_result.content.clone(),
1075 output: tool_result.output.clone(),
1076 })
1077 .collect(),
1078 context: message.loaded_context.text.clone(),
1079 creases: message
1080 .creases
1081 .iter()
1082 .map(|crease| SerializedCrease {
1083 start: crease.range.start,
1084 end: crease.range.end,
1085 icon_path: crease.metadata.icon_path.clone(),
1086 label: crease.metadata.label.clone(),
1087 })
1088 .collect(),
1089 })
1090 .collect(),
1091 initial_project_snapshot,
1092 cumulative_token_usage: this.cumulative_token_usage,
1093 request_token_usage: this.request_token_usage.clone(),
1094 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1095 exceeded_window_error: this.exceeded_window_error.clone(),
1096 model: this
1097 .configured_model
1098 .as_ref()
1099 .map(|model| SerializedLanguageModel {
1100 provider: model.provider.id().0.to_string(),
1101 model: model.model.id().0.to_string(),
1102 }),
1103 completion_mode: Some(this.completion_mode),
1104 })
1105 })
1106 }
1107
1108 pub fn remaining_turns(&self) -> u32 {
1109 self.remaining_turns
1110 }
1111
1112 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1113 self.remaining_turns = remaining_turns;
1114 }
1115
1116 pub fn send_to_model(
1117 &mut self,
1118 model: Arc<dyn LanguageModel>,
1119 window: Option<AnyWindowHandle>,
1120 cx: &mut Context<Self>,
1121 ) {
1122 if self.remaining_turns == 0 {
1123 return;
1124 }
1125
1126 self.remaining_turns -= 1;
1127
1128 let request = self.to_completion_request(model.clone(), cx);
1129
1130 self.stream_completion(request, model, window, cx);
1131 }
1132
1133 pub fn used_tools_since_last_user_message(&self) -> bool {
1134 for message in self.messages.iter().rev() {
1135 if self.tool_use.message_has_tool_results(message.id) {
1136 return true;
1137 } else if message.role == Role::User {
1138 return false;
1139 }
1140 }
1141
1142 false
1143 }
1144
1145 pub fn to_completion_request(
1146 &self,
1147 model: Arc<dyn LanguageModel>,
1148 cx: &mut Context<Self>,
1149 ) -> LanguageModelRequest {
1150 let mut request = LanguageModelRequest {
1151 thread_id: Some(self.id.to_string()),
1152 prompt_id: Some(self.last_prompt_id.to_string()),
1153 mode: None,
1154 messages: vec![],
1155 tools: Vec::new(),
1156 stop: Vec::new(),
1157 temperature: AssistantSettings::temperature_for_model(&model, cx),
1158 };
1159
1160 let available_tools = self.available_tools(cx, model.clone());
1161 let available_tool_names = available_tools
1162 .iter()
1163 .map(|tool| tool.name.clone())
1164 .collect();
1165
1166 let model_context = &ModelContext {
1167 available_tools: available_tool_names,
1168 };
1169
1170 if let Some(project_context) = self.project_context.borrow().as_ref() {
1171 match self
1172 .prompt_builder
1173 .generate_assistant_system_prompt(project_context, model_context)
1174 {
1175 Err(err) => {
1176 let message = format!("{err:?}").into();
1177 log::error!("{message}");
1178 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1179 header: "Error generating system prompt".into(),
1180 message,
1181 }));
1182 }
1183 Ok(system_prompt) => {
1184 request.messages.push(LanguageModelRequestMessage {
1185 role: Role::System,
1186 content: vec![MessageContent::Text(system_prompt)],
1187 cache: true,
1188 });
1189 }
1190 }
1191 } else {
1192 let message = "Context for system prompt unexpectedly not ready.".into();
1193 log::error!("{message}");
1194 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1195 header: "Error generating system prompt".into(),
1196 message,
1197 }));
1198 }
1199
1200 for message in &self.messages {
1201 let mut request_message = LanguageModelRequestMessage {
1202 role: message.role,
1203 content: Vec::new(),
1204 cache: false,
1205 };
1206
1207 message
1208 .loaded_context
1209 .add_to_request_message(&mut request_message);
1210
1211 for segment in &message.segments {
1212 match segment {
1213 MessageSegment::Text(text) => {
1214 if !text.is_empty() {
1215 request_message
1216 .content
1217 .push(MessageContent::Text(text.into()));
1218 }
1219 }
1220 MessageSegment::Thinking { text, signature } => {
1221 if !text.is_empty() {
1222 request_message.content.push(MessageContent::Thinking {
1223 text: text.into(),
1224 signature: signature.clone(),
1225 });
1226 }
1227 }
1228 MessageSegment::RedactedThinking(data) => {
1229 request_message
1230 .content
1231 .push(MessageContent::RedactedThinking(data.clone()));
1232 }
1233 };
1234 }
1235
1236 self.tool_use
1237 .attach_tool_uses(message.id, &mut request_message);
1238
1239 request.messages.push(request_message);
1240
1241 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1242 request.messages.push(tool_results_message);
1243 }
1244 }
1245
1246 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1247 if let Some(last) = request.messages.last_mut() {
1248 last.cache = true;
1249 }
1250
1251 self.attached_tracked_files_state(&mut request.messages, cx);
1252
1253 request.tools = available_tools;
1254 request.mode = if model.supports_max_mode() {
1255 Some(self.completion_mode.into())
1256 } else {
1257 Some(CompletionMode::Normal.into())
1258 };
1259
1260 request
1261 }
1262
1263 fn to_summarize_request(
1264 &self,
1265 model: &Arc<dyn LanguageModel>,
1266 added_user_message: String,
1267 cx: &App,
1268 ) -> LanguageModelRequest {
1269 let mut request = LanguageModelRequest {
1270 thread_id: None,
1271 prompt_id: None,
1272 mode: None,
1273 messages: vec![],
1274 tools: Vec::new(),
1275 stop: Vec::new(),
1276 temperature: AssistantSettings::temperature_for_model(model, cx),
1277 };
1278
1279 for message in &self.messages {
1280 let mut request_message = LanguageModelRequestMessage {
1281 role: message.role,
1282 content: Vec::new(),
1283 cache: false,
1284 };
1285
1286 for segment in &message.segments {
1287 match segment {
1288 MessageSegment::Text(text) => request_message
1289 .content
1290 .push(MessageContent::Text(text.clone())),
1291 MessageSegment::Thinking { .. } => {}
1292 MessageSegment::RedactedThinking(_) => {}
1293 }
1294 }
1295
1296 if request_message.content.is_empty() {
1297 continue;
1298 }
1299
1300 request.messages.push(request_message);
1301 }
1302
1303 request.messages.push(LanguageModelRequestMessage {
1304 role: Role::User,
1305 content: vec![MessageContent::Text(added_user_message)],
1306 cache: false,
1307 });
1308
1309 request
1310 }
1311
1312 fn attached_tracked_files_state(
1313 &self,
1314 messages: &mut Vec<LanguageModelRequestMessage>,
1315 cx: &App,
1316 ) {
1317 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1318
1319 let mut stale_message = String::new();
1320
1321 let action_log = self.action_log.read(cx);
1322
1323 for stale_file in action_log.stale_buffers(cx) {
1324 let Some(file) = stale_file.read(cx).file() else {
1325 continue;
1326 };
1327
1328 if stale_message.is_empty() {
1329 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1330 }
1331
1332 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1333 }
1334
1335 let mut content = Vec::with_capacity(2);
1336
1337 if !stale_message.is_empty() {
1338 content.push(stale_message.into());
1339 }
1340
1341 if !content.is_empty() {
1342 let context_message = LanguageModelRequestMessage {
1343 role: Role::User,
1344 content,
1345 cache: false,
1346 };
1347
1348 messages.push(context_message);
1349 }
1350 }
1351
1352 pub fn stream_completion(
1353 &mut self,
1354 request: LanguageModelRequest,
1355 model: Arc<dyn LanguageModel>,
1356 window: Option<AnyWindowHandle>,
1357 cx: &mut Context<Self>,
1358 ) {
1359 self.tool_use_limit_reached = false;
1360
1361 let pending_completion_id = post_inc(&mut self.completion_count);
1362 let mut request_callback_parameters = if self.request_callback.is_some() {
1363 Some((request.clone(), Vec::new()))
1364 } else {
1365 None
1366 };
1367 let prompt_id = self.last_prompt_id.clone();
1368 let tool_use_metadata = ToolUseMetadata {
1369 model: model.clone(),
1370 thread_id: self.id.clone(),
1371 prompt_id: prompt_id.clone(),
1372 };
1373
1374 self.last_received_chunk_at = Some(Instant::now());
1375
1376 let task = cx.spawn(async move |thread, cx| {
1377 let stream_completion_future = model.stream_completion(request, &cx);
1378 let initial_token_usage =
1379 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1380 let stream_completion = async {
1381 let mut events = stream_completion_future.await?;
1382
1383 let mut stop_reason = StopReason::EndTurn;
1384 let mut current_token_usage = TokenUsage::default();
1385
1386 thread
1387 .update(cx, |_thread, cx| {
1388 cx.emit(ThreadEvent::NewRequest);
1389 })
1390 .ok();
1391
1392 let mut request_assistant_message_id = None;
1393
1394 while let Some(event) = events.next().await {
1395 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1396 response_events
1397 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1398 }
1399
1400 thread.update(cx, |thread, cx| {
1401 let event = match event {
1402 Ok(event) => event,
1403 Err(LanguageModelCompletionError::BadInputJson {
1404 id,
1405 tool_name,
1406 raw_input: invalid_input_json,
1407 json_parse_error,
1408 }) => {
1409 thread.receive_invalid_tool_json(
1410 id,
1411 tool_name,
1412 invalid_input_json,
1413 json_parse_error,
1414 window,
1415 cx,
1416 );
1417 return Ok(());
1418 }
1419 Err(LanguageModelCompletionError::Other(error)) => {
1420 return Err(error);
1421 }
1422 };
1423
1424 match event {
1425 LanguageModelCompletionEvent::StartMessage { .. } => {
1426 request_assistant_message_id =
1427 Some(thread.insert_assistant_message(
1428 vec![MessageSegment::Text(String::new())],
1429 cx,
1430 ));
1431 }
1432 LanguageModelCompletionEvent::Stop(reason) => {
1433 stop_reason = reason;
1434 }
1435 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1436 thread.update_token_usage_at_last_message(token_usage);
1437 thread.cumulative_token_usage = thread.cumulative_token_usage
1438 + token_usage
1439 - current_token_usage;
1440 current_token_usage = token_usage;
1441 }
1442 LanguageModelCompletionEvent::Text(chunk) => {
1443 thread.received_chunk();
1444
1445 cx.emit(ThreadEvent::ReceivedTextChunk);
1446 if let Some(last_message) = thread.messages.last_mut() {
1447 if last_message.role == Role::Assistant
1448 && !thread.tool_use.has_tool_results(last_message.id)
1449 {
1450 last_message.push_text(&chunk);
1451 cx.emit(ThreadEvent::StreamedAssistantText(
1452 last_message.id,
1453 chunk,
1454 ));
1455 } else {
1456 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1457 // of a new Assistant response.
1458 //
1459 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1460 // will result in duplicating the text of the chunk in the rendered Markdown.
1461 request_assistant_message_id =
1462 Some(thread.insert_assistant_message(
1463 vec![MessageSegment::Text(chunk.to_string())],
1464 cx,
1465 ));
1466 };
1467 }
1468 }
1469 LanguageModelCompletionEvent::Thinking {
1470 text: chunk,
1471 signature,
1472 } => {
1473 thread.received_chunk();
1474
1475 if let Some(last_message) = thread.messages.last_mut() {
1476 if last_message.role == Role::Assistant
1477 && !thread.tool_use.has_tool_results(last_message.id)
1478 {
1479 last_message.push_thinking(&chunk, signature);
1480 cx.emit(ThreadEvent::StreamedAssistantThinking(
1481 last_message.id,
1482 chunk,
1483 ));
1484 } else {
1485 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1486 // of a new Assistant response.
1487 //
1488 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1489 // will result in duplicating the text of the chunk in the rendered Markdown.
1490 request_assistant_message_id =
1491 Some(thread.insert_assistant_message(
1492 vec![MessageSegment::Thinking {
1493 text: chunk.to_string(),
1494 signature,
1495 }],
1496 cx,
1497 ));
1498 };
1499 }
1500 }
1501 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1502 let last_assistant_message_id = request_assistant_message_id
1503 .unwrap_or_else(|| {
1504 let new_assistant_message_id =
1505 thread.insert_assistant_message(vec![], cx);
1506 request_assistant_message_id =
1507 Some(new_assistant_message_id);
1508 new_assistant_message_id
1509 });
1510
1511 let tool_use_id = tool_use.id.clone();
1512 let streamed_input = if tool_use.is_input_complete {
1513 None
1514 } else {
1515 Some((&tool_use.input).clone())
1516 };
1517
1518 let ui_text = thread.tool_use.request_tool_use(
1519 last_assistant_message_id,
1520 tool_use,
1521 tool_use_metadata.clone(),
1522 cx,
1523 );
1524
1525 if let Some(input) = streamed_input {
1526 cx.emit(ThreadEvent::StreamedToolUse {
1527 tool_use_id,
1528 ui_text,
1529 input,
1530 });
1531 }
1532 }
1533 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1534 if let Some(completion) = thread
1535 .pending_completions
1536 .iter_mut()
1537 .find(|completion| completion.id == pending_completion_id)
1538 {
1539 match status_update {
1540 CompletionRequestStatus::Queued {
1541 position,
1542 } => {
1543 completion.queue_state = QueueState::Queued { position };
1544 }
1545 CompletionRequestStatus::Started => {
1546 completion.queue_state = QueueState::Started;
1547 }
1548 CompletionRequestStatus::Failed {
1549 code, message, request_id
1550 } => {
1551 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1552 }
1553 CompletionRequestStatus::UsageUpdated {
1554 amount, limit
1555 } => {
1556 let usage = RequestUsage { limit, amount: amount as i32 };
1557
1558 thread.last_usage = Some(usage);
1559 }
1560 CompletionRequestStatus::ToolUseLimitReached => {
1561 thread.tool_use_limit_reached = true;
1562 }
1563 }
1564 }
1565 }
1566 }
1567
1568 thread.touch_updated_at();
1569 cx.emit(ThreadEvent::StreamedCompletion);
1570 cx.notify();
1571
1572 thread.auto_capture_telemetry(cx);
1573 Ok(())
1574 })??;
1575
1576 smol::future::yield_now().await;
1577 }
1578
1579 thread.update(cx, |thread, cx| {
1580 thread.last_received_chunk_at = None;
1581 thread
1582 .pending_completions
1583 .retain(|completion| completion.id != pending_completion_id);
1584
1585 // If there is a response without tool use, summarize the message. Otherwise,
1586 // allow two tool uses before summarizing.
1587 if thread.summary.is_none()
1588 && thread.messages.len() >= 2
1589 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1590 {
1591 thread.summarize(cx);
1592 }
1593 })?;
1594
1595 anyhow::Ok(stop_reason)
1596 };
1597
1598 let result = stream_completion.await;
1599
1600 thread
1601 .update(cx, |thread, cx| {
1602 thread.finalize_pending_checkpoint(cx);
1603 match result.as_ref() {
1604 Ok(stop_reason) => match stop_reason {
1605 StopReason::ToolUse => {
1606 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1607 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1608 }
1609 StopReason::EndTurn | StopReason::MaxTokens => {
1610 thread.project.update(cx, |project, cx| {
1611 project.set_agent_location(None, cx);
1612 });
1613 }
1614 },
1615 Err(error) => {
1616 thread.project.update(cx, |project, cx| {
1617 project.set_agent_location(None, cx);
1618 });
1619
1620 if error.is::<PaymentRequiredError>() {
1621 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1622 } else if error.is::<MaxMonthlySpendReachedError>() {
1623 cx.emit(ThreadEvent::ShowError(
1624 ThreadError::MaxMonthlySpendReached,
1625 ));
1626 } else if let Some(error) =
1627 error.downcast_ref::<ModelRequestLimitReachedError>()
1628 {
1629 cx.emit(ThreadEvent::ShowError(
1630 ThreadError::ModelRequestLimitReached { plan: error.plan },
1631 ));
1632 } else if let Some(known_error) =
1633 error.downcast_ref::<LanguageModelKnownError>()
1634 {
1635 match known_error {
1636 LanguageModelKnownError::ContextWindowLimitExceeded {
1637 tokens,
1638 } => {
1639 thread.exceeded_window_error = Some(ExceededWindowError {
1640 model_id: model.id(),
1641 token_count: *tokens,
1642 });
1643 cx.notify();
1644 }
1645 }
1646 } else {
1647 let error_message = error
1648 .chain()
1649 .map(|err| err.to_string())
1650 .collect::<Vec<_>>()
1651 .join("\n");
1652 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1653 header: "Error interacting with language model".into(),
1654 message: SharedString::from(error_message.clone()),
1655 }));
1656 }
1657
1658 thread.cancel_last_completion(window, cx);
1659 }
1660 }
1661 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1662
1663 if let Some((request_callback, (request, response_events))) = thread
1664 .request_callback
1665 .as_mut()
1666 .zip(request_callback_parameters.as_ref())
1667 {
1668 request_callback(request, response_events);
1669 }
1670
1671 thread.auto_capture_telemetry(cx);
1672
1673 if let Ok(initial_usage) = initial_token_usage {
1674 let usage = thread.cumulative_token_usage - initial_usage;
1675
1676 telemetry::event!(
1677 "Assistant Thread Completion",
1678 thread_id = thread.id().to_string(),
1679 prompt_id = prompt_id,
1680 model = model.telemetry_id(),
1681 model_provider = model.provider_id().to_string(),
1682 input_tokens = usage.input_tokens,
1683 output_tokens = usage.output_tokens,
1684 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1685 cache_read_input_tokens = usage.cache_read_input_tokens,
1686 );
1687 }
1688 })
1689 .ok();
1690 });
1691
1692 self.pending_completions.push(PendingCompletion {
1693 id: pending_completion_id,
1694 queue_state: QueueState::Sending,
1695 _task: task,
1696 });
1697 }
1698
1699 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1700 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1701 return;
1702 };
1703
1704 if !model.provider.is_authenticated(cx) {
1705 return;
1706 }
1707
1708 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1709 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1710 If the conversation is about a specific subject, include it in the title. \
1711 Be descriptive. DO NOT speak in the first person.";
1712
1713 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1714
1715 self.pending_summary = cx.spawn(async move |this, cx| {
1716 async move {
1717 let mut messages = model.model.stream_completion(request, &cx).await?;
1718
1719 let mut new_summary = String::new();
1720 while let Some(event) = messages.next().await {
1721 let event = event?;
1722 let text = match event {
1723 LanguageModelCompletionEvent::Text(text) => text,
1724 LanguageModelCompletionEvent::StatusUpdate(
1725 CompletionRequestStatus::UsageUpdated { amount, limit },
1726 ) => {
1727 this.update(cx, |thread, _cx| {
1728 thread.last_usage = Some(RequestUsage {
1729 limit,
1730 amount: amount as i32,
1731 });
1732 })?;
1733 continue;
1734 }
1735 _ => continue,
1736 };
1737
1738 let mut lines = text.lines();
1739 new_summary.extend(lines.next());
1740
1741 // Stop if the LLM generated multiple lines.
1742 if lines.next().is_some() {
1743 break;
1744 }
1745 }
1746
1747 this.update(cx, |this, cx| {
1748 if !new_summary.is_empty() {
1749 this.summary = Some(new_summary.into());
1750 }
1751
1752 cx.emit(ThreadEvent::SummaryGenerated);
1753 })?;
1754
1755 anyhow::Ok(())
1756 }
1757 .log_err()
1758 .await
1759 });
1760 }
1761
1762 pub fn start_generating_detailed_summary_if_needed(
1763 &mut self,
1764 thread_store: WeakEntity<ThreadStore>,
1765 cx: &mut Context<Self>,
1766 ) {
1767 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1768 return;
1769 };
1770
1771 match &*self.detailed_summary_rx.borrow() {
1772 DetailedSummaryState::Generating { message_id, .. }
1773 | DetailedSummaryState::Generated { message_id, .. }
1774 if *message_id == last_message_id =>
1775 {
1776 // Already up-to-date
1777 return;
1778 }
1779 _ => {}
1780 }
1781
1782 let Some(ConfiguredModel { model, provider }) =
1783 LanguageModelRegistry::read_global(cx).thread_summary_model()
1784 else {
1785 return;
1786 };
1787
1788 if !provider.is_authenticated(cx) {
1789 return;
1790 }
1791
1792 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1793 1. A brief overview of what was discussed\n\
1794 2. Key facts or information discovered\n\
1795 3. Outcomes or conclusions reached\n\
1796 4. Any action items or next steps if any\n\
1797 Format it in Markdown with headings and bullet points.";
1798
1799 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1800
1801 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1802 message_id: last_message_id,
1803 };
1804
1805 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1806 // be better to allow the old task to complete, but this would require logic for choosing
1807 // which result to prefer (the old task could complete after the new one, resulting in a
1808 // stale summary).
1809 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1810 let stream = model.stream_completion_text(request, &cx);
1811 let Some(mut messages) = stream.await.log_err() else {
1812 thread
1813 .update(cx, |thread, _cx| {
1814 *thread.detailed_summary_tx.borrow_mut() =
1815 DetailedSummaryState::NotGenerated;
1816 })
1817 .ok()?;
1818 return None;
1819 };
1820
1821 let mut new_detailed_summary = String::new();
1822
1823 while let Some(chunk) = messages.stream.next().await {
1824 if let Some(chunk) = chunk.log_err() {
1825 new_detailed_summary.push_str(&chunk);
1826 }
1827 }
1828
1829 thread
1830 .update(cx, |thread, _cx| {
1831 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1832 text: new_detailed_summary.into(),
1833 message_id: last_message_id,
1834 };
1835 })
1836 .ok()?;
1837
1838 // Save thread so its summary can be reused later
1839 if let Some(thread) = thread.upgrade() {
1840 if let Ok(Ok(save_task)) = cx.update(|cx| {
1841 thread_store
1842 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1843 }) {
1844 save_task.await.log_err();
1845 }
1846 }
1847
1848 Some(())
1849 });
1850 }
1851
1852 pub async fn wait_for_detailed_summary_or_text(
1853 this: &Entity<Self>,
1854 cx: &mut AsyncApp,
1855 ) -> Option<SharedString> {
1856 let mut detailed_summary_rx = this
1857 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1858 .ok()?;
1859 loop {
1860 match detailed_summary_rx.recv().await? {
1861 DetailedSummaryState::Generating { .. } => {}
1862 DetailedSummaryState::NotGenerated => {
1863 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1864 }
1865 DetailedSummaryState::Generated { text, .. } => return Some(text),
1866 }
1867 }
1868 }
1869
1870 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1871 self.detailed_summary_rx
1872 .borrow()
1873 .text()
1874 .unwrap_or_else(|| self.text().into())
1875 }
1876
1877 pub fn is_generating_detailed_summary(&self) -> bool {
1878 matches!(
1879 &*self.detailed_summary_rx.borrow(),
1880 DetailedSummaryState::Generating { .. }
1881 )
1882 }
1883
1884 pub fn use_pending_tools(
1885 &mut self,
1886 window: Option<AnyWindowHandle>,
1887 cx: &mut Context<Self>,
1888 model: Arc<dyn LanguageModel>,
1889 ) -> Vec<PendingToolUse> {
1890 self.auto_capture_telemetry(cx);
1891 let request = self.to_completion_request(model, cx);
1892 let messages = Arc::new(request.messages);
1893 let pending_tool_uses = self
1894 .tool_use
1895 .pending_tool_uses()
1896 .into_iter()
1897 .filter(|tool_use| tool_use.status.is_idle())
1898 .cloned()
1899 .collect::<Vec<_>>();
1900
1901 for tool_use in pending_tool_uses.iter() {
1902 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1903 if tool.needs_confirmation(&tool_use.input, cx)
1904 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1905 {
1906 self.tool_use.confirm_tool_use(
1907 tool_use.id.clone(),
1908 tool_use.ui_text.clone(),
1909 tool_use.input.clone(),
1910 messages.clone(),
1911 tool,
1912 );
1913 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1914 } else {
1915 self.run_tool(
1916 tool_use.id.clone(),
1917 tool_use.ui_text.clone(),
1918 tool_use.input.clone(),
1919 &messages,
1920 tool,
1921 window,
1922 cx,
1923 );
1924 }
1925 } else {
1926 self.handle_hallucinated_tool_use(
1927 tool_use.id.clone(),
1928 tool_use.name.clone(),
1929 window,
1930 cx,
1931 );
1932 }
1933 }
1934
1935 pending_tool_uses
1936 }
1937
1938 pub fn handle_hallucinated_tool_use(
1939 &mut self,
1940 tool_use_id: LanguageModelToolUseId,
1941 hallucinated_tool_name: Arc<str>,
1942 window: Option<AnyWindowHandle>,
1943 cx: &mut Context<Thread>,
1944 ) {
1945 let available_tools = self.tools.read(cx).enabled_tools(cx);
1946
1947 let tool_list = available_tools
1948 .iter()
1949 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1950 .collect::<Vec<_>>()
1951 .join("\n");
1952
1953 let error_message = format!(
1954 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1955 hallucinated_tool_name, tool_list
1956 );
1957
1958 let pending_tool_use = self.tool_use.insert_tool_output(
1959 tool_use_id.clone(),
1960 hallucinated_tool_name,
1961 Err(anyhow!("Missing tool call: {error_message}")),
1962 self.configured_model.as_ref(),
1963 );
1964
1965 cx.emit(ThreadEvent::MissingToolUse {
1966 tool_use_id: tool_use_id.clone(),
1967 ui_text: error_message.into(),
1968 });
1969
1970 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1971 }
1972
1973 pub fn receive_invalid_tool_json(
1974 &mut self,
1975 tool_use_id: LanguageModelToolUseId,
1976 tool_name: Arc<str>,
1977 invalid_json: Arc<str>,
1978 error: String,
1979 window: Option<AnyWindowHandle>,
1980 cx: &mut Context<Thread>,
1981 ) {
1982 log::error!("The model returned invalid input JSON: {invalid_json}");
1983
1984 let pending_tool_use = self.tool_use.insert_tool_output(
1985 tool_use_id.clone(),
1986 tool_name,
1987 Err(anyhow!("Error parsing input JSON: {error}")),
1988 self.configured_model.as_ref(),
1989 );
1990 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1991 pending_tool_use.ui_text.clone()
1992 } else {
1993 log::error!(
1994 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1995 );
1996 format!("Unknown tool {}", tool_use_id).into()
1997 };
1998
1999 cx.emit(ThreadEvent::InvalidToolInput {
2000 tool_use_id: tool_use_id.clone(),
2001 ui_text,
2002 invalid_input_json: invalid_json,
2003 });
2004
2005 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2006 }
2007
2008 pub fn run_tool(
2009 &mut self,
2010 tool_use_id: LanguageModelToolUseId,
2011 ui_text: impl Into<SharedString>,
2012 input: serde_json::Value,
2013 messages: &[LanguageModelRequestMessage],
2014 tool: Arc<dyn Tool>,
2015 window: Option<AnyWindowHandle>,
2016 cx: &mut Context<Thread>,
2017 ) {
2018 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
2019 self.tool_use
2020 .run_pending_tool(tool_use_id, ui_text.into(), task);
2021 }
2022
2023 fn spawn_tool_use(
2024 &mut self,
2025 tool_use_id: LanguageModelToolUseId,
2026 messages: &[LanguageModelRequestMessage],
2027 input: serde_json::Value,
2028 tool: Arc<dyn Tool>,
2029 window: Option<AnyWindowHandle>,
2030 cx: &mut Context<Thread>,
2031 ) -> Task<()> {
2032 let tool_name: Arc<str> = tool.name().into();
2033
2034 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2035 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2036 } else {
2037 tool.run(
2038 input,
2039 messages,
2040 self.project.clone(),
2041 self.action_log.clone(),
2042 window,
2043 cx,
2044 )
2045 };
2046
2047 // Store the card separately if it exists
2048 if let Some(card) = tool_result.card.clone() {
2049 self.tool_use
2050 .insert_tool_result_card(tool_use_id.clone(), card);
2051 }
2052
2053 cx.spawn({
2054 async move |thread: WeakEntity<Thread>, cx| {
2055 let output = tool_result.output.await;
2056
2057 thread
2058 .update(cx, |thread, cx| {
2059 let pending_tool_use = thread.tool_use.insert_tool_output(
2060 tool_use_id.clone(),
2061 tool_name,
2062 output,
2063 thread.configured_model.as_ref(),
2064 );
2065 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2066 })
2067 .ok();
2068 }
2069 })
2070 }
2071
2072 fn tool_finished(
2073 &mut self,
2074 tool_use_id: LanguageModelToolUseId,
2075 pending_tool_use: Option<PendingToolUse>,
2076 canceled: bool,
2077 window: Option<AnyWindowHandle>,
2078 cx: &mut Context<Self>,
2079 ) {
2080 if self.all_tools_finished() {
2081 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2082 if !canceled {
2083 self.send_to_model(model.clone(), window, cx);
2084 }
2085 self.auto_capture_telemetry(cx);
2086 }
2087 }
2088
2089 cx.emit(ThreadEvent::ToolFinished {
2090 tool_use_id,
2091 pending_tool_use,
2092 });
2093 }
2094
2095 /// Cancels the last pending completion, if there are any pending.
2096 ///
2097 /// Returns whether a completion was canceled.
2098 pub fn cancel_last_completion(
2099 &mut self,
2100 window: Option<AnyWindowHandle>,
2101 cx: &mut Context<Self>,
2102 ) -> bool {
2103 let mut canceled = self.pending_completions.pop().is_some();
2104
2105 for pending_tool_use in self.tool_use.cancel_pending() {
2106 canceled = true;
2107 self.tool_finished(
2108 pending_tool_use.id.clone(),
2109 Some(pending_tool_use),
2110 true,
2111 window,
2112 cx,
2113 );
2114 }
2115
2116 self.finalize_pending_checkpoint(cx);
2117
2118 if canceled {
2119 cx.emit(ThreadEvent::CompletionCanceled);
2120 }
2121
2122 canceled
2123 }
2124
2125 /// Signals that any in-progress editing should be canceled.
2126 ///
2127 /// This method is used to notify listeners (like ActiveThread) that
2128 /// they should cancel any editing operations.
2129 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2130 cx.emit(ThreadEvent::CancelEditing);
2131 }
2132
2133 pub fn feedback(&self) -> Option<ThreadFeedback> {
2134 self.feedback
2135 }
2136
2137 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2138 self.message_feedback.get(&message_id).copied()
2139 }
2140
2141 pub fn report_message_feedback(
2142 &mut self,
2143 message_id: MessageId,
2144 feedback: ThreadFeedback,
2145 cx: &mut Context<Self>,
2146 ) -> Task<Result<()>> {
2147 if self.message_feedback.get(&message_id) == Some(&feedback) {
2148 return Task::ready(Ok(()));
2149 }
2150
2151 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2152 let serialized_thread = self.serialize(cx);
2153 let thread_id = self.id().clone();
2154 let client = self.project.read(cx).client();
2155
2156 let enabled_tool_names: Vec<String> = self
2157 .tools()
2158 .read(cx)
2159 .enabled_tools(cx)
2160 .iter()
2161 .map(|tool| tool.name().to_string())
2162 .collect();
2163
2164 self.message_feedback.insert(message_id, feedback);
2165
2166 cx.notify();
2167
2168 let message_content = self
2169 .message(message_id)
2170 .map(|msg| msg.to_string())
2171 .unwrap_or_default();
2172
2173 cx.background_spawn(async move {
2174 let final_project_snapshot = final_project_snapshot.await;
2175 let serialized_thread = serialized_thread.await?;
2176 let thread_data =
2177 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2178
2179 let rating = match feedback {
2180 ThreadFeedback::Positive => "positive",
2181 ThreadFeedback::Negative => "negative",
2182 };
2183 telemetry::event!(
2184 "Assistant Thread Rated",
2185 rating,
2186 thread_id,
2187 enabled_tool_names,
2188 message_id = message_id.0,
2189 message_content,
2190 thread_data,
2191 final_project_snapshot
2192 );
2193 client.telemetry().flush_events().await;
2194
2195 Ok(())
2196 })
2197 }
2198
2199 pub fn report_feedback(
2200 &mut self,
2201 feedback: ThreadFeedback,
2202 cx: &mut Context<Self>,
2203 ) -> Task<Result<()>> {
2204 let last_assistant_message_id = self
2205 .messages
2206 .iter()
2207 .rev()
2208 .find(|msg| msg.role == Role::Assistant)
2209 .map(|msg| msg.id);
2210
2211 if let Some(message_id) = last_assistant_message_id {
2212 self.report_message_feedback(message_id, feedback, cx)
2213 } else {
2214 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2215 let serialized_thread = self.serialize(cx);
2216 let thread_id = self.id().clone();
2217 let client = self.project.read(cx).client();
2218 self.feedback = Some(feedback);
2219 cx.notify();
2220
2221 cx.background_spawn(async move {
2222 let final_project_snapshot = final_project_snapshot.await;
2223 let serialized_thread = serialized_thread.await?;
2224 let thread_data = serde_json::to_value(serialized_thread)
2225 .unwrap_or_else(|_| serde_json::Value::Null);
2226
2227 let rating = match feedback {
2228 ThreadFeedback::Positive => "positive",
2229 ThreadFeedback::Negative => "negative",
2230 };
2231 telemetry::event!(
2232 "Assistant Thread Rated",
2233 rating,
2234 thread_id,
2235 thread_data,
2236 final_project_snapshot
2237 );
2238 client.telemetry().flush_events().await;
2239
2240 Ok(())
2241 })
2242 }
2243 }
2244
2245 /// Create a snapshot of the current project state including git information and unsaved buffers.
2246 fn project_snapshot(
2247 project: Entity<Project>,
2248 cx: &mut Context<Self>,
2249 ) -> Task<Arc<ProjectSnapshot>> {
2250 let git_store = project.read(cx).git_store().clone();
2251 let worktree_snapshots: Vec<_> = project
2252 .read(cx)
2253 .visible_worktrees(cx)
2254 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2255 .collect();
2256
2257 cx.spawn(async move |_, cx| {
2258 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2259
2260 let mut unsaved_buffers = Vec::new();
2261 cx.update(|app_cx| {
2262 let buffer_store = project.read(app_cx).buffer_store();
2263 for buffer_handle in buffer_store.read(app_cx).buffers() {
2264 let buffer = buffer_handle.read(app_cx);
2265 if buffer.is_dirty() {
2266 if let Some(file) = buffer.file() {
2267 let path = file.path().to_string_lossy().to_string();
2268 unsaved_buffers.push(path);
2269 }
2270 }
2271 }
2272 })
2273 .ok();
2274
2275 Arc::new(ProjectSnapshot {
2276 worktree_snapshots,
2277 unsaved_buffer_paths: unsaved_buffers,
2278 timestamp: Utc::now(),
2279 })
2280 })
2281 }
2282
2283 fn worktree_snapshot(
2284 worktree: Entity<project::Worktree>,
2285 git_store: Entity<GitStore>,
2286 cx: &App,
2287 ) -> Task<WorktreeSnapshot> {
2288 cx.spawn(async move |cx| {
2289 // Get worktree path and snapshot
2290 let worktree_info = cx.update(|app_cx| {
2291 let worktree = worktree.read(app_cx);
2292 let path = worktree.abs_path().to_string_lossy().to_string();
2293 let snapshot = worktree.snapshot();
2294 (path, snapshot)
2295 });
2296
2297 let Ok((worktree_path, _snapshot)) = worktree_info else {
2298 return WorktreeSnapshot {
2299 worktree_path: String::new(),
2300 git_state: None,
2301 };
2302 };
2303
2304 let git_state = git_store
2305 .update(cx, |git_store, cx| {
2306 git_store
2307 .repositories()
2308 .values()
2309 .find(|repo| {
2310 repo.read(cx)
2311 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2312 .is_some()
2313 })
2314 .cloned()
2315 })
2316 .ok()
2317 .flatten()
2318 .map(|repo| {
2319 repo.update(cx, |repo, _| {
2320 let current_branch =
2321 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2322 repo.send_job(None, |state, _| async move {
2323 let RepositoryState::Local { backend, .. } = state else {
2324 return GitState {
2325 remote_url: None,
2326 head_sha: None,
2327 current_branch,
2328 diff: None,
2329 };
2330 };
2331
2332 let remote_url = backend.remote_url("origin");
2333 let head_sha = backend.head_sha().await;
2334 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2335
2336 GitState {
2337 remote_url,
2338 head_sha,
2339 current_branch,
2340 diff,
2341 }
2342 })
2343 })
2344 });
2345
2346 let git_state = match git_state {
2347 Some(git_state) => match git_state.ok() {
2348 Some(git_state) => git_state.await.ok(),
2349 None => None,
2350 },
2351 None => None,
2352 };
2353
2354 WorktreeSnapshot {
2355 worktree_path,
2356 git_state,
2357 }
2358 })
2359 }
2360
2361 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2362 let mut markdown = Vec::new();
2363
2364 if let Some(summary) = self.summary() {
2365 writeln!(markdown, "# {summary}\n")?;
2366 };
2367
2368 for message in self.messages() {
2369 writeln!(
2370 markdown,
2371 "## {role}\n",
2372 role = match message.role {
2373 Role::User => "User",
2374 Role::Assistant => "Agent",
2375 Role::System => "System",
2376 }
2377 )?;
2378
2379 if !message.loaded_context.text.is_empty() {
2380 writeln!(markdown, "{}", message.loaded_context.text)?;
2381 }
2382
2383 if !message.loaded_context.images.is_empty() {
2384 writeln!(
2385 markdown,
2386 "\n{} images attached as context.\n",
2387 message.loaded_context.images.len()
2388 )?;
2389 }
2390
2391 for segment in &message.segments {
2392 match segment {
2393 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2394 MessageSegment::Thinking { text, .. } => {
2395 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2396 }
2397 MessageSegment::RedactedThinking(_) => {}
2398 }
2399 }
2400
2401 for tool_use in self.tool_uses_for_message(message.id, cx) {
2402 writeln!(
2403 markdown,
2404 "**Use Tool: {} ({})**",
2405 tool_use.name, tool_use.id
2406 )?;
2407 writeln!(markdown, "```json")?;
2408 writeln!(
2409 markdown,
2410 "{}",
2411 serde_json::to_string_pretty(&tool_use.input)?
2412 )?;
2413 writeln!(markdown, "```")?;
2414 }
2415
2416 for tool_result in self.tool_results_for_message(message.id) {
2417 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2418 if tool_result.is_error {
2419 write!(markdown, " (Error)")?;
2420 }
2421
2422 writeln!(markdown, "**\n")?;
2423 writeln!(markdown, "{}", tool_result.content)?;
2424 }
2425 }
2426
2427 Ok(String::from_utf8_lossy(&markdown).to_string())
2428 }
2429
2430 pub fn keep_edits_in_range(
2431 &mut self,
2432 buffer: Entity<language::Buffer>,
2433 buffer_range: Range<language::Anchor>,
2434 cx: &mut Context<Self>,
2435 ) {
2436 self.action_log.update(cx, |action_log, cx| {
2437 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2438 });
2439 }
2440
2441 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2442 self.action_log
2443 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2444 }
2445
2446 pub fn reject_edits_in_ranges(
2447 &mut self,
2448 buffer: Entity<language::Buffer>,
2449 buffer_ranges: Vec<Range<language::Anchor>>,
2450 cx: &mut Context<Self>,
2451 ) -> Task<Result<()>> {
2452 self.action_log.update(cx, |action_log, cx| {
2453 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2454 })
2455 }
2456
2457 pub fn action_log(&self) -> &Entity<ActionLog> {
2458 &self.action_log
2459 }
2460
2461 pub fn project(&self) -> &Entity<Project> {
2462 &self.project
2463 }
2464
2465 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2466 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2467 return;
2468 }
2469
2470 let now = Instant::now();
2471 if let Some(last) = self.last_auto_capture_at {
2472 if now.duration_since(last).as_secs() < 10 {
2473 return;
2474 }
2475 }
2476
2477 self.last_auto_capture_at = Some(now);
2478
2479 let thread_id = self.id().clone();
2480 let github_login = self
2481 .project
2482 .read(cx)
2483 .user_store()
2484 .read(cx)
2485 .current_user()
2486 .map(|user| user.github_login.clone());
2487 let client = self.project.read(cx).client().clone();
2488 let serialize_task = self.serialize(cx);
2489
2490 cx.background_executor()
2491 .spawn(async move {
2492 if let Ok(serialized_thread) = serialize_task.await {
2493 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2494 telemetry::event!(
2495 "Agent Thread Auto-Captured",
2496 thread_id = thread_id.to_string(),
2497 thread_data = thread_data,
2498 auto_capture_reason = "tracked_user",
2499 github_login = github_login
2500 );
2501
2502 client.telemetry().flush_events().await;
2503 }
2504 }
2505 })
2506 .detach();
2507 }
2508
2509 pub fn cumulative_token_usage(&self) -> TokenUsage {
2510 self.cumulative_token_usage
2511 }
2512
2513 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2514 let Some(model) = self.configured_model.as_ref() else {
2515 return TotalTokenUsage::default();
2516 };
2517
2518 let max = model.model.max_token_count();
2519
2520 let index = self
2521 .messages
2522 .iter()
2523 .position(|msg| msg.id == message_id)
2524 .unwrap_or(0);
2525
2526 if index == 0 {
2527 return TotalTokenUsage { total: 0, max };
2528 }
2529
2530 let token_usage = &self
2531 .request_token_usage
2532 .get(index - 1)
2533 .cloned()
2534 .unwrap_or_default();
2535
2536 TotalTokenUsage {
2537 total: token_usage.total_tokens() as usize,
2538 max,
2539 }
2540 }
2541
2542 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2543 let model = self.configured_model.as_ref()?;
2544
2545 let max = model.model.max_token_count();
2546
2547 if let Some(exceeded_error) = &self.exceeded_window_error {
2548 if model.model.id() == exceeded_error.model_id {
2549 return Some(TotalTokenUsage {
2550 total: exceeded_error.token_count,
2551 max,
2552 });
2553 }
2554 }
2555
2556 let total = self
2557 .token_usage_at_last_message()
2558 .unwrap_or_default()
2559 .total_tokens() as usize;
2560
2561 Some(TotalTokenUsage { total, max })
2562 }
2563
2564 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2565 self.request_token_usage
2566 .get(self.messages.len().saturating_sub(1))
2567 .or_else(|| self.request_token_usage.last())
2568 .cloned()
2569 }
2570
2571 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2572 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2573 self.request_token_usage
2574 .resize(self.messages.len(), placeholder);
2575
2576 if let Some(last) = self.request_token_usage.last_mut() {
2577 *last = token_usage;
2578 }
2579 }
2580
2581 pub fn deny_tool_use(
2582 &mut self,
2583 tool_use_id: LanguageModelToolUseId,
2584 tool_name: Arc<str>,
2585 window: Option<AnyWindowHandle>,
2586 cx: &mut Context<Self>,
2587 ) {
2588 let err = Err(anyhow::anyhow!(
2589 "Permission to run tool action denied by user"
2590 ));
2591
2592 self.tool_use.insert_tool_output(
2593 tool_use_id.clone(),
2594 tool_name,
2595 err,
2596 self.configured_model.as_ref(),
2597 );
2598 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2599 }
2600}
2601
2602#[derive(Debug, Clone, Error)]
2603pub enum ThreadError {
2604 #[error("Payment required")]
2605 PaymentRequired,
2606 #[error("Max monthly spend reached")]
2607 MaxMonthlySpendReached,
2608 #[error("Model request limit reached")]
2609 ModelRequestLimitReached { plan: Plan },
2610 #[error("Message {header}: {message}")]
2611 Message {
2612 header: SharedString,
2613 message: SharedString,
2614 },
2615}
2616
2617#[derive(Debug, Clone)]
2618pub enum ThreadEvent {
2619 ShowError(ThreadError),
2620 StreamedCompletion,
2621 ReceivedTextChunk,
2622 NewRequest,
2623 StreamedAssistantText(MessageId, String),
2624 StreamedAssistantThinking(MessageId, String),
2625 StreamedToolUse {
2626 tool_use_id: LanguageModelToolUseId,
2627 ui_text: Arc<str>,
2628 input: serde_json::Value,
2629 },
2630 MissingToolUse {
2631 tool_use_id: LanguageModelToolUseId,
2632 ui_text: Arc<str>,
2633 },
2634 InvalidToolInput {
2635 tool_use_id: LanguageModelToolUseId,
2636 ui_text: Arc<str>,
2637 invalid_input_json: Arc<str>,
2638 },
2639 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2640 MessageAdded(MessageId),
2641 MessageEdited(MessageId),
2642 MessageDeleted(MessageId),
2643 SummaryGenerated,
2644 SummaryChanged,
2645 UsePendingTools {
2646 tool_uses: Vec<PendingToolUse>,
2647 },
2648 ToolFinished {
2649 #[allow(unused)]
2650 tool_use_id: LanguageModelToolUseId,
2651 /// The pending tool use that corresponds to this tool.
2652 pending_tool_use: Option<PendingToolUse>,
2653 },
2654 CheckpointChanged,
2655 ToolConfirmationNeeded,
2656 CancelEditing,
2657 CompletionCanceled,
2658}
2659
2660impl EventEmitter<ThreadEvent> for Thread {}
2661
2662struct PendingCompletion {
2663 id: usize,
2664 queue_state: QueueState,
2665 _task: Task<()>,
2666}
2667
2668#[cfg(test)]
2669mod tests {
2670 use super::*;
2671 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2672 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2673 use assistant_tool::ToolRegistry;
2674 use editor::EditorSettings;
2675 use gpui::TestAppContext;
2676 use language_model::fake_provider::FakeLanguageModel;
2677 use project::{FakeFs, Project};
2678 use prompt_store::PromptBuilder;
2679 use serde_json::json;
2680 use settings::{Settings, SettingsStore};
2681 use std::sync::Arc;
2682 use theme::ThemeSettings;
2683 use util::path;
2684 use workspace::Workspace;
2685
2686 #[gpui::test]
2687 async fn test_message_with_context(cx: &mut TestAppContext) {
2688 init_test_settings(cx);
2689
2690 let project = create_test_project(
2691 cx,
2692 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2693 )
2694 .await;
2695
2696 let (_workspace, _thread_store, thread, context_store, model) =
2697 setup_test_environment(cx, project.clone()).await;
2698
2699 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2700 .await
2701 .unwrap();
2702
2703 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2704 let loaded_context = cx
2705 .update(|cx| load_context(vec![context], &project, &None, cx))
2706 .await;
2707
2708 // Insert user message with context
2709 let message_id = thread.update(cx, |thread, cx| {
2710 thread.insert_user_message(
2711 "Please explain this code",
2712 loaded_context,
2713 None,
2714 Vec::new(),
2715 cx,
2716 )
2717 });
2718
2719 // Check content and context in message object
2720 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2721
2722 // Use different path format strings based on platform for the test
2723 #[cfg(windows)]
2724 let path_part = r"test\code.rs";
2725 #[cfg(not(windows))]
2726 let path_part = "test/code.rs";
2727
2728 let expected_context = format!(
2729 r#"
2730<context>
2731The following items were attached by the user. They are up-to-date and don't need to be re-read.
2732
2733<files>
2734```rs {path_part}
2735fn main() {{
2736 println!("Hello, world!");
2737}}
2738```
2739</files>
2740</context>
2741"#
2742 );
2743
2744 assert_eq!(message.role, Role::User);
2745 assert_eq!(message.segments.len(), 1);
2746 assert_eq!(
2747 message.segments[0],
2748 MessageSegment::Text("Please explain this code".to_string())
2749 );
2750 assert_eq!(message.loaded_context.text, expected_context);
2751
2752 // Check message in request
2753 let request = thread.update(cx, |thread, cx| {
2754 thread.to_completion_request(model.clone(), cx)
2755 });
2756
2757 assert_eq!(request.messages.len(), 2);
2758 let expected_full_message = format!("{}Please explain this code", expected_context);
2759 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2760 }
2761
2762 #[gpui::test]
2763 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2764 init_test_settings(cx);
2765
2766 let project = create_test_project(
2767 cx,
2768 json!({
2769 "file1.rs": "fn function1() {}\n",
2770 "file2.rs": "fn function2() {}\n",
2771 "file3.rs": "fn function3() {}\n",
2772 "file4.rs": "fn function4() {}\n",
2773 }),
2774 )
2775 .await;
2776
2777 let (_, _thread_store, thread, context_store, model) =
2778 setup_test_environment(cx, project.clone()).await;
2779
2780 // First message with context 1
2781 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2782 .await
2783 .unwrap();
2784 let new_contexts = context_store.update(cx, |store, cx| {
2785 store.new_context_for_thread(thread.read(cx), None)
2786 });
2787 assert_eq!(new_contexts.len(), 1);
2788 let loaded_context = cx
2789 .update(|cx| load_context(new_contexts, &project, &None, cx))
2790 .await;
2791 let message1_id = thread.update(cx, |thread, cx| {
2792 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2793 });
2794
2795 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2796 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2797 .await
2798 .unwrap();
2799 let new_contexts = context_store.update(cx, |store, cx| {
2800 store.new_context_for_thread(thread.read(cx), None)
2801 });
2802 assert_eq!(new_contexts.len(), 1);
2803 let loaded_context = cx
2804 .update(|cx| load_context(new_contexts, &project, &None, cx))
2805 .await;
2806 let message2_id = thread.update(cx, |thread, cx| {
2807 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2808 });
2809
2810 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2811 //
2812 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2813 .await
2814 .unwrap();
2815 let new_contexts = context_store.update(cx, |store, cx| {
2816 store.new_context_for_thread(thread.read(cx), None)
2817 });
2818 assert_eq!(new_contexts.len(), 1);
2819 let loaded_context = cx
2820 .update(|cx| load_context(new_contexts, &project, &None, cx))
2821 .await;
2822 let message3_id = thread.update(cx, |thread, cx| {
2823 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2824 });
2825
2826 // Check what contexts are included in each message
2827 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2828 (
2829 thread.message(message1_id).unwrap().clone(),
2830 thread.message(message2_id).unwrap().clone(),
2831 thread.message(message3_id).unwrap().clone(),
2832 )
2833 });
2834
2835 // First message should include context 1
2836 assert!(message1.loaded_context.text.contains("file1.rs"));
2837
2838 // Second message should include only context 2 (not 1)
2839 assert!(!message2.loaded_context.text.contains("file1.rs"));
2840 assert!(message2.loaded_context.text.contains("file2.rs"));
2841
2842 // Third message should include only context 3 (not 1 or 2)
2843 assert!(!message3.loaded_context.text.contains("file1.rs"));
2844 assert!(!message3.loaded_context.text.contains("file2.rs"));
2845 assert!(message3.loaded_context.text.contains("file3.rs"));
2846
2847 // Check entire request to make sure all contexts are properly included
2848 let request = thread.update(cx, |thread, cx| {
2849 thread.to_completion_request(model.clone(), cx)
2850 });
2851
2852 // The request should contain all 3 messages
2853 assert_eq!(request.messages.len(), 4);
2854
2855 // Check that the contexts are properly formatted in each message
2856 assert!(request.messages[1].string_contents().contains("file1.rs"));
2857 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2858 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2859
2860 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2861 assert!(request.messages[2].string_contents().contains("file2.rs"));
2862 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2863
2864 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2865 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2866 assert!(request.messages[3].string_contents().contains("file3.rs"));
2867
2868 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2869 .await
2870 .unwrap();
2871 let new_contexts = context_store.update(cx, |store, cx| {
2872 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2873 });
2874 assert_eq!(new_contexts.len(), 3);
2875 let loaded_context = cx
2876 .update(|cx| load_context(new_contexts, &project, &None, cx))
2877 .await
2878 .loaded_context;
2879
2880 assert!(!loaded_context.text.contains("file1.rs"));
2881 assert!(loaded_context.text.contains("file2.rs"));
2882 assert!(loaded_context.text.contains("file3.rs"));
2883 assert!(loaded_context.text.contains("file4.rs"));
2884
2885 let new_contexts = context_store.update(cx, |store, cx| {
2886 // Remove file4.rs
2887 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2888 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2889 });
2890 assert_eq!(new_contexts.len(), 2);
2891 let loaded_context = cx
2892 .update(|cx| load_context(new_contexts, &project, &None, cx))
2893 .await
2894 .loaded_context;
2895
2896 assert!(!loaded_context.text.contains("file1.rs"));
2897 assert!(loaded_context.text.contains("file2.rs"));
2898 assert!(loaded_context.text.contains("file3.rs"));
2899 assert!(!loaded_context.text.contains("file4.rs"));
2900
2901 let new_contexts = context_store.update(cx, |store, cx| {
2902 // Remove file3.rs
2903 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2904 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2905 });
2906 assert_eq!(new_contexts.len(), 1);
2907 let loaded_context = cx
2908 .update(|cx| load_context(new_contexts, &project, &None, cx))
2909 .await
2910 .loaded_context;
2911
2912 assert!(!loaded_context.text.contains("file1.rs"));
2913 assert!(loaded_context.text.contains("file2.rs"));
2914 assert!(!loaded_context.text.contains("file3.rs"));
2915 assert!(!loaded_context.text.contains("file4.rs"));
2916 }
2917
2918 #[gpui::test]
2919 async fn test_message_without_files(cx: &mut TestAppContext) {
2920 init_test_settings(cx);
2921
2922 let project = create_test_project(
2923 cx,
2924 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2925 )
2926 .await;
2927
2928 let (_, _thread_store, thread, _context_store, model) =
2929 setup_test_environment(cx, project.clone()).await;
2930
2931 // Insert user message without any context (empty context vector)
2932 let message_id = thread.update(cx, |thread, cx| {
2933 thread.insert_user_message(
2934 "What is the best way to learn Rust?",
2935 ContextLoadResult::default(),
2936 None,
2937 Vec::new(),
2938 cx,
2939 )
2940 });
2941
2942 // Check content and context in message object
2943 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2944
2945 // Context should be empty when no files are included
2946 assert_eq!(message.role, Role::User);
2947 assert_eq!(message.segments.len(), 1);
2948 assert_eq!(
2949 message.segments[0],
2950 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2951 );
2952 assert_eq!(message.loaded_context.text, "");
2953
2954 // Check message in request
2955 let request = thread.update(cx, |thread, cx| {
2956 thread.to_completion_request(model.clone(), cx)
2957 });
2958
2959 assert_eq!(request.messages.len(), 2);
2960 assert_eq!(
2961 request.messages[1].string_contents(),
2962 "What is the best way to learn Rust?"
2963 );
2964
2965 // Add second message, also without context
2966 let message2_id = thread.update(cx, |thread, cx| {
2967 thread.insert_user_message(
2968 "Are there any good books?",
2969 ContextLoadResult::default(),
2970 None,
2971 Vec::new(),
2972 cx,
2973 )
2974 });
2975
2976 let message2 =
2977 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2978 assert_eq!(message2.loaded_context.text, "");
2979
2980 // Check that both messages appear in the request
2981 let request = thread.update(cx, |thread, cx| {
2982 thread.to_completion_request(model.clone(), cx)
2983 });
2984
2985 assert_eq!(request.messages.len(), 3);
2986 assert_eq!(
2987 request.messages[1].string_contents(),
2988 "What is the best way to learn Rust?"
2989 );
2990 assert_eq!(
2991 request.messages[2].string_contents(),
2992 "Are there any good books?"
2993 );
2994 }
2995
2996 #[gpui::test]
2997 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2998 init_test_settings(cx);
2999
3000 let project = create_test_project(
3001 cx,
3002 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3003 )
3004 .await;
3005
3006 let (_workspace, _thread_store, thread, context_store, model) =
3007 setup_test_environment(cx, project.clone()).await;
3008
3009 // Open buffer and add it to context
3010 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3011 .await
3012 .unwrap();
3013
3014 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3015 let loaded_context = cx
3016 .update(|cx| load_context(vec![context], &project, &None, cx))
3017 .await;
3018
3019 // Insert user message with the buffer as context
3020 thread.update(cx, |thread, cx| {
3021 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3022 });
3023
3024 // Create a request and check that it doesn't have a stale buffer warning yet
3025 let initial_request = thread.update(cx, |thread, cx| {
3026 thread.to_completion_request(model.clone(), cx)
3027 });
3028
3029 // Make sure we don't have a stale file warning yet
3030 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3031 msg.string_contents()
3032 .contains("These files changed since last read:")
3033 });
3034 assert!(
3035 !has_stale_warning,
3036 "Should not have stale buffer warning before buffer is modified"
3037 );
3038
3039 // Modify the buffer
3040 buffer.update(cx, |buffer, cx| {
3041 // Find a position at the end of line 1
3042 buffer.edit(
3043 [(1..1, "\n println!(\"Added a new line\");\n")],
3044 None,
3045 cx,
3046 );
3047 });
3048
3049 // Insert another user message without context
3050 thread.update(cx, |thread, cx| {
3051 thread.insert_user_message(
3052 "What does the code do now?",
3053 ContextLoadResult::default(),
3054 None,
3055 Vec::new(),
3056 cx,
3057 )
3058 });
3059
3060 // Create a new request and check for the stale buffer warning
3061 let new_request = thread.update(cx, |thread, cx| {
3062 thread.to_completion_request(model.clone(), cx)
3063 });
3064
3065 // We should have a stale file warning as the last message
3066 let last_message = new_request
3067 .messages
3068 .last()
3069 .expect("Request should have messages");
3070
3071 // The last message should be the stale buffer notification
3072 assert_eq!(last_message.role, Role::User);
3073
3074 // Check the exact content of the message
3075 let expected_content = "These files changed since last read:\n- code.rs\n";
3076 assert_eq!(
3077 last_message.string_contents(),
3078 expected_content,
3079 "Last message should be exactly the stale buffer notification"
3080 );
3081 }
3082
3083 #[gpui::test]
3084 async fn test_temperature_setting(cx: &mut TestAppContext) {
3085 init_test_settings(cx);
3086
3087 let project = create_test_project(
3088 cx,
3089 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3090 )
3091 .await;
3092
3093 let (_workspace, _thread_store, thread, _context_store, model) =
3094 setup_test_environment(cx, project.clone()).await;
3095
3096 // Both model and provider
3097 cx.update(|cx| {
3098 AssistantSettings::override_global(
3099 AssistantSettings {
3100 model_parameters: vec![LanguageModelParameters {
3101 provider: Some(model.provider_id().0.to_string().into()),
3102 model: Some(model.id().0.clone()),
3103 temperature: Some(0.66),
3104 }],
3105 ..AssistantSettings::get_global(cx).clone()
3106 },
3107 cx,
3108 );
3109 });
3110
3111 let request = thread.update(cx, |thread, cx| {
3112 thread.to_completion_request(model.clone(), cx)
3113 });
3114 assert_eq!(request.temperature, Some(0.66));
3115
3116 // Only model
3117 cx.update(|cx| {
3118 AssistantSettings::override_global(
3119 AssistantSettings {
3120 model_parameters: vec![LanguageModelParameters {
3121 provider: None,
3122 model: Some(model.id().0.clone()),
3123 temperature: Some(0.66),
3124 }],
3125 ..AssistantSettings::get_global(cx).clone()
3126 },
3127 cx,
3128 );
3129 });
3130
3131 let request = thread.update(cx, |thread, cx| {
3132 thread.to_completion_request(model.clone(), cx)
3133 });
3134 assert_eq!(request.temperature, Some(0.66));
3135
3136 // Only provider
3137 cx.update(|cx| {
3138 AssistantSettings::override_global(
3139 AssistantSettings {
3140 model_parameters: vec![LanguageModelParameters {
3141 provider: Some(model.provider_id().0.to_string().into()),
3142 model: None,
3143 temperature: Some(0.66),
3144 }],
3145 ..AssistantSettings::get_global(cx).clone()
3146 },
3147 cx,
3148 );
3149 });
3150
3151 let request = thread.update(cx, |thread, cx| {
3152 thread.to_completion_request(model.clone(), cx)
3153 });
3154 assert_eq!(request.temperature, Some(0.66));
3155
3156 // Same model name, different provider
3157 cx.update(|cx| {
3158 AssistantSettings::override_global(
3159 AssistantSettings {
3160 model_parameters: vec![LanguageModelParameters {
3161 provider: Some("anthropic".into()),
3162 model: Some(model.id().0.clone()),
3163 temperature: Some(0.66),
3164 }],
3165 ..AssistantSettings::get_global(cx).clone()
3166 },
3167 cx,
3168 );
3169 });
3170
3171 let request = thread.update(cx, |thread, cx| {
3172 thread.to_completion_request(model.clone(), cx)
3173 });
3174 assert_eq!(request.temperature, None);
3175 }
3176
3177 fn init_test_settings(cx: &mut TestAppContext) {
3178 cx.update(|cx| {
3179 let settings_store = SettingsStore::test(cx);
3180 cx.set_global(settings_store);
3181 language::init(cx);
3182 Project::init_settings(cx);
3183 AssistantSettings::register(cx);
3184 prompt_store::init(cx);
3185 thread_store::init(cx);
3186 workspace::init_settings(cx);
3187 language_model::init_settings(cx);
3188 ThemeSettings::register(cx);
3189 EditorSettings::register(cx);
3190 ToolRegistry::default_global(cx);
3191 });
3192 }
3193
3194 // Helper to create a test project with test files
3195 async fn create_test_project(
3196 cx: &mut TestAppContext,
3197 files: serde_json::Value,
3198 ) -> Entity<Project> {
3199 let fs = FakeFs::new(cx.executor());
3200 fs.insert_tree(path!("/test"), files).await;
3201 Project::test(fs, [path!("/test").as_ref()], cx).await
3202 }
3203
3204 async fn setup_test_environment(
3205 cx: &mut TestAppContext,
3206 project: Entity<Project>,
3207 ) -> (
3208 Entity<Workspace>,
3209 Entity<ThreadStore>,
3210 Entity<Thread>,
3211 Entity<ContextStore>,
3212 Arc<dyn LanguageModel>,
3213 ) {
3214 let (workspace, cx) =
3215 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3216
3217 let thread_store = cx
3218 .update(|_, cx| {
3219 ThreadStore::load(
3220 project.clone(),
3221 cx.new(|_| ToolWorkingSet::default()),
3222 None,
3223 Arc::new(PromptBuilder::new(None).unwrap()),
3224 cx,
3225 )
3226 })
3227 .await
3228 .unwrap();
3229
3230 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3231 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3232
3233 let model = FakeLanguageModel::default();
3234 let model: Arc<dyn LanguageModel> = Arc::new(model);
3235
3236 (workspace, thread_store, thread, context_store, model)
3237 }
3238
3239 async fn add_file_to_context(
3240 project: &Entity<Project>,
3241 context_store: &Entity<ContextStore>,
3242 path: &str,
3243 cx: &mut TestAppContext,
3244 ) -> Result<Entity<language::Buffer>> {
3245 let buffer_path = project
3246 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3247 .unwrap();
3248
3249 let buffer = project
3250 .update(cx, |project, cx| {
3251 project.open_buffer(buffer_path.clone(), cx)
3252 })
3253 .await
3254 .unwrap();
3255
3256 context_store.update(cx, |context_store, cx| {
3257 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3258 });
3259
3260 Ok(buffer)
3261 }
3262}