1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, TryFutureExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: Option<SharedString>,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct ExceededWindowError {
366 /// Model used when last message exceeded context window
367 model_id: LanguageModelId,
368 /// Token count including last message
369 token_count: usize,
370}
371
372impl Thread {
373 pub fn new(
374 project: Entity<Project>,
375 tools: Entity<ToolWorkingSet>,
376 prompt_builder: Arc<PromptBuilder>,
377 system_prompt: SharedProjectContext,
378 cx: &mut Context<Self>,
379 ) -> Self {
380 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
381 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
382
383 Self {
384 id: ThreadId::new(),
385 updated_at: Utc::now(),
386 summary: None,
387 pending_summary: Task::ready(None),
388 detailed_summary_task: Task::ready(None),
389 detailed_summary_tx,
390 detailed_summary_rx,
391 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
392 messages: Vec::new(),
393 next_message_id: MessageId(0),
394 last_prompt_id: PromptId::new(),
395 project_context: system_prompt,
396 checkpoints_by_message: HashMap::default(),
397 completion_count: 0,
398 pending_completions: Vec::new(),
399 project: project.clone(),
400 prompt_builder,
401 tools: tools.clone(),
402 last_restore_checkpoint: None,
403 pending_checkpoint: None,
404 tool_use: ToolUseState::new(tools.clone()),
405 action_log: cx.new(|_| ActionLog::new(project.clone())),
406 initial_project_snapshot: {
407 let project_snapshot = Self::project_snapshot(project, cx);
408 cx.foreground_executor()
409 .spawn(async move { Some(project_snapshot.await) })
410 .shared()
411 },
412 request_token_usage: Vec::new(),
413 cumulative_token_usage: TokenUsage::default(),
414 exceeded_window_error: None,
415 last_usage: None,
416 tool_use_limit_reached: false,
417 feedback: None,
418 message_feedback: HashMap::default(),
419 last_auto_capture_at: None,
420 last_received_chunk_at: None,
421 request_callback: None,
422 remaining_turns: u32::MAX,
423 configured_model,
424 }
425 }
426
427 pub fn deserialize(
428 id: ThreadId,
429 serialized: SerializedThread,
430 project: Entity<Project>,
431 tools: Entity<ToolWorkingSet>,
432 prompt_builder: Arc<PromptBuilder>,
433 project_context: SharedProjectContext,
434 window: &mut Window,
435 cx: &mut Context<Self>,
436 ) -> Self {
437 let next_message_id = MessageId(
438 serialized
439 .messages
440 .last()
441 .map(|message| message.id.0 + 1)
442 .unwrap_or(0),
443 );
444 let tool_use = ToolUseState::from_serialized_messages(
445 tools.clone(),
446 &serialized.messages,
447 project.clone(),
448 window,
449 cx,
450 );
451 let (detailed_summary_tx, detailed_summary_rx) =
452 postage::watch::channel_with(serialized.detailed_summary_state);
453
454 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
455 serialized
456 .model
457 .and_then(|model| {
458 let model = SelectedModel {
459 provider: model.provider.clone().into(),
460 model: model.model.clone().into(),
461 };
462 registry.select_model(&model, cx)
463 })
464 .or_else(|| registry.default_model())
465 });
466
467 let completion_mode = serialized
468 .completion_mode
469 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
470
471 Self {
472 id,
473 updated_at: serialized.updated_at,
474 summary: Some(serialized.summary),
475 pending_summary: Task::ready(None),
476 detailed_summary_task: Task::ready(None),
477 detailed_summary_tx,
478 detailed_summary_rx,
479 completion_mode,
480 messages: serialized
481 .messages
482 .into_iter()
483 .map(|message| Message {
484 id: message.id,
485 role: message.role,
486 segments: message
487 .segments
488 .into_iter()
489 .map(|segment| match segment {
490 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
491 SerializedMessageSegment::Thinking { text, signature } => {
492 MessageSegment::Thinking { text, signature }
493 }
494 SerializedMessageSegment::RedactedThinking { data } => {
495 MessageSegment::RedactedThinking(data)
496 }
497 })
498 .collect(),
499 loaded_context: LoadedContext {
500 contexts: Vec::new(),
501 text: message.context,
502 images: Vec::new(),
503 },
504 creases: message
505 .creases
506 .into_iter()
507 .map(|crease| MessageCrease {
508 range: crease.start..crease.end,
509 metadata: CreaseMetadata {
510 icon_path: crease.icon_path,
511 label: crease.label,
512 },
513 context: None,
514 })
515 .collect(),
516 })
517 .collect(),
518 next_message_id,
519 last_prompt_id: PromptId::new(),
520 project_context,
521 checkpoints_by_message: HashMap::default(),
522 completion_count: 0,
523 pending_completions: Vec::new(),
524 last_restore_checkpoint: None,
525 pending_checkpoint: None,
526 project: project.clone(),
527 prompt_builder,
528 tools,
529 tool_use,
530 action_log: cx.new(|_| ActionLog::new(project)),
531 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
532 request_token_usage: serialized.request_token_usage,
533 cumulative_token_usage: serialized.cumulative_token_usage,
534 exceeded_window_error: None,
535 last_usage: None,
536 tool_use_limit_reached: false,
537 feedback: None,
538 message_feedback: HashMap::default(),
539 last_auto_capture_at: None,
540 last_received_chunk_at: None,
541 request_callback: None,
542 remaining_turns: u32::MAX,
543 configured_model,
544 }
545 }
546
547 pub fn set_request_callback(
548 &mut self,
549 callback: impl 'static
550 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
551 ) {
552 self.request_callback = Some(Box::new(callback));
553 }
554
555 pub fn id(&self) -> &ThreadId {
556 &self.id
557 }
558
559 pub fn is_empty(&self) -> bool {
560 self.messages.is_empty()
561 }
562
563 pub fn updated_at(&self) -> DateTime<Utc> {
564 self.updated_at
565 }
566
567 pub fn touch_updated_at(&mut self) {
568 self.updated_at = Utc::now();
569 }
570
571 pub fn advance_prompt_id(&mut self) {
572 self.last_prompt_id = PromptId::new();
573 }
574
575 pub fn summary(&self) -> Option<SharedString> {
576 self.summary.clone()
577 }
578
579 pub fn project_context(&self) -> SharedProjectContext {
580 self.project_context.clone()
581 }
582
583 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
584 if self.configured_model.is_none() {
585 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
586 }
587 self.configured_model.clone()
588 }
589
590 pub fn configured_model(&self) -> Option<ConfiguredModel> {
591 self.configured_model.clone()
592 }
593
594 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
595 self.configured_model = model;
596 cx.notify();
597 }
598
599 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
600
601 pub fn summary_or_default(&self) -> SharedString {
602 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
603 }
604
605 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
606 let Some(current_summary) = &self.summary else {
607 // Don't allow setting summary until generated
608 return;
609 };
610
611 let mut new_summary = new_summary.into();
612
613 if new_summary.is_empty() {
614 new_summary = Self::DEFAULT_SUMMARY;
615 }
616
617 if current_summary != &new_summary {
618 self.summary = Some(new_summary);
619 cx.emit(ThreadEvent::SummaryChanged);
620 }
621 }
622
623 pub fn completion_mode(&self) -> CompletionMode {
624 self.completion_mode
625 }
626
627 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
628 self.completion_mode = mode;
629 }
630
631 pub fn message(&self, id: MessageId) -> Option<&Message> {
632 let index = self
633 .messages
634 .binary_search_by(|message| message.id.cmp(&id))
635 .ok()?;
636
637 self.messages.get(index)
638 }
639
640 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
641 self.messages.iter()
642 }
643
644 pub fn is_generating(&self) -> bool {
645 !self.pending_completions.is_empty() || !self.all_tools_finished()
646 }
647
648 /// Indicates whether streaming of language model events is stale.
649 /// When `is_generating()` is false, this method returns `None`.
650 pub fn is_generation_stale(&self) -> Option<bool> {
651 const STALE_THRESHOLD: u128 = 250;
652
653 self.last_received_chunk_at
654 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
655 }
656
657 fn received_chunk(&mut self) {
658 self.last_received_chunk_at = Some(Instant::now());
659 }
660
661 pub fn queue_state(&self) -> Option<QueueState> {
662 self.pending_completions
663 .first()
664 .map(|pending_completion| pending_completion.queue_state)
665 }
666
667 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
668 &self.tools
669 }
670
671 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
672 self.tool_use
673 .pending_tool_uses()
674 .into_iter()
675 .find(|tool_use| &tool_use.id == id)
676 }
677
678 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
679 self.tool_use
680 .pending_tool_uses()
681 .into_iter()
682 .filter(|tool_use| tool_use.status.needs_confirmation())
683 }
684
685 pub fn has_pending_tool_uses(&self) -> bool {
686 !self.tool_use.pending_tool_uses().is_empty()
687 }
688
689 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
690 self.checkpoints_by_message.get(&id).cloned()
691 }
692
693 pub fn restore_checkpoint(
694 &mut self,
695 checkpoint: ThreadCheckpoint,
696 cx: &mut Context<Self>,
697 ) -> Task<Result<()>> {
698 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
699 message_id: checkpoint.message_id,
700 });
701 cx.emit(ThreadEvent::CheckpointChanged);
702 cx.notify();
703
704 let git_store = self.project().read(cx).git_store().clone();
705 let restore = git_store.update(cx, |git_store, cx| {
706 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
707 });
708
709 cx.spawn(async move |this, cx| {
710 let result = restore.await;
711 this.update(cx, |this, cx| {
712 if let Err(err) = result.as_ref() {
713 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
714 message_id: checkpoint.message_id,
715 error: err.to_string(),
716 });
717 } else {
718 this.truncate(checkpoint.message_id, cx);
719 this.last_restore_checkpoint = None;
720 }
721 this.pending_checkpoint = None;
722 cx.emit(ThreadEvent::CheckpointChanged);
723 cx.notify();
724 })?;
725 result
726 })
727 }
728
729 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
730 let pending_checkpoint = if self.is_generating() {
731 return;
732 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
733 checkpoint
734 } else {
735 return;
736 };
737
738 let git_store = self.project.read(cx).git_store().clone();
739 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
740 cx.spawn(async move |this, cx| match final_checkpoint.await {
741 Ok(final_checkpoint) => {
742 let equal = git_store
743 .update(cx, |store, cx| {
744 store.compare_checkpoints(
745 pending_checkpoint.git_checkpoint.clone(),
746 final_checkpoint.clone(),
747 cx,
748 )
749 })?
750 .await
751 .unwrap_or(false);
752
753 if !equal {
754 this.update(cx, |this, cx| {
755 this.insert_checkpoint(pending_checkpoint, cx)
756 })?;
757 }
758
759 Ok(())
760 }
761 Err(_) => this.update(cx, |this, cx| {
762 this.insert_checkpoint(pending_checkpoint, cx)
763 }),
764 })
765 .detach();
766 }
767
768 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
769 self.checkpoints_by_message
770 .insert(checkpoint.message_id, checkpoint);
771 cx.emit(ThreadEvent::CheckpointChanged);
772 cx.notify();
773 }
774
775 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
776 self.last_restore_checkpoint.as_ref()
777 }
778
779 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
780 let Some(message_ix) = self
781 .messages
782 .iter()
783 .rposition(|message| message.id == message_id)
784 else {
785 return;
786 };
787 for deleted_message in self.messages.drain(message_ix..) {
788 self.checkpoints_by_message.remove(&deleted_message.id);
789 }
790 cx.notify();
791 }
792
793 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
794 self.messages
795 .iter()
796 .find(|message| message.id == id)
797 .into_iter()
798 .flat_map(|message| message.loaded_context.contexts.iter())
799 }
800
801 pub fn is_turn_end(&self, ix: usize) -> bool {
802 if self.messages.is_empty() {
803 return false;
804 }
805
806 if !self.is_generating() && ix == self.messages.len() - 1 {
807 return true;
808 }
809
810 let Some(message) = self.messages.get(ix) else {
811 return false;
812 };
813
814 if message.role != Role::Assistant {
815 return false;
816 }
817
818 self.messages
819 .get(ix + 1)
820 .and_then(|message| {
821 self.message(message.id)
822 .map(|next_message| next_message.role == Role::User)
823 })
824 .unwrap_or(false)
825 }
826
827 pub fn last_usage(&self) -> Option<RequestUsage> {
828 self.last_usage
829 }
830
831 pub fn tool_use_limit_reached(&self) -> bool {
832 self.tool_use_limit_reached
833 }
834
835 /// Returns whether all of the tool uses have finished running.
836 pub fn all_tools_finished(&self) -> bool {
837 // If the only pending tool uses left are the ones with errors, then
838 // that means that we've finished running all of the pending tools.
839 self.tool_use
840 .pending_tool_uses()
841 .iter()
842 .all(|tool_use| tool_use.status.is_error())
843 }
844
845 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
846 self.tool_use.tool_uses_for_message(id, cx)
847 }
848
849 pub fn tool_results_for_message(
850 &self,
851 assistant_message_id: MessageId,
852 ) -> Vec<&LanguageModelToolResult> {
853 self.tool_use.tool_results_for_message(assistant_message_id)
854 }
855
856 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
857 self.tool_use.tool_result(id)
858 }
859
860 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
861 Some(&self.tool_use.tool_result(id)?.content)
862 }
863
864 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
865 self.tool_use.tool_result_card(id).cloned()
866 }
867
868 /// Return tools that are both enabled and supported by the model
869 pub fn available_tools(
870 &self,
871 cx: &App,
872 model: Arc<dyn LanguageModel>,
873 ) -> Vec<LanguageModelRequestTool> {
874 if model.supports_tools() {
875 self.tools()
876 .read(cx)
877 .enabled_tools(cx)
878 .into_iter()
879 .filter_map(|tool| {
880 // Skip tools that cannot be supported
881 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
882 Some(LanguageModelRequestTool {
883 name: tool.name(),
884 description: tool.description(),
885 input_schema,
886 })
887 })
888 .collect()
889 } else {
890 Vec::default()
891 }
892 }
893
894 pub fn insert_user_message(
895 &mut self,
896 text: impl Into<String>,
897 loaded_context: ContextLoadResult,
898 git_checkpoint: Option<GitStoreCheckpoint>,
899 creases: Vec<MessageCrease>,
900 cx: &mut Context<Self>,
901 ) -> MessageId {
902 if !loaded_context.referenced_buffers.is_empty() {
903 self.action_log.update(cx, |log, cx| {
904 for buffer in loaded_context.referenced_buffers {
905 log.buffer_read(buffer, cx);
906 }
907 });
908 }
909
910 let message_id = self.insert_message(
911 Role::User,
912 vec![MessageSegment::Text(text.into())],
913 loaded_context.loaded_context,
914 creases,
915 cx,
916 );
917
918 if let Some(git_checkpoint) = git_checkpoint {
919 self.pending_checkpoint = Some(ThreadCheckpoint {
920 message_id,
921 git_checkpoint,
922 });
923 }
924
925 self.auto_capture_telemetry(cx);
926
927 message_id
928 }
929
930 pub fn insert_assistant_message(
931 &mut self,
932 segments: Vec<MessageSegment>,
933 cx: &mut Context<Self>,
934 ) -> MessageId {
935 self.insert_message(
936 Role::Assistant,
937 segments,
938 LoadedContext::default(),
939 Vec::new(),
940 cx,
941 )
942 }
943
944 pub fn insert_message(
945 &mut self,
946 role: Role,
947 segments: Vec<MessageSegment>,
948 loaded_context: LoadedContext,
949 creases: Vec<MessageCrease>,
950 cx: &mut Context<Self>,
951 ) -> MessageId {
952 let id = self.next_message_id.post_inc();
953 self.messages.push(Message {
954 id,
955 role,
956 segments,
957 loaded_context,
958 creases,
959 });
960 self.touch_updated_at();
961 cx.emit(ThreadEvent::MessageAdded(id));
962 id
963 }
964
965 pub fn edit_message(
966 &mut self,
967 id: MessageId,
968 new_role: Role,
969 new_segments: Vec<MessageSegment>,
970 loaded_context: Option<LoadedContext>,
971 cx: &mut Context<Self>,
972 ) -> bool {
973 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
974 return false;
975 };
976 message.role = new_role;
977 message.segments = new_segments;
978 if let Some(context) = loaded_context {
979 message.loaded_context = context;
980 }
981 self.touch_updated_at();
982 cx.emit(ThreadEvent::MessageEdited(id));
983 true
984 }
985
986 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
987 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
988 return false;
989 };
990 self.messages.remove(index);
991 self.touch_updated_at();
992 cx.emit(ThreadEvent::MessageDeleted(id));
993 true
994 }
995
996 /// Returns the representation of this [`Thread`] in a textual form.
997 ///
998 /// This is the representation we use when attaching a thread as context to another thread.
999 pub fn text(&self) -> String {
1000 let mut text = String::new();
1001
1002 for message in &self.messages {
1003 text.push_str(match message.role {
1004 language_model::Role::User => "User:",
1005 language_model::Role::Assistant => "Agent:",
1006 language_model::Role::System => "System:",
1007 });
1008 text.push('\n');
1009
1010 for segment in &message.segments {
1011 match segment {
1012 MessageSegment::Text(content) => text.push_str(content),
1013 MessageSegment::Thinking { text: content, .. } => {
1014 text.push_str(&format!("<think>{}</think>", content))
1015 }
1016 MessageSegment::RedactedThinking(_) => {}
1017 }
1018 }
1019 text.push('\n');
1020 }
1021
1022 text
1023 }
1024
1025 /// Serializes this thread into a format for storage or telemetry.
1026 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1027 let initial_project_snapshot = self.initial_project_snapshot.clone();
1028 cx.spawn(async move |this, cx| {
1029 let initial_project_snapshot = initial_project_snapshot.await;
1030 this.read_with(cx, |this, cx| SerializedThread {
1031 version: SerializedThread::VERSION.to_string(),
1032 summary: this.summary_or_default(),
1033 updated_at: this.updated_at(),
1034 messages: this
1035 .messages()
1036 .map(|message| SerializedMessage {
1037 id: message.id,
1038 role: message.role,
1039 segments: message
1040 .segments
1041 .iter()
1042 .map(|segment| match segment {
1043 MessageSegment::Text(text) => {
1044 SerializedMessageSegment::Text { text: text.clone() }
1045 }
1046 MessageSegment::Thinking { text, signature } => {
1047 SerializedMessageSegment::Thinking {
1048 text: text.clone(),
1049 signature: signature.clone(),
1050 }
1051 }
1052 MessageSegment::RedactedThinking(data) => {
1053 SerializedMessageSegment::RedactedThinking {
1054 data: data.clone(),
1055 }
1056 }
1057 })
1058 .collect(),
1059 tool_uses: this
1060 .tool_uses_for_message(message.id, cx)
1061 .into_iter()
1062 .map(|tool_use| SerializedToolUse {
1063 id: tool_use.id,
1064 name: tool_use.name,
1065 input: tool_use.input,
1066 })
1067 .collect(),
1068 tool_results: this
1069 .tool_results_for_message(message.id)
1070 .into_iter()
1071 .map(|tool_result| SerializedToolResult {
1072 tool_use_id: tool_result.tool_use_id.clone(),
1073 is_error: tool_result.is_error,
1074 content: tool_result.content.clone(),
1075 output: tool_result.output.clone(),
1076 })
1077 .collect(),
1078 context: message.loaded_context.text.clone(),
1079 creases: message
1080 .creases
1081 .iter()
1082 .map(|crease| SerializedCrease {
1083 start: crease.range.start,
1084 end: crease.range.end,
1085 icon_path: crease.metadata.icon_path.clone(),
1086 label: crease.metadata.label.clone(),
1087 })
1088 .collect(),
1089 })
1090 .collect(),
1091 initial_project_snapshot,
1092 cumulative_token_usage: this.cumulative_token_usage,
1093 request_token_usage: this.request_token_usage.clone(),
1094 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1095 exceeded_window_error: this.exceeded_window_error.clone(),
1096 model: this
1097 .configured_model
1098 .as_ref()
1099 .map(|model| SerializedLanguageModel {
1100 provider: model.provider.id().0.to_string(),
1101 model: model.model.id().0.to_string(),
1102 }),
1103 completion_mode: Some(this.completion_mode),
1104 })
1105 })
1106 }
1107
1108 pub fn remaining_turns(&self) -> u32 {
1109 self.remaining_turns
1110 }
1111
1112 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1113 self.remaining_turns = remaining_turns;
1114 }
1115
1116 pub fn send_to_model(
1117 &mut self,
1118 model: Arc<dyn LanguageModel>,
1119 window: Option<AnyWindowHandle>,
1120 cx: &mut Context<Self>,
1121 ) {
1122 if self.remaining_turns == 0 {
1123 return;
1124 }
1125
1126 self.remaining_turns -= 1;
1127
1128 let request = self.to_completion_request(model.clone(), cx);
1129
1130 self.stream_completion(request, model, window, cx);
1131 }
1132
1133 pub fn used_tools_since_last_user_message(&self) -> bool {
1134 for message in self.messages.iter().rev() {
1135 if self.tool_use.message_has_tool_results(message.id) {
1136 return true;
1137 } else if message.role == Role::User {
1138 return false;
1139 }
1140 }
1141
1142 false
1143 }
1144
1145 pub fn to_completion_request(
1146 &self,
1147 model: Arc<dyn LanguageModel>,
1148 cx: &mut Context<Self>,
1149 ) -> LanguageModelRequest {
1150 let mut request = LanguageModelRequest {
1151 thread_id: Some(self.id.to_string()),
1152 prompt_id: Some(self.last_prompt_id.to_string()),
1153 mode: None,
1154 messages: vec![],
1155 tools: Vec::new(),
1156 stop: Vec::new(),
1157 temperature: AssistantSettings::temperature_for_model(&model, cx),
1158 };
1159
1160 let available_tools = self.available_tools(cx, model.clone());
1161 let available_tool_names = available_tools
1162 .iter()
1163 .map(|tool| tool.name.clone())
1164 .collect();
1165
1166 let model_context = &ModelContext {
1167 available_tools: available_tool_names,
1168 };
1169
1170 if let Some(project_context) = self.project_context.borrow().as_ref() {
1171 match self
1172 .prompt_builder
1173 .generate_assistant_system_prompt(project_context, model_context)
1174 {
1175 Err(err) => {
1176 let message = format!("{err:?}").into();
1177 log::error!("{message}");
1178 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1179 header: "Error generating system prompt".into(),
1180 message,
1181 }));
1182 }
1183 Ok(system_prompt) => {
1184 request.messages.push(LanguageModelRequestMessage {
1185 role: Role::System,
1186 content: vec![MessageContent::Text(system_prompt)],
1187 cache: true,
1188 });
1189 }
1190 }
1191 } else {
1192 let message = "Context for system prompt unexpectedly not ready.".into();
1193 log::error!("{message}");
1194 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1195 header: "Error generating system prompt".into(),
1196 message,
1197 }));
1198 }
1199
1200 for message in &self.messages {
1201 let mut request_message = LanguageModelRequestMessage {
1202 role: message.role,
1203 content: Vec::new(),
1204 cache: false,
1205 };
1206
1207 message
1208 .loaded_context
1209 .add_to_request_message(&mut request_message);
1210
1211 for segment in &message.segments {
1212 match segment {
1213 MessageSegment::Text(text) => {
1214 if !text.is_empty() {
1215 request_message
1216 .content
1217 .push(MessageContent::Text(text.into()));
1218 }
1219 }
1220 MessageSegment::Thinking { text, signature } => {
1221 if !text.is_empty() {
1222 request_message.content.push(MessageContent::Thinking {
1223 text: text.into(),
1224 signature: signature.clone(),
1225 });
1226 }
1227 }
1228 MessageSegment::RedactedThinking(data) => {
1229 request_message
1230 .content
1231 .push(MessageContent::RedactedThinking(data.clone()));
1232 }
1233 };
1234 }
1235
1236 self.tool_use
1237 .attach_tool_uses(message.id, &mut request_message);
1238
1239 request.messages.push(request_message);
1240
1241 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1242 request.messages.push(tool_results_message);
1243 }
1244 }
1245
1246 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1247 if let Some(last) = request.messages.last_mut() {
1248 last.cache = true;
1249 }
1250
1251 self.attached_tracked_files_state(&mut request.messages, cx);
1252
1253 request.tools = available_tools;
1254 request.mode = if model.supports_max_mode() {
1255 Some(self.completion_mode.into())
1256 } else {
1257 Some(CompletionMode::Normal.into())
1258 };
1259
1260 request
1261 }
1262
1263 fn to_summarize_request(
1264 &self,
1265 model: &Arc<dyn LanguageModel>,
1266 added_user_message: String,
1267 cx: &App,
1268 ) -> LanguageModelRequest {
1269 let mut request = LanguageModelRequest {
1270 thread_id: None,
1271 prompt_id: None,
1272 mode: None,
1273 messages: vec![],
1274 tools: Vec::new(),
1275 stop: Vec::new(),
1276 temperature: AssistantSettings::temperature_for_model(model, cx),
1277 };
1278
1279 for message in &self.messages {
1280 let mut request_message = LanguageModelRequestMessage {
1281 role: message.role,
1282 content: Vec::new(),
1283 cache: false,
1284 };
1285
1286 for segment in &message.segments {
1287 match segment {
1288 MessageSegment::Text(text) => request_message
1289 .content
1290 .push(MessageContent::Text(text.clone())),
1291 MessageSegment::Thinking { .. } => {}
1292 MessageSegment::RedactedThinking(_) => {}
1293 }
1294 }
1295
1296 if request_message.content.is_empty() {
1297 continue;
1298 }
1299
1300 request.messages.push(request_message);
1301 }
1302
1303 request.messages.push(LanguageModelRequestMessage {
1304 role: Role::User,
1305 content: vec![MessageContent::Text(added_user_message)],
1306 cache: false,
1307 });
1308
1309 request
1310 }
1311
1312 fn attached_tracked_files_state(
1313 &self,
1314 messages: &mut Vec<LanguageModelRequestMessage>,
1315 cx: &App,
1316 ) {
1317 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1318
1319 let mut stale_message = String::new();
1320
1321 let action_log = self.action_log.read(cx);
1322
1323 for stale_file in action_log.stale_buffers(cx) {
1324 let Some(file) = stale_file.read(cx).file() else {
1325 continue;
1326 };
1327
1328 if stale_message.is_empty() {
1329 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1330 }
1331
1332 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1333 }
1334
1335 let mut content = Vec::with_capacity(2);
1336
1337 if !stale_message.is_empty() {
1338 content.push(stale_message.into());
1339 }
1340
1341 if !content.is_empty() {
1342 let context_message = LanguageModelRequestMessage {
1343 role: Role::User,
1344 content,
1345 cache: false,
1346 };
1347
1348 messages.push(context_message);
1349 }
1350 }
1351
1352 pub fn stream_completion(
1353 &mut self,
1354 request: LanguageModelRequest,
1355 model: Arc<dyn LanguageModel>,
1356 window: Option<AnyWindowHandle>,
1357 cx: &mut Context<Self>,
1358 ) {
1359 self.tool_use_limit_reached = false;
1360
1361 let pending_completion_id = post_inc(&mut self.completion_count);
1362 let mut request_callback_parameters = if self.request_callback.is_some() {
1363 Some((request.clone(), Vec::new()))
1364 } else {
1365 None
1366 };
1367 let prompt_id = self.last_prompt_id.clone();
1368 let tool_use_metadata = ToolUseMetadata {
1369 model: model.clone(),
1370 thread_id: self.id.clone(),
1371 prompt_id: prompt_id.clone(),
1372 };
1373
1374 self.last_received_chunk_at = Some(Instant::now());
1375
1376 let task = cx.spawn(async move |thread, cx| {
1377 let stream_completion_future = model.stream_completion(request, &cx);
1378 let initial_token_usage =
1379 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1380 let stream_completion = async {
1381 let mut events = stream_completion_future.await?;
1382
1383 let mut stop_reason = StopReason::EndTurn;
1384 let mut current_token_usage = TokenUsage::default();
1385
1386 thread
1387 .update(cx, |_thread, cx| {
1388 cx.emit(ThreadEvent::NewRequest);
1389 })
1390 .ok();
1391
1392 let mut request_assistant_message_id = None;
1393
1394 while let Some(event) = events.next().await {
1395 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1396 response_events
1397 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1398 }
1399
1400 thread.update(cx, |thread, cx| {
1401 let event = match event {
1402 Ok(event) => event,
1403 Err(LanguageModelCompletionError::BadInputJson {
1404 id,
1405 tool_name,
1406 raw_input: invalid_input_json,
1407 json_parse_error,
1408 }) => {
1409 thread.receive_invalid_tool_json(
1410 id,
1411 tool_name,
1412 invalid_input_json,
1413 json_parse_error,
1414 window,
1415 cx,
1416 );
1417 return Ok(());
1418 }
1419 Err(LanguageModelCompletionError::Other(error)) => {
1420 return Err(error);
1421 }
1422 };
1423
1424 match event {
1425 LanguageModelCompletionEvent::StartMessage { .. } => {
1426 request_assistant_message_id =
1427 Some(thread.insert_assistant_message(
1428 vec![MessageSegment::Text(String::new())],
1429 cx,
1430 ));
1431 }
1432 LanguageModelCompletionEvent::Stop(reason) => {
1433 stop_reason = reason;
1434 }
1435 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1436 thread.update_token_usage_at_last_message(token_usage);
1437 thread.cumulative_token_usage = thread.cumulative_token_usage
1438 + token_usage
1439 - current_token_usage;
1440 current_token_usage = token_usage;
1441 }
1442 LanguageModelCompletionEvent::Text(chunk) => {
1443 thread.received_chunk();
1444
1445 cx.emit(ThreadEvent::ReceivedTextChunk);
1446 if let Some(last_message) = thread.messages.last_mut() {
1447 if last_message.role == Role::Assistant
1448 && !thread.tool_use.has_tool_results(last_message.id)
1449 {
1450 last_message.push_text(&chunk);
1451 cx.emit(ThreadEvent::StreamedAssistantText(
1452 last_message.id,
1453 chunk,
1454 ));
1455 } else {
1456 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1457 // of a new Assistant response.
1458 //
1459 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1460 // will result in duplicating the text of the chunk in the rendered Markdown.
1461 request_assistant_message_id =
1462 Some(thread.insert_assistant_message(
1463 vec![MessageSegment::Text(chunk.to_string())],
1464 cx,
1465 ));
1466 };
1467 }
1468 }
1469 LanguageModelCompletionEvent::Thinking {
1470 text: chunk,
1471 signature,
1472 } => {
1473 thread.received_chunk();
1474
1475 if let Some(last_message) = thread.messages.last_mut() {
1476 if last_message.role == Role::Assistant
1477 && !thread.tool_use.has_tool_results(last_message.id)
1478 {
1479 last_message.push_thinking(&chunk, signature);
1480 cx.emit(ThreadEvent::StreamedAssistantThinking(
1481 last_message.id,
1482 chunk,
1483 ));
1484 } else {
1485 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1486 // of a new Assistant response.
1487 //
1488 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1489 // will result in duplicating the text of the chunk in the rendered Markdown.
1490 request_assistant_message_id =
1491 Some(thread.insert_assistant_message(
1492 vec![MessageSegment::Thinking {
1493 text: chunk.to_string(),
1494 signature,
1495 }],
1496 cx,
1497 ));
1498 };
1499 }
1500 }
1501 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1502 let last_assistant_message_id = request_assistant_message_id
1503 .unwrap_or_else(|| {
1504 let new_assistant_message_id =
1505 thread.insert_assistant_message(vec![], cx);
1506 request_assistant_message_id =
1507 Some(new_assistant_message_id);
1508 new_assistant_message_id
1509 });
1510
1511 let tool_use_id = tool_use.id.clone();
1512 let streamed_input = if tool_use.is_input_complete {
1513 None
1514 } else {
1515 Some((&tool_use.input).clone())
1516 };
1517
1518 let ui_text = thread.tool_use.request_tool_use(
1519 last_assistant_message_id,
1520 tool_use,
1521 tool_use_metadata.clone(),
1522 cx,
1523 );
1524
1525 if let Some(input) = streamed_input {
1526 cx.emit(ThreadEvent::StreamedToolUse {
1527 tool_use_id,
1528 ui_text,
1529 input,
1530 });
1531 }
1532 }
1533 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1534 if let Some(completion) = thread
1535 .pending_completions
1536 .iter_mut()
1537 .find(|completion| completion.id == pending_completion_id)
1538 {
1539 match status_update {
1540 CompletionRequestStatus::Queued {
1541 position,
1542 } => {
1543 completion.queue_state = QueueState::Queued { position };
1544 }
1545 CompletionRequestStatus::Started => {
1546 completion.queue_state = QueueState::Started;
1547 }
1548 CompletionRequestStatus::Failed {
1549 code, message, request_id
1550 } => {
1551 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1552 }
1553 CompletionRequestStatus::UsageUpdated {
1554 amount, limit
1555 } => {
1556 let usage = RequestUsage { limit, amount: amount as i32 };
1557
1558 thread.last_usage = Some(usage);
1559 }
1560 CompletionRequestStatus::ToolUseLimitReached => {
1561 thread.tool_use_limit_reached = true;
1562 }
1563 }
1564 }
1565 }
1566 }
1567
1568 thread.touch_updated_at();
1569 cx.emit(ThreadEvent::StreamedCompletion);
1570 cx.notify();
1571
1572 thread.auto_capture_telemetry(cx);
1573 Ok(())
1574 })??;
1575
1576 smol::future::yield_now().await;
1577 }
1578
1579 thread.update(cx, |thread, cx| {
1580 thread.last_received_chunk_at = None;
1581 thread
1582 .pending_completions
1583 .retain(|completion| completion.id != pending_completion_id);
1584
1585 // If there is a response without tool use, summarize the message. Otherwise,
1586 // allow two tool uses before summarizing.
1587 if thread.summary.is_none()
1588 && thread.messages.len() >= 2
1589 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1590 {
1591 thread.summarize(cx);
1592 }
1593 })?;
1594
1595 anyhow::Ok(stop_reason)
1596 };
1597
1598 let result = stream_completion.await;
1599
1600 thread
1601 .update(cx, |thread, cx| {
1602 thread.finalize_pending_checkpoint(cx);
1603 match result.as_ref() {
1604 Ok(stop_reason) => match stop_reason {
1605 StopReason::ToolUse => {
1606 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1607 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1608 }
1609 StopReason::EndTurn | StopReason::MaxTokens => {
1610 thread.project.update(cx, |project, cx| {
1611 project.set_agent_location(None, cx);
1612 });
1613 }
1614 },
1615 Err(error) => {
1616 thread.project.update(cx, |project, cx| {
1617 project.set_agent_location(None, cx);
1618 });
1619
1620 if error.is::<PaymentRequiredError>() {
1621 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1622 } else if error.is::<MaxMonthlySpendReachedError>() {
1623 cx.emit(ThreadEvent::ShowError(
1624 ThreadError::MaxMonthlySpendReached,
1625 ));
1626 } else if let Some(error) =
1627 error.downcast_ref::<ModelRequestLimitReachedError>()
1628 {
1629 cx.emit(ThreadEvent::ShowError(
1630 ThreadError::ModelRequestLimitReached { plan: error.plan },
1631 ));
1632 } else if let Some(known_error) =
1633 error.downcast_ref::<LanguageModelKnownError>()
1634 {
1635 match known_error {
1636 LanguageModelKnownError::ContextWindowLimitExceeded {
1637 tokens,
1638 } => {
1639 thread.exceeded_window_error = Some(ExceededWindowError {
1640 model_id: model.id(),
1641 token_count: *tokens,
1642 });
1643 cx.notify();
1644 }
1645 }
1646 } else {
1647 let error_message = error
1648 .chain()
1649 .map(|err| err.to_string())
1650 .collect::<Vec<_>>()
1651 .join("\n");
1652 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1653 header: "Error interacting with language model".into(),
1654 message: SharedString::from(error_message.clone()),
1655 }));
1656 }
1657
1658 thread.cancel_last_completion(window, cx);
1659 }
1660 }
1661 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1662
1663 if let Some((request_callback, (request, response_events))) = thread
1664 .request_callback
1665 .as_mut()
1666 .zip(request_callback_parameters.as_ref())
1667 {
1668 request_callback(request, response_events);
1669 }
1670
1671 thread.auto_capture_telemetry(cx);
1672
1673 if let Ok(initial_usage) = initial_token_usage {
1674 let usage = thread.cumulative_token_usage - initial_usage;
1675
1676 telemetry::event!(
1677 "Assistant Thread Completion",
1678 thread_id = thread.id().to_string(),
1679 prompt_id = prompt_id,
1680 model = model.telemetry_id(),
1681 model_provider = model.provider_id().to_string(),
1682 input_tokens = usage.input_tokens,
1683 output_tokens = usage.output_tokens,
1684 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1685 cache_read_input_tokens = usage.cache_read_input_tokens,
1686 );
1687 }
1688 })
1689 .ok();
1690 });
1691
1692 self.pending_completions.push(PendingCompletion {
1693 id: pending_completion_id,
1694 queue_state: QueueState::Sending,
1695 _task: task,
1696 });
1697 }
1698
1699 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1700 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1701 return;
1702 };
1703
1704 if !model.provider.is_authenticated(cx) {
1705 return;
1706 }
1707
1708 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1709 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1710 If the conversation is about a specific subject, include it in the title. \
1711 Be descriptive. DO NOT speak in the first person.";
1712
1713 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1714
1715 self.pending_summary = cx.spawn(async move |this, cx| {
1716 async move {
1717 let mut messages = model.model.stream_completion(request, &cx).await?;
1718
1719 let mut new_summary = String::new();
1720 while let Some(event) = messages.next().await {
1721 let event = event?;
1722 let text = match event {
1723 LanguageModelCompletionEvent::Text(text) => text,
1724 LanguageModelCompletionEvent::StatusUpdate(
1725 CompletionRequestStatus::UsageUpdated { amount, limit },
1726 ) => {
1727 this.update(cx, |thread, _cx| {
1728 thread.last_usage = Some(RequestUsage {
1729 limit,
1730 amount: amount as i32,
1731 });
1732 })?;
1733 continue;
1734 }
1735 _ => continue,
1736 };
1737
1738 let mut lines = text.lines();
1739 new_summary.extend(lines.next());
1740
1741 // Stop if the LLM generated multiple lines.
1742 if lines.next().is_some() {
1743 break;
1744 }
1745 }
1746
1747 this.update(cx, |this, cx| {
1748 if !new_summary.is_empty() {
1749 this.summary = Some(new_summary.into());
1750 }
1751
1752 cx.emit(ThreadEvent::SummaryGenerated);
1753 })?;
1754
1755 anyhow::Ok(())
1756 }
1757 .log_err()
1758 .await
1759 });
1760 }
1761
1762 pub fn start_generating_detailed_summary_if_needed(
1763 &mut self,
1764 thread_store: WeakEntity<ThreadStore>,
1765 cx: &mut Context<Self>,
1766 ) {
1767 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1768 return;
1769 };
1770
1771 match &*self.detailed_summary_rx.borrow() {
1772 DetailedSummaryState::Generating { message_id, .. }
1773 | DetailedSummaryState::Generated { message_id, .. }
1774 if *message_id == last_message_id =>
1775 {
1776 // Already up-to-date
1777 return;
1778 }
1779 _ => {}
1780 }
1781
1782 let Some(ConfiguredModel { model, provider }) =
1783 LanguageModelRegistry::read_global(cx).thread_summary_model()
1784 else {
1785 return;
1786 };
1787
1788 if !provider.is_authenticated(cx) {
1789 return;
1790 }
1791
1792 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1793 1. A brief overview of what was discussed\n\
1794 2. Key facts or information discovered\n\
1795 3. Outcomes or conclusions reached\n\
1796 4. Any action items or next steps if any\n\
1797 Format it in Markdown with headings and bullet points.";
1798
1799 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1800
1801 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1802 message_id: last_message_id,
1803 };
1804
1805 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1806 // be better to allow the old task to complete, but this would require logic for choosing
1807 // which result to prefer (the old task could complete after the new one, resulting in a
1808 // stale summary).
1809 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1810 let stream = model.stream_completion_text(request, &cx);
1811 let Some(mut messages) = stream.await.log_err() else {
1812 thread
1813 .update(cx, |thread, _cx| {
1814 *thread.detailed_summary_tx.borrow_mut() =
1815 DetailedSummaryState::NotGenerated;
1816 })
1817 .ok()?;
1818 return None;
1819 };
1820
1821 let mut new_detailed_summary = String::new();
1822
1823 while let Some(chunk) = messages.stream.next().await {
1824 if let Some(chunk) = chunk.log_err() {
1825 new_detailed_summary.push_str(&chunk);
1826 }
1827 }
1828
1829 thread
1830 .update(cx, |thread, _cx| {
1831 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1832 text: new_detailed_summary.into(),
1833 message_id: last_message_id,
1834 };
1835 })
1836 .ok()?;
1837
1838 // Save thread so its summary can be reused later
1839 if let Some(thread) = thread.upgrade() {
1840 if let Ok(Ok(save_task)) = cx.update(|cx| {
1841 thread_store
1842 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1843 }) {
1844 save_task.await.log_err();
1845 }
1846 }
1847
1848 Some(())
1849 });
1850 }
1851
1852 pub async fn wait_for_detailed_summary_or_text(
1853 this: &Entity<Self>,
1854 cx: &mut AsyncApp,
1855 ) -> Option<SharedString> {
1856 let mut detailed_summary_rx = this
1857 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1858 .ok()?;
1859 loop {
1860 match detailed_summary_rx.recv().await? {
1861 DetailedSummaryState::Generating { .. } => {}
1862 DetailedSummaryState::NotGenerated => {
1863 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1864 }
1865 DetailedSummaryState::Generated { text, .. } => return Some(text),
1866 }
1867 }
1868 }
1869
1870 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1871 self.detailed_summary_rx
1872 .borrow()
1873 .text()
1874 .unwrap_or_else(|| self.text().into())
1875 }
1876
1877 pub fn is_generating_detailed_summary(&self) -> bool {
1878 matches!(
1879 &*self.detailed_summary_rx.borrow(),
1880 DetailedSummaryState::Generating { .. }
1881 )
1882 }
1883
1884 pub fn use_pending_tools(
1885 &mut self,
1886 window: Option<AnyWindowHandle>,
1887 cx: &mut Context<Self>,
1888 model: Arc<dyn LanguageModel>,
1889 ) -> Vec<PendingToolUse> {
1890 self.auto_capture_telemetry(cx);
1891 let request = self.to_completion_request(model.clone(), cx);
1892 let messages = Arc::new(request.messages);
1893 let pending_tool_uses = self
1894 .tool_use
1895 .pending_tool_uses()
1896 .into_iter()
1897 .filter(|tool_use| tool_use.status.is_idle())
1898 .cloned()
1899 .collect::<Vec<_>>();
1900
1901 for tool_use in pending_tool_uses.iter() {
1902 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1903 if tool.needs_confirmation(&tool_use.input, cx)
1904 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1905 {
1906 self.tool_use.confirm_tool_use(
1907 tool_use.id.clone(),
1908 tool_use.ui_text.clone(),
1909 tool_use.input.clone(),
1910 messages.clone(),
1911 tool,
1912 );
1913 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1914 } else {
1915 self.run_tool(
1916 tool_use.id.clone(),
1917 tool_use.ui_text.clone(),
1918 tool_use.input.clone(),
1919 &messages,
1920 tool,
1921 model.clone(),
1922 window,
1923 cx,
1924 );
1925 }
1926 } else {
1927 self.handle_hallucinated_tool_use(
1928 tool_use.id.clone(),
1929 tool_use.name.clone(),
1930 window,
1931 cx,
1932 );
1933 }
1934 }
1935
1936 pending_tool_uses
1937 }
1938
1939 pub fn handle_hallucinated_tool_use(
1940 &mut self,
1941 tool_use_id: LanguageModelToolUseId,
1942 hallucinated_tool_name: Arc<str>,
1943 window: Option<AnyWindowHandle>,
1944 cx: &mut Context<Thread>,
1945 ) {
1946 let available_tools = self.tools.read(cx).enabled_tools(cx);
1947
1948 let tool_list = available_tools
1949 .iter()
1950 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1951 .collect::<Vec<_>>()
1952 .join("\n");
1953
1954 let error_message = format!(
1955 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1956 hallucinated_tool_name, tool_list
1957 );
1958
1959 let pending_tool_use = self.tool_use.insert_tool_output(
1960 tool_use_id.clone(),
1961 hallucinated_tool_name,
1962 Err(anyhow!("Missing tool call: {error_message}")),
1963 self.configured_model.as_ref(),
1964 );
1965
1966 cx.emit(ThreadEvent::MissingToolUse {
1967 tool_use_id: tool_use_id.clone(),
1968 ui_text: error_message.into(),
1969 });
1970
1971 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1972 }
1973
1974 pub fn receive_invalid_tool_json(
1975 &mut self,
1976 tool_use_id: LanguageModelToolUseId,
1977 tool_name: Arc<str>,
1978 invalid_json: Arc<str>,
1979 error: String,
1980 window: Option<AnyWindowHandle>,
1981 cx: &mut Context<Thread>,
1982 ) {
1983 log::error!("The model returned invalid input JSON: {invalid_json}");
1984
1985 let pending_tool_use = self.tool_use.insert_tool_output(
1986 tool_use_id.clone(),
1987 tool_name,
1988 Err(anyhow!("Error parsing input JSON: {error}")),
1989 self.configured_model.as_ref(),
1990 );
1991 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1992 pending_tool_use.ui_text.clone()
1993 } else {
1994 log::error!(
1995 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1996 );
1997 format!("Unknown tool {}", tool_use_id).into()
1998 };
1999
2000 cx.emit(ThreadEvent::InvalidToolInput {
2001 tool_use_id: tool_use_id.clone(),
2002 ui_text,
2003 invalid_input_json: invalid_json,
2004 });
2005
2006 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2007 }
2008
2009 pub fn run_tool(
2010 &mut self,
2011 tool_use_id: LanguageModelToolUseId,
2012 ui_text: impl Into<SharedString>,
2013 input: serde_json::Value,
2014 messages: &[LanguageModelRequestMessage],
2015 tool: Arc<dyn Tool>,
2016 model: Arc<dyn LanguageModel>,
2017 window: Option<AnyWindowHandle>,
2018 cx: &mut Context<Thread>,
2019 ) {
2020 let task = self.spawn_tool_use(
2021 tool_use_id.clone(),
2022 messages,
2023 input,
2024 tool,
2025 model,
2026 window,
2027 cx,
2028 );
2029 self.tool_use
2030 .run_pending_tool(tool_use_id, ui_text.into(), task);
2031 }
2032
2033 fn spawn_tool_use(
2034 &mut self,
2035 tool_use_id: LanguageModelToolUseId,
2036 messages: &[LanguageModelRequestMessage],
2037 input: serde_json::Value,
2038 tool: Arc<dyn Tool>,
2039 model: Arc<dyn LanguageModel>,
2040 window: Option<AnyWindowHandle>,
2041 cx: &mut Context<Thread>,
2042 ) -> Task<()> {
2043 let tool_name: Arc<str> = tool.name().into();
2044
2045 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2046 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2047 } else {
2048 tool.run(
2049 input,
2050 messages,
2051 self.project.clone(),
2052 self.action_log.clone(),
2053 model,
2054 window,
2055 cx,
2056 )
2057 };
2058
2059 // Store the card separately if it exists
2060 if let Some(card) = tool_result.card.clone() {
2061 self.tool_use
2062 .insert_tool_result_card(tool_use_id.clone(), card);
2063 }
2064
2065 cx.spawn({
2066 async move |thread: WeakEntity<Thread>, cx| {
2067 let output = tool_result.output.await;
2068
2069 thread
2070 .update(cx, |thread, cx| {
2071 let pending_tool_use = thread.tool_use.insert_tool_output(
2072 tool_use_id.clone(),
2073 tool_name,
2074 output,
2075 thread.configured_model.as_ref(),
2076 );
2077 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2078 })
2079 .ok();
2080 }
2081 })
2082 }
2083
2084 fn tool_finished(
2085 &mut self,
2086 tool_use_id: LanguageModelToolUseId,
2087 pending_tool_use: Option<PendingToolUse>,
2088 canceled: bool,
2089 window: Option<AnyWindowHandle>,
2090 cx: &mut Context<Self>,
2091 ) {
2092 if self.all_tools_finished() {
2093 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2094 if !canceled {
2095 self.send_to_model(model.clone(), window, cx);
2096 }
2097 self.auto_capture_telemetry(cx);
2098 }
2099 }
2100
2101 cx.emit(ThreadEvent::ToolFinished {
2102 tool_use_id,
2103 pending_tool_use,
2104 });
2105 }
2106
2107 /// Cancels the last pending completion, if there are any pending.
2108 ///
2109 /// Returns whether a completion was canceled.
2110 pub fn cancel_last_completion(
2111 &mut self,
2112 window: Option<AnyWindowHandle>,
2113 cx: &mut Context<Self>,
2114 ) -> bool {
2115 let mut canceled = self.pending_completions.pop().is_some();
2116
2117 for pending_tool_use in self.tool_use.cancel_pending() {
2118 canceled = true;
2119 self.tool_finished(
2120 pending_tool_use.id.clone(),
2121 Some(pending_tool_use),
2122 true,
2123 window,
2124 cx,
2125 );
2126 }
2127
2128 self.finalize_pending_checkpoint(cx);
2129
2130 if canceled {
2131 cx.emit(ThreadEvent::CompletionCanceled);
2132 }
2133
2134 canceled
2135 }
2136
2137 /// Signals that any in-progress editing should be canceled.
2138 ///
2139 /// This method is used to notify listeners (like ActiveThread) that
2140 /// they should cancel any editing operations.
2141 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2142 cx.emit(ThreadEvent::CancelEditing);
2143 }
2144
2145 pub fn feedback(&self) -> Option<ThreadFeedback> {
2146 self.feedback
2147 }
2148
2149 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2150 self.message_feedback.get(&message_id).copied()
2151 }
2152
2153 pub fn report_message_feedback(
2154 &mut self,
2155 message_id: MessageId,
2156 feedback: ThreadFeedback,
2157 cx: &mut Context<Self>,
2158 ) -> Task<Result<()>> {
2159 if self.message_feedback.get(&message_id) == Some(&feedback) {
2160 return Task::ready(Ok(()));
2161 }
2162
2163 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2164 let serialized_thread = self.serialize(cx);
2165 let thread_id = self.id().clone();
2166 let client = self.project.read(cx).client();
2167
2168 let enabled_tool_names: Vec<String> = self
2169 .tools()
2170 .read(cx)
2171 .enabled_tools(cx)
2172 .iter()
2173 .map(|tool| tool.name().to_string())
2174 .collect();
2175
2176 self.message_feedback.insert(message_id, feedback);
2177
2178 cx.notify();
2179
2180 let message_content = self
2181 .message(message_id)
2182 .map(|msg| msg.to_string())
2183 .unwrap_or_default();
2184
2185 cx.background_spawn(async move {
2186 let final_project_snapshot = final_project_snapshot.await;
2187 let serialized_thread = serialized_thread.await?;
2188 let thread_data =
2189 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2190
2191 let rating = match feedback {
2192 ThreadFeedback::Positive => "positive",
2193 ThreadFeedback::Negative => "negative",
2194 };
2195 telemetry::event!(
2196 "Assistant Thread Rated",
2197 rating,
2198 thread_id,
2199 enabled_tool_names,
2200 message_id = message_id.0,
2201 message_content,
2202 thread_data,
2203 final_project_snapshot
2204 );
2205 client.telemetry().flush_events().await;
2206
2207 Ok(())
2208 })
2209 }
2210
2211 pub fn report_feedback(
2212 &mut self,
2213 feedback: ThreadFeedback,
2214 cx: &mut Context<Self>,
2215 ) -> Task<Result<()>> {
2216 let last_assistant_message_id = self
2217 .messages
2218 .iter()
2219 .rev()
2220 .find(|msg| msg.role == Role::Assistant)
2221 .map(|msg| msg.id);
2222
2223 if let Some(message_id) = last_assistant_message_id {
2224 self.report_message_feedback(message_id, feedback, cx)
2225 } else {
2226 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2227 let serialized_thread = self.serialize(cx);
2228 let thread_id = self.id().clone();
2229 let client = self.project.read(cx).client();
2230 self.feedback = Some(feedback);
2231 cx.notify();
2232
2233 cx.background_spawn(async move {
2234 let final_project_snapshot = final_project_snapshot.await;
2235 let serialized_thread = serialized_thread.await?;
2236 let thread_data = serde_json::to_value(serialized_thread)
2237 .unwrap_or_else(|_| serde_json::Value::Null);
2238
2239 let rating = match feedback {
2240 ThreadFeedback::Positive => "positive",
2241 ThreadFeedback::Negative => "negative",
2242 };
2243 telemetry::event!(
2244 "Assistant Thread Rated",
2245 rating,
2246 thread_id,
2247 thread_data,
2248 final_project_snapshot
2249 );
2250 client.telemetry().flush_events().await;
2251
2252 Ok(())
2253 })
2254 }
2255 }
2256
2257 /// Create a snapshot of the current project state including git information and unsaved buffers.
2258 fn project_snapshot(
2259 project: Entity<Project>,
2260 cx: &mut Context<Self>,
2261 ) -> Task<Arc<ProjectSnapshot>> {
2262 let git_store = project.read(cx).git_store().clone();
2263 let worktree_snapshots: Vec<_> = project
2264 .read(cx)
2265 .visible_worktrees(cx)
2266 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2267 .collect();
2268
2269 cx.spawn(async move |_, cx| {
2270 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2271
2272 let mut unsaved_buffers = Vec::new();
2273 cx.update(|app_cx| {
2274 let buffer_store = project.read(app_cx).buffer_store();
2275 for buffer_handle in buffer_store.read(app_cx).buffers() {
2276 let buffer = buffer_handle.read(app_cx);
2277 if buffer.is_dirty() {
2278 if let Some(file) = buffer.file() {
2279 let path = file.path().to_string_lossy().to_string();
2280 unsaved_buffers.push(path);
2281 }
2282 }
2283 }
2284 })
2285 .ok();
2286
2287 Arc::new(ProjectSnapshot {
2288 worktree_snapshots,
2289 unsaved_buffer_paths: unsaved_buffers,
2290 timestamp: Utc::now(),
2291 })
2292 })
2293 }
2294
2295 fn worktree_snapshot(
2296 worktree: Entity<project::Worktree>,
2297 git_store: Entity<GitStore>,
2298 cx: &App,
2299 ) -> Task<WorktreeSnapshot> {
2300 cx.spawn(async move |cx| {
2301 // Get worktree path and snapshot
2302 let worktree_info = cx.update(|app_cx| {
2303 let worktree = worktree.read(app_cx);
2304 let path = worktree.abs_path().to_string_lossy().to_string();
2305 let snapshot = worktree.snapshot();
2306 (path, snapshot)
2307 });
2308
2309 let Ok((worktree_path, _snapshot)) = worktree_info else {
2310 return WorktreeSnapshot {
2311 worktree_path: String::new(),
2312 git_state: None,
2313 };
2314 };
2315
2316 let git_state = git_store
2317 .update(cx, |git_store, cx| {
2318 git_store
2319 .repositories()
2320 .values()
2321 .find(|repo| {
2322 repo.read(cx)
2323 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2324 .is_some()
2325 })
2326 .cloned()
2327 })
2328 .ok()
2329 .flatten()
2330 .map(|repo| {
2331 repo.update(cx, |repo, _| {
2332 let current_branch =
2333 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2334 repo.send_job(None, |state, _| async move {
2335 let RepositoryState::Local { backend, .. } = state else {
2336 return GitState {
2337 remote_url: None,
2338 head_sha: None,
2339 current_branch,
2340 diff: None,
2341 };
2342 };
2343
2344 let remote_url = backend.remote_url("origin");
2345 let head_sha = backend.head_sha().await;
2346 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2347
2348 GitState {
2349 remote_url,
2350 head_sha,
2351 current_branch,
2352 diff,
2353 }
2354 })
2355 })
2356 });
2357
2358 let git_state = match git_state {
2359 Some(git_state) => match git_state.ok() {
2360 Some(git_state) => git_state.await.ok(),
2361 None => None,
2362 },
2363 None => None,
2364 };
2365
2366 WorktreeSnapshot {
2367 worktree_path,
2368 git_state,
2369 }
2370 })
2371 }
2372
2373 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2374 let mut markdown = Vec::new();
2375
2376 if let Some(summary) = self.summary() {
2377 writeln!(markdown, "# {summary}\n")?;
2378 };
2379
2380 for message in self.messages() {
2381 writeln!(
2382 markdown,
2383 "## {role}\n",
2384 role = match message.role {
2385 Role::User => "User",
2386 Role::Assistant => "Agent",
2387 Role::System => "System",
2388 }
2389 )?;
2390
2391 if !message.loaded_context.text.is_empty() {
2392 writeln!(markdown, "{}", message.loaded_context.text)?;
2393 }
2394
2395 if !message.loaded_context.images.is_empty() {
2396 writeln!(
2397 markdown,
2398 "\n{} images attached as context.\n",
2399 message.loaded_context.images.len()
2400 )?;
2401 }
2402
2403 for segment in &message.segments {
2404 match segment {
2405 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2406 MessageSegment::Thinking { text, .. } => {
2407 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2408 }
2409 MessageSegment::RedactedThinking(_) => {}
2410 }
2411 }
2412
2413 for tool_use in self.tool_uses_for_message(message.id, cx) {
2414 writeln!(
2415 markdown,
2416 "**Use Tool: {} ({})**",
2417 tool_use.name, tool_use.id
2418 )?;
2419 writeln!(markdown, "```json")?;
2420 writeln!(
2421 markdown,
2422 "{}",
2423 serde_json::to_string_pretty(&tool_use.input)?
2424 )?;
2425 writeln!(markdown, "```")?;
2426 }
2427
2428 for tool_result in self.tool_results_for_message(message.id) {
2429 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2430 if tool_result.is_error {
2431 write!(markdown, " (Error)")?;
2432 }
2433
2434 writeln!(markdown, "**\n")?;
2435 writeln!(markdown, "{}", tool_result.content)?;
2436 }
2437 }
2438
2439 Ok(String::from_utf8_lossy(&markdown).to_string())
2440 }
2441
2442 pub fn keep_edits_in_range(
2443 &mut self,
2444 buffer: Entity<language::Buffer>,
2445 buffer_range: Range<language::Anchor>,
2446 cx: &mut Context<Self>,
2447 ) {
2448 self.action_log.update(cx, |action_log, cx| {
2449 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2450 });
2451 }
2452
2453 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2454 self.action_log
2455 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2456 }
2457
2458 pub fn reject_edits_in_ranges(
2459 &mut self,
2460 buffer: Entity<language::Buffer>,
2461 buffer_ranges: Vec<Range<language::Anchor>>,
2462 cx: &mut Context<Self>,
2463 ) -> Task<Result<()>> {
2464 self.action_log.update(cx, |action_log, cx| {
2465 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2466 })
2467 }
2468
2469 pub fn action_log(&self) -> &Entity<ActionLog> {
2470 &self.action_log
2471 }
2472
2473 pub fn project(&self) -> &Entity<Project> {
2474 &self.project
2475 }
2476
2477 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2478 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2479 return;
2480 }
2481
2482 let now = Instant::now();
2483 if let Some(last) = self.last_auto_capture_at {
2484 if now.duration_since(last).as_secs() < 10 {
2485 return;
2486 }
2487 }
2488
2489 self.last_auto_capture_at = Some(now);
2490
2491 let thread_id = self.id().clone();
2492 let github_login = self
2493 .project
2494 .read(cx)
2495 .user_store()
2496 .read(cx)
2497 .current_user()
2498 .map(|user| user.github_login.clone());
2499 let client = self.project.read(cx).client().clone();
2500 let serialize_task = self.serialize(cx);
2501
2502 cx.background_executor()
2503 .spawn(async move {
2504 if let Ok(serialized_thread) = serialize_task.await {
2505 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2506 telemetry::event!(
2507 "Agent Thread Auto-Captured",
2508 thread_id = thread_id.to_string(),
2509 thread_data = thread_data,
2510 auto_capture_reason = "tracked_user",
2511 github_login = github_login
2512 );
2513
2514 client.telemetry().flush_events().await;
2515 }
2516 }
2517 })
2518 .detach();
2519 }
2520
2521 pub fn cumulative_token_usage(&self) -> TokenUsage {
2522 self.cumulative_token_usage
2523 }
2524
2525 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2526 let Some(model) = self.configured_model.as_ref() else {
2527 return TotalTokenUsage::default();
2528 };
2529
2530 let max = model.model.max_token_count();
2531
2532 let index = self
2533 .messages
2534 .iter()
2535 .position(|msg| msg.id == message_id)
2536 .unwrap_or(0);
2537
2538 if index == 0 {
2539 return TotalTokenUsage { total: 0, max };
2540 }
2541
2542 let token_usage = &self
2543 .request_token_usage
2544 .get(index - 1)
2545 .cloned()
2546 .unwrap_or_default();
2547
2548 TotalTokenUsage {
2549 total: token_usage.total_tokens() as usize,
2550 max,
2551 }
2552 }
2553
2554 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2555 let model = self.configured_model.as_ref()?;
2556
2557 let max = model.model.max_token_count();
2558
2559 if let Some(exceeded_error) = &self.exceeded_window_error {
2560 if model.model.id() == exceeded_error.model_id {
2561 return Some(TotalTokenUsage {
2562 total: exceeded_error.token_count,
2563 max,
2564 });
2565 }
2566 }
2567
2568 let total = self
2569 .token_usage_at_last_message()
2570 .unwrap_or_default()
2571 .total_tokens() as usize;
2572
2573 Some(TotalTokenUsage { total, max })
2574 }
2575
2576 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2577 self.request_token_usage
2578 .get(self.messages.len().saturating_sub(1))
2579 .or_else(|| self.request_token_usage.last())
2580 .cloned()
2581 }
2582
2583 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2584 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2585 self.request_token_usage
2586 .resize(self.messages.len(), placeholder);
2587
2588 if let Some(last) = self.request_token_usage.last_mut() {
2589 *last = token_usage;
2590 }
2591 }
2592
2593 pub fn deny_tool_use(
2594 &mut self,
2595 tool_use_id: LanguageModelToolUseId,
2596 tool_name: Arc<str>,
2597 window: Option<AnyWindowHandle>,
2598 cx: &mut Context<Self>,
2599 ) {
2600 let err = Err(anyhow::anyhow!(
2601 "Permission to run tool action denied by user"
2602 ));
2603
2604 self.tool_use.insert_tool_output(
2605 tool_use_id.clone(),
2606 tool_name,
2607 err,
2608 self.configured_model.as_ref(),
2609 );
2610 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2611 }
2612}
2613
2614#[derive(Debug, Clone, Error)]
2615pub enum ThreadError {
2616 #[error("Payment required")]
2617 PaymentRequired,
2618 #[error("Max monthly spend reached")]
2619 MaxMonthlySpendReached,
2620 #[error("Model request limit reached")]
2621 ModelRequestLimitReached { plan: Plan },
2622 #[error("Message {header}: {message}")]
2623 Message {
2624 header: SharedString,
2625 message: SharedString,
2626 },
2627}
2628
2629#[derive(Debug, Clone)]
2630pub enum ThreadEvent {
2631 ShowError(ThreadError),
2632 StreamedCompletion,
2633 ReceivedTextChunk,
2634 NewRequest,
2635 StreamedAssistantText(MessageId, String),
2636 StreamedAssistantThinking(MessageId, String),
2637 StreamedToolUse {
2638 tool_use_id: LanguageModelToolUseId,
2639 ui_text: Arc<str>,
2640 input: serde_json::Value,
2641 },
2642 MissingToolUse {
2643 tool_use_id: LanguageModelToolUseId,
2644 ui_text: Arc<str>,
2645 },
2646 InvalidToolInput {
2647 tool_use_id: LanguageModelToolUseId,
2648 ui_text: Arc<str>,
2649 invalid_input_json: Arc<str>,
2650 },
2651 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2652 MessageAdded(MessageId),
2653 MessageEdited(MessageId),
2654 MessageDeleted(MessageId),
2655 SummaryGenerated,
2656 SummaryChanged,
2657 UsePendingTools {
2658 tool_uses: Vec<PendingToolUse>,
2659 },
2660 ToolFinished {
2661 #[allow(unused)]
2662 tool_use_id: LanguageModelToolUseId,
2663 /// The pending tool use that corresponds to this tool.
2664 pending_tool_use: Option<PendingToolUse>,
2665 },
2666 CheckpointChanged,
2667 ToolConfirmationNeeded,
2668 CancelEditing,
2669 CompletionCanceled,
2670}
2671
2672impl EventEmitter<ThreadEvent> for Thread {}
2673
2674struct PendingCompletion {
2675 id: usize,
2676 queue_state: QueueState,
2677 _task: Task<()>,
2678}
2679
2680#[cfg(test)]
2681mod tests {
2682 use super::*;
2683 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2684 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2685 use assistant_tool::ToolRegistry;
2686 use editor::EditorSettings;
2687 use gpui::TestAppContext;
2688 use language_model::fake_provider::FakeLanguageModel;
2689 use project::{FakeFs, Project};
2690 use prompt_store::PromptBuilder;
2691 use serde_json::json;
2692 use settings::{Settings, SettingsStore};
2693 use std::sync::Arc;
2694 use theme::ThemeSettings;
2695 use util::path;
2696 use workspace::Workspace;
2697
2698 #[gpui::test]
2699 async fn test_message_with_context(cx: &mut TestAppContext) {
2700 init_test_settings(cx);
2701
2702 let project = create_test_project(
2703 cx,
2704 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2705 )
2706 .await;
2707
2708 let (_workspace, _thread_store, thread, context_store, model) =
2709 setup_test_environment(cx, project.clone()).await;
2710
2711 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2712 .await
2713 .unwrap();
2714
2715 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2716 let loaded_context = cx
2717 .update(|cx| load_context(vec![context], &project, &None, cx))
2718 .await;
2719
2720 // Insert user message with context
2721 let message_id = thread.update(cx, |thread, cx| {
2722 thread.insert_user_message(
2723 "Please explain this code",
2724 loaded_context,
2725 None,
2726 Vec::new(),
2727 cx,
2728 )
2729 });
2730
2731 // Check content and context in message object
2732 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2733
2734 // Use different path format strings based on platform for the test
2735 #[cfg(windows)]
2736 let path_part = r"test\code.rs";
2737 #[cfg(not(windows))]
2738 let path_part = "test/code.rs";
2739
2740 let expected_context = format!(
2741 r#"
2742<context>
2743The following items were attached by the user. They are up-to-date and don't need to be re-read.
2744
2745<files>
2746```rs {path_part}
2747fn main() {{
2748 println!("Hello, world!");
2749}}
2750```
2751</files>
2752</context>
2753"#
2754 );
2755
2756 assert_eq!(message.role, Role::User);
2757 assert_eq!(message.segments.len(), 1);
2758 assert_eq!(
2759 message.segments[0],
2760 MessageSegment::Text("Please explain this code".to_string())
2761 );
2762 assert_eq!(message.loaded_context.text, expected_context);
2763
2764 // Check message in request
2765 let request = thread.update(cx, |thread, cx| {
2766 thread.to_completion_request(model.clone(), cx)
2767 });
2768
2769 assert_eq!(request.messages.len(), 2);
2770 let expected_full_message = format!("{}Please explain this code", expected_context);
2771 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2772 }
2773
2774 #[gpui::test]
2775 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2776 init_test_settings(cx);
2777
2778 let project = create_test_project(
2779 cx,
2780 json!({
2781 "file1.rs": "fn function1() {}\n",
2782 "file2.rs": "fn function2() {}\n",
2783 "file3.rs": "fn function3() {}\n",
2784 "file4.rs": "fn function4() {}\n",
2785 }),
2786 )
2787 .await;
2788
2789 let (_, _thread_store, thread, context_store, model) =
2790 setup_test_environment(cx, project.clone()).await;
2791
2792 // First message with context 1
2793 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2794 .await
2795 .unwrap();
2796 let new_contexts = context_store.update(cx, |store, cx| {
2797 store.new_context_for_thread(thread.read(cx), None)
2798 });
2799 assert_eq!(new_contexts.len(), 1);
2800 let loaded_context = cx
2801 .update(|cx| load_context(new_contexts, &project, &None, cx))
2802 .await;
2803 let message1_id = thread.update(cx, |thread, cx| {
2804 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2805 });
2806
2807 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2808 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2809 .await
2810 .unwrap();
2811 let new_contexts = context_store.update(cx, |store, cx| {
2812 store.new_context_for_thread(thread.read(cx), None)
2813 });
2814 assert_eq!(new_contexts.len(), 1);
2815 let loaded_context = cx
2816 .update(|cx| load_context(new_contexts, &project, &None, cx))
2817 .await;
2818 let message2_id = thread.update(cx, |thread, cx| {
2819 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2820 });
2821
2822 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2823 //
2824 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2825 .await
2826 .unwrap();
2827 let new_contexts = context_store.update(cx, |store, cx| {
2828 store.new_context_for_thread(thread.read(cx), None)
2829 });
2830 assert_eq!(new_contexts.len(), 1);
2831 let loaded_context = cx
2832 .update(|cx| load_context(new_contexts, &project, &None, cx))
2833 .await;
2834 let message3_id = thread.update(cx, |thread, cx| {
2835 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2836 });
2837
2838 // Check what contexts are included in each message
2839 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2840 (
2841 thread.message(message1_id).unwrap().clone(),
2842 thread.message(message2_id).unwrap().clone(),
2843 thread.message(message3_id).unwrap().clone(),
2844 )
2845 });
2846
2847 // First message should include context 1
2848 assert!(message1.loaded_context.text.contains("file1.rs"));
2849
2850 // Second message should include only context 2 (not 1)
2851 assert!(!message2.loaded_context.text.contains("file1.rs"));
2852 assert!(message2.loaded_context.text.contains("file2.rs"));
2853
2854 // Third message should include only context 3 (not 1 or 2)
2855 assert!(!message3.loaded_context.text.contains("file1.rs"));
2856 assert!(!message3.loaded_context.text.contains("file2.rs"));
2857 assert!(message3.loaded_context.text.contains("file3.rs"));
2858
2859 // Check entire request to make sure all contexts are properly included
2860 let request = thread.update(cx, |thread, cx| {
2861 thread.to_completion_request(model.clone(), cx)
2862 });
2863
2864 // The request should contain all 3 messages
2865 assert_eq!(request.messages.len(), 4);
2866
2867 // Check that the contexts are properly formatted in each message
2868 assert!(request.messages[1].string_contents().contains("file1.rs"));
2869 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2870 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2871
2872 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2873 assert!(request.messages[2].string_contents().contains("file2.rs"));
2874 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2875
2876 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2877 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2878 assert!(request.messages[3].string_contents().contains("file3.rs"));
2879
2880 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2881 .await
2882 .unwrap();
2883 let new_contexts = context_store.update(cx, |store, cx| {
2884 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2885 });
2886 assert_eq!(new_contexts.len(), 3);
2887 let loaded_context = cx
2888 .update(|cx| load_context(new_contexts, &project, &None, cx))
2889 .await
2890 .loaded_context;
2891
2892 assert!(!loaded_context.text.contains("file1.rs"));
2893 assert!(loaded_context.text.contains("file2.rs"));
2894 assert!(loaded_context.text.contains("file3.rs"));
2895 assert!(loaded_context.text.contains("file4.rs"));
2896
2897 let new_contexts = context_store.update(cx, |store, cx| {
2898 // Remove file4.rs
2899 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2900 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2901 });
2902 assert_eq!(new_contexts.len(), 2);
2903 let loaded_context = cx
2904 .update(|cx| load_context(new_contexts, &project, &None, cx))
2905 .await
2906 .loaded_context;
2907
2908 assert!(!loaded_context.text.contains("file1.rs"));
2909 assert!(loaded_context.text.contains("file2.rs"));
2910 assert!(loaded_context.text.contains("file3.rs"));
2911 assert!(!loaded_context.text.contains("file4.rs"));
2912
2913 let new_contexts = context_store.update(cx, |store, cx| {
2914 // Remove file3.rs
2915 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2916 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2917 });
2918 assert_eq!(new_contexts.len(), 1);
2919 let loaded_context = cx
2920 .update(|cx| load_context(new_contexts, &project, &None, cx))
2921 .await
2922 .loaded_context;
2923
2924 assert!(!loaded_context.text.contains("file1.rs"));
2925 assert!(loaded_context.text.contains("file2.rs"));
2926 assert!(!loaded_context.text.contains("file3.rs"));
2927 assert!(!loaded_context.text.contains("file4.rs"));
2928 }
2929
2930 #[gpui::test]
2931 async fn test_message_without_files(cx: &mut TestAppContext) {
2932 init_test_settings(cx);
2933
2934 let project = create_test_project(
2935 cx,
2936 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2937 )
2938 .await;
2939
2940 let (_, _thread_store, thread, _context_store, model) =
2941 setup_test_environment(cx, project.clone()).await;
2942
2943 // Insert user message without any context (empty context vector)
2944 let message_id = thread.update(cx, |thread, cx| {
2945 thread.insert_user_message(
2946 "What is the best way to learn Rust?",
2947 ContextLoadResult::default(),
2948 None,
2949 Vec::new(),
2950 cx,
2951 )
2952 });
2953
2954 // Check content and context in message object
2955 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2956
2957 // Context should be empty when no files are included
2958 assert_eq!(message.role, Role::User);
2959 assert_eq!(message.segments.len(), 1);
2960 assert_eq!(
2961 message.segments[0],
2962 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2963 );
2964 assert_eq!(message.loaded_context.text, "");
2965
2966 // Check message in request
2967 let request = thread.update(cx, |thread, cx| {
2968 thread.to_completion_request(model.clone(), cx)
2969 });
2970
2971 assert_eq!(request.messages.len(), 2);
2972 assert_eq!(
2973 request.messages[1].string_contents(),
2974 "What is the best way to learn Rust?"
2975 );
2976
2977 // Add second message, also without context
2978 let message2_id = thread.update(cx, |thread, cx| {
2979 thread.insert_user_message(
2980 "Are there any good books?",
2981 ContextLoadResult::default(),
2982 None,
2983 Vec::new(),
2984 cx,
2985 )
2986 });
2987
2988 let message2 =
2989 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2990 assert_eq!(message2.loaded_context.text, "");
2991
2992 // Check that both messages appear in the request
2993 let request = thread.update(cx, |thread, cx| {
2994 thread.to_completion_request(model.clone(), cx)
2995 });
2996
2997 assert_eq!(request.messages.len(), 3);
2998 assert_eq!(
2999 request.messages[1].string_contents(),
3000 "What is the best way to learn Rust?"
3001 );
3002 assert_eq!(
3003 request.messages[2].string_contents(),
3004 "Are there any good books?"
3005 );
3006 }
3007
3008 #[gpui::test]
3009 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3010 init_test_settings(cx);
3011
3012 let project = create_test_project(
3013 cx,
3014 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3015 )
3016 .await;
3017
3018 let (_workspace, _thread_store, thread, context_store, model) =
3019 setup_test_environment(cx, project.clone()).await;
3020
3021 // Open buffer and add it to context
3022 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3023 .await
3024 .unwrap();
3025
3026 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3027 let loaded_context = cx
3028 .update(|cx| load_context(vec![context], &project, &None, cx))
3029 .await;
3030
3031 // Insert user message with the buffer as context
3032 thread.update(cx, |thread, cx| {
3033 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3034 });
3035
3036 // Create a request and check that it doesn't have a stale buffer warning yet
3037 let initial_request = thread.update(cx, |thread, cx| {
3038 thread.to_completion_request(model.clone(), cx)
3039 });
3040
3041 // Make sure we don't have a stale file warning yet
3042 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3043 msg.string_contents()
3044 .contains("These files changed since last read:")
3045 });
3046 assert!(
3047 !has_stale_warning,
3048 "Should not have stale buffer warning before buffer is modified"
3049 );
3050
3051 // Modify the buffer
3052 buffer.update(cx, |buffer, cx| {
3053 // Find a position at the end of line 1
3054 buffer.edit(
3055 [(1..1, "\n println!(\"Added a new line\");\n")],
3056 None,
3057 cx,
3058 );
3059 });
3060
3061 // Insert another user message without context
3062 thread.update(cx, |thread, cx| {
3063 thread.insert_user_message(
3064 "What does the code do now?",
3065 ContextLoadResult::default(),
3066 None,
3067 Vec::new(),
3068 cx,
3069 )
3070 });
3071
3072 // Create a new request and check for the stale buffer warning
3073 let new_request = thread.update(cx, |thread, cx| {
3074 thread.to_completion_request(model.clone(), cx)
3075 });
3076
3077 // We should have a stale file warning as the last message
3078 let last_message = new_request
3079 .messages
3080 .last()
3081 .expect("Request should have messages");
3082
3083 // The last message should be the stale buffer notification
3084 assert_eq!(last_message.role, Role::User);
3085
3086 // Check the exact content of the message
3087 let expected_content = "These files changed since last read:\n- code.rs\n";
3088 assert_eq!(
3089 last_message.string_contents(),
3090 expected_content,
3091 "Last message should be exactly the stale buffer notification"
3092 );
3093 }
3094
3095 #[gpui::test]
3096 async fn test_temperature_setting(cx: &mut TestAppContext) {
3097 init_test_settings(cx);
3098
3099 let project = create_test_project(
3100 cx,
3101 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3102 )
3103 .await;
3104
3105 let (_workspace, _thread_store, thread, _context_store, model) =
3106 setup_test_environment(cx, project.clone()).await;
3107
3108 // Both model and provider
3109 cx.update(|cx| {
3110 AssistantSettings::override_global(
3111 AssistantSettings {
3112 model_parameters: vec![LanguageModelParameters {
3113 provider: Some(model.provider_id().0.to_string().into()),
3114 model: Some(model.id().0.clone()),
3115 temperature: Some(0.66),
3116 }],
3117 ..AssistantSettings::get_global(cx).clone()
3118 },
3119 cx,
3120 );
3121 });
3122
3123 let request = thread.update(cx, |thread, cx| {
3124 thread.to_completion_request(model.clone(), cx)
3125 });
3126 assert_eq!(request.temperature, Some(0.66));
3127
3128 // Only model
3129 cx.update(|cx| {
3130 AssistantSettings::override_global(
3131 AssistantSettings {
3132 model_parameters: vec![LanguageModelParameters {
3133 provider: None,
3134 model: Some(model.id().0.clone()),
3135 temperature: Some(0.66),
3136 }],
3137 ..AssistantSettings::get_global(cx).clone()
3138 },
3139 cx,
3140 );
3141 });
3142
3143 let request = thread.update(cx, |thread, cx| {
3144 thread.to_completion_request(model.clone(), cx)
3145 });
3146 assert_eq!(request.temperature, Some(0.66));
3147
3148 // Only provider
3149 cx.update(|cx| {
3150 AssistantSettings::override_global(
3151 AssistantSettings {
3152 model_parameters: vec![LanguageModelParameters {
3153 provider: Some(model.provider_id().0.to_string().into()),
3154 model: None,
3155 temperature: Some(0.66),
3156 }],
3157 ..AssistantSettings::get_global(cx).clone()
3158 },
3159 cx,
3160 );
3161 });
3162
3163 let request = thread.update(cx, |thread, cx| {
3164 thread.to_completion_request(model.clone(), cx)
3165 });
3166 assert_eq!(request.temperature, Some(0.66));
3167
3168 // Same model name, different provider
3169 cx.update(|cx| {
3170 AssistantSettings::override_global(
3171 AssistantSettings {
3172 model_parameters: vec![LanguageModelParameters {
3173 provider: Some("anthropic".into()),
3174 model: Some(model.id().0.clone()),
3175 temperature: Some(0.66),
3176 }],
3177 ..AssistantSettings::get_global(cx).clone()
3178 },
3179 cx,
3180 );
3181 });
3182
3183 let request = thread.update(cx, |thread, cx| {
3184 thread.to_completion_request(model.clone(), cx)
3185 });
3186 assert_eq!(request.temperature, None);
3187 }
3188
3189 fn init_test_settings(cx: &mut TestAppContext) {
3190 cx.update(|cx| {
3191 let settings_store = SettingsStore::test(cx);
3192 cx.set_global(settings_store);
3193 language::init(cx);
3194 Project::init_settings(cx);
3195 AssistantSettings::register(cx);
3196 prompt_store::init(cx);
3197 thread_store::init(cx);
3198 workspace::init_settings(cx);
3199 language_model::init_settings(cx);
3200 ThemeSettings::register(cx);
3201 EditorSettings::register(cx);
3202 ToolRegistry::default_global(cx);
3203 });
3204 }
3205
3206 // Helper to create a test project with test files
3207 async fn create_test_project(
3208 cx: &mut TestAppContext,
3209 files: serde_json::Value,
3210 ) -> Entity<Project> {
3211 let fs = FakeFs::new(cx.executor());
3212 fs.insert_tree(path!("/test"), files).await;
3213 Project::test(fs, [path!("/test").as_ref()], cx).await
3214 }
3215
3216 async fn setup_test_environment(
3217 cx: &mut TestAppContext,
3218 project: Entity<Project>,
3219 ) -> (
3220 Entity<Workspace>,
3221 Entity<ThreadStore>,
3222 Entity<Thread>,
3223 Entity<ContextStore>,
3224 Arc<dyn LanguageModel>,
3225 ) {
3226 let (workspace, cx) =
3227 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3228
3229 let thread_store = cx
3230 .update(|_, cx| {
3231 ThreadStore::load(
3232 project.clone(),
3233 cx.new(|_| ToolWorkingSet::default()),
3234 None,
3235 Arc::new(PromptBuilder::new(None).unwrap()),
3236 cx,
3237 )
3238 })
3239 .await
3240 .unwrap();
3241
3242 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3243 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3244
3245 let model = FakeLanguageModel::default();
3246 let model: Arc<dyn LanguageModel> = Arc::new(model);
3247
3248 (workspace, thread_store, thread, context_store, model)
3249 }
3250
3251 async fn add_file_to_context(
3252 project: &Entity<Project>,
3253 context_store: &Entity<ContextStore>,
3254 path: &str,
3255 cx: &mut TestAppContext,
3256 ) -> Result<Entity<language::Buffer>> {
3257 let buffer_path = project
3258 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3259 .unwrap();
3260
3261 let buffer = project
3262 .update(cx, |project, cx| {
3263 project.open_buffer(buffer_path.clone(), cx)
3264 })
3265 .await
3266 .unwrap();
3267
3268 context_store.update(cx, |context_store, cx| {
3269 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3270 });
3271
3272 Ok(buffer)
3273 }
3274}