1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use util::{ResultExt as _, TryFutureExt as _, post_inc};
39use uuid::Uuid;
40use zed_llm_client::CompletionRequestStatus;
41
42use crate::ThreadStore;
43use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
44use crate::thread_store::{
45 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
46 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
47};
48use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
49
50#[derive(
51 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
52)]
53pub struct ThreadId(Arc<str>);
54
55impl ThreadId {
56 pub fn new() -> Self {
57 Self(Uuid::new_v4().to_string().into())
58 }
59}
60
61impl std::fmt::Display for ThreadId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<&str> for ThreadId {
68 fn from(value: &str) -> Self {
69 Self(value.into())
70 }
71}
72
73/// The ID of the user prompt that initiated a request.
74///
75/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
76#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
77pub struct PromptId(Arc<str>);
78
79impl PromptId {
80 pub fn new() -> Self {
81 Self(Uuid::new_v4().to_string().into())
82 }
83}
84
85impl std::fmt::Display for PromptId {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
92pub struct MessageId(pub(crate) usize);
93
94impl MessageId {
95 fn post_inc(&mut self) -> Self {
96 Self(post_inc(&mut self.0))
97 }
98}
99
100/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
101#[derive(Clone, Debug)]
102pub struct MessageCrease {
103 pub range: Range<usize>,
104 pub metadata: CreaseMetadata,
105 /// None for a deserialized message, Some otherwise.
106 pub context: Option<AgentContextHandle>,
107}
108
109/// A message in a [`Thread`].
110#[derive(Debug, Clone)]
111pub struct Message {
112 pub id: MessageId,
113 pub role: Role,
114 pub segments: Vec<MessageSegment>,
115 pub loaded_context: LoadedContext,
116 pub creases: Vec<MessageCrease>,
117}
118
119impl Message {
120 /// Returns whether the message contains any meaningful text that should be displayed
121 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
122 pub fn should_display_content(&self) -> bool {
123 self.segments.iter().all(|segment| segment.should_display())
124 }
125
126 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
127 if let Some(MessageSegment::Thinking {
128 text: segment,
129 signature: current_signature,
130 }) = self.segments.last_mut()
131 {
132 if let Some(signature) = signature {
133 *current_signature = Some(signature);
134 }
135 segment.push_str(text);
136 } else {
137 self.segments.push(MessageSegment::Thinking {
138 text: text.to_string(),
139 signature,
140 });
141 }
142 }
143
144 pub fn push_text(&mut self, text: &str) {
145 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
146 segment.push_str(text);
147 } else {
148 self.segments.push(MessageSegment::Text(text.to_string()));
149 }
150 }
151
152 pub fn to_string(&self) -> String {
153 let mut result = String::new();
154
155 if !self.loaded_context.text.is_empty() {
156 result.push_str(&self.loaded_context.text);
157 }
158
159 for segment in &self.segments {
160 match segment {
161 MessageSegment::Text(text) => result.push_str(text),
162 MessageSegment::Thinking { text, .. } => {
163 result.push_str("<think>\n");
164 result.push_str(text);
165 result.push_str("\n</think>");
166 }
167 MessageSegment::RedactedThinking(_) => {}
168 }
169 }
170
171 result
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum MessageSegment {
177 Text(String),
178 Thinking {
179 text: String,
180 signature: Option<String>,
181 },
182 RedactedThinking(Vec<u8>),
183}
184
185impl MessageSegment {
186 pub fn should_display(&self) -> bool {
187 match self {
188 Self::Text(text) => text.is_empty(),
189 Self::Thinking { text, .. } => text.is_empty(),
190 Self::RedactedThinking(_) => false,
191 }
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ProjectSnapshot {
197 pub worktree_snapshots: Vec<WorktreeSnapshot>,
198 pub unsaved_buffer_paths: Vec<String>,
199 pub timestamp: DateTime<Utc>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct WorktreeSnapshot {
204 pub worktree_path: String,
205 pub git_state: Option<GitState>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct GitState {
210 pub remote_url: Option<String>,
211 pub head_sha: Option<String>,
212 pub current_branch: Option<String>,
213 pub diff: Option<String>,
214}
215
216#[derive(Clone)]
217pub struct ThreadCheckpoint {
218 message_id: MessageId,
219 git_checkpoint: GitStoreCheckpoint,
220}
221
222#[derive(Copy, Clone, Debug, PartialEq, Eq)]
223pub enum ThreadFeedback {
224 Positive,
225 Negative,
226}
227
228pub enum LastRestoreCheckpoint {
229 Pending {
230 message_id: MessageId,
231 },
232 Error {
233 message_id: MessageId,
234 error: String,
235 },
236}
237
238impl LastRestoreCheckpoint {
239 pub fn message_id(&self) -> MessageId {
240 match self {
241 LastRestoreCheckpoint::Pending { message_id } => *message_id,
242 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
243 }
244 }
245}
246
247#[derive(Clone, Debug, Default, Serialize, Deserialize)]
248pub enum DetailedSummaryState {
249 #[default]
250 NotGenerated,
251 Generating {
252 message_id: MessageId,
253 },
254 Generated {
255 text: SharedString,
256 message_id: MessageId,
257 },
258}
259
260impl DetailedSummaryState {
261 fn text(&self) -> Option<SharedString> {
262 if let Self::Generated { text, .. } = self {
263 Some(text.clone())
264 } else {
265 None
266 }
267 }
268}
269
270#[derive(Default, Debug)]
271pub struct TotalTokenUsage {
272 pub total: usize,
273 pub max: usize,
274}
275
276impl TotalTokenUsage {
277 pub fn ratio(&self) -> TokenUsageRatio {
278 #[cfg(debug_assertions)]
279 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
280 .unwrap_or("0.8".to_string())
281 .parse()
282 .unwrap();
283 #[cfg(not(debug_assertions))]
284 let warning_threshold: f32 = 0.8;
285
286 // When the maximum is unknown because there is no selected model,
287 // avoid showing the token limit warning.
288 if self.max == 0 {
289 TokenUsageRatio::Normal
290 } else if self.total >= self.max {
291 TokenUsageRatio::Exceeded
292 } else if self.total as f32 / self.max as f32 >= warning_threshold {
293 TokenUsageRatio::Warning
294 } else {
295 TokenUsageRatio::Normal
296 }
297 }
298
299 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
300 TotalTokenUsage {
301 total: self.total + tokens,
302 max: self.max,
303 }
304 }
305}
306
307#[derive(Debug, Default, PartialEq, Eq)]
308pub enum TokenUsageRatio {
309 #[default]
310 Normal,
311 Warning,
312 Exceeded,
313}
314
315#[derive(Debug, Clone, Copy)]
316pub enum QueueState {
317 Sending,
318 Queued { position: usize },
319 Started,
320}
321
322/// A thread of conversation with the LLM.
323pub struct Thread {
324 id: ThreadId,
325 updated_at: DateTime<Utc>,
326 summary: Option<SharedString>,
327 pending_summary: Task<Option<()>>,
328 detailed_summary_task: Task<Option<()>>,
329 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
330 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
331 completion_mode: assistant_settings::CompletionMode,
332 messages: Vec<Message>,
333 next_message_id: MessageId,
334 last_prompt_id: PromptId,
335 project_context: SharedProjectContext,
336 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
337 completion_count: usize,
338 pending_completions: Vec<PendingCompletion>,
339 project: Entity<Project>,
340 prompt_builder: Arc<PromptBuilder>,
341 tools: Entity<ToolWorkingSet>,
342 tool_use: ToolUseState,
343 action_log: Entity<ActionLog>,
344 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
345 pending_checkpoint: Option<ThreadCheckpoint>,
346 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
347 request_token_usage: Vec<TokenUsage>,
348 cumulative_token_usage: TokenUsage,
349 exceeded_window_error: Option<ExceededWindowError>,
350 last_usage: Option<RequestUsage>,
351 tool_use_limit_reached: bool,
352 feedback: Option<ThreadFeedback>,
353 message_feedback: HashMap<MessageId, ThreadFeedback>,
354 last_auto_capture_at: Option<Instant>,
355 last_received_chunk_at: Option<Instant>,
356 request_callback: Option<
357 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
358 >,
359 remaining_turns: u32,
360 configured_model: Option<ConfiguredModel>,
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct ExceededWindowError {
365 /// Model used when last message exceeded context window
366 model_id: LanguageModelId,
367 /// Token count including last message
368 token_count: usize,
369}
370
371impl Thread {
372 pub fn new(
373 project: Entity<Project>,
374 tools: Entity<ToolWorkingSet>,
375 prompt_builder: Arc<PromptBuilder>,
376 system_prompt: SharedProjectContext,
377 cx: &mut Context<Self>,
378 ) -> Self {
379 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
380 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
381
382 Self {
383 id: ThreadId::new(),
384 updated_at: Utc::now(),
385 summary: None,
386 pending_summary: Task::ready(None),
387 detailed_summary_task: Task::ready(None),
388 detailed_summary_tx,
389 detailed_summary_rx,
390 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
391 messages: Vec::new(),
392 next_message_id: MessageId(0),
393 last_prompt_id: PromptId::new(),
394 project_context: system_prompt,
395 checkpoints_by_message: HashMap::default(),
396 completion_count: 0,
397 pending_completions: Vec::new(),
398 project: project.clone(),
399 prompt_builder,
400 tools: tools.clone(),
401 last_restore_checkpoint: None,
402 pending_checkpoint: None,
403 tool_use: ToolUseState::new(tools.clone()),
404 action_log: cx.new(|_| ActionLog::new(project.clone())),
405 initial_project_snapshot: {
406 let project_snapshot = Self::project_snapshot(project, cx);
407 cx.foreground_executor()
408 .spawn(async move { Some(project_snapshot.await) })
409 .shared()
410 },
411 request_token_usage: Vec::new(),
412 cumulative_token_usage: TokenUsage::default(),
413 exceeded_window_error: None,
414 last_usage: None,
415 tool_use_limit_reached: false,
416 feedback: None,
417 message_feedback: HashMap::default(),
418 last_auto_capture_at: None,
419 last_received_chunk_at: None,
420 request_callback: None,
421 remaining_turns: u32::MAX,
422 configured_model,
423 }
424 }
425
426 pub fn deserialize(
427 id: ThreadId,
428 serialized: SerializedThread,
429 project: Entity<Project>,
430 tools: Entity<ToolWorkingSet>,
431 prompt_builder: Arc<PromptBuilder>,
432 project_context: SharedProjectContext,
433 cx: &mut Context<Self>,
434 ) -> Self {
435 let next_message_id = MessageId(
436 serialized
437 .messages
438 .last()
439 .map(|message| message.id.0 + 1)
440 .unwrap_or(0),
441 );
442 let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
443 let (detailed_summary_tx, detailed_summary_rx) =
444 postage::watch::channel_with(serialized.detailed_summary_state);
445
446 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
447 serialized
448 .model
449 .and_then(|model| {
450 let model = SelectedModel {
451 provider: model.provider.clone().into(),
452 model: model.model.clone().into(),
453 };
454 registry.select_model(&model, cx)
455 })
456 .or_else(|| registry.default_model())
457 });
458
459 let completion_mode = serialized
460 .completion_mode
461 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
462
463 Self {
464 id,
465 updated_at: serialized.updated_at,
466 summary: Some(serialized.summary),
467 pending_summary: Task::ready(None),
468 detailed_summary_task: Task::ready(None),
469 detailed_summary_tx,
470 detailed_summary_rx,
471 completion_mode,
472 messages: serialized
473 .messages
474 .into_iter()
475 .map(|message| Message {
476 id: message.id,
477 role: message.role,
478 segments: message
479 .segments
480 .into_iter()
481 .map(|segment| match segment {
482 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
483 SerializedMessageSegment::Thinking { text, signature } => {
484 MessageSegment::Thinking { text, signature }
485 }
486 SerializedMessageSegment::RedactedThinking { data } => {
487 MessageSegment::RedactedThinking(data)
488 }
489 })
490 .collect(),
491 loaded_context: LoadedContext {
492 contexts: Vec::new(),
493 text: message.context,
494 images: Vec::new(),
495 },
496 creases: message
497 .creases
498 .into_iter()
499 .map(|crease| MessageCrease {
500 range: crease.start..crease.end,
501 metadata: CreaseMetadata {
502 icon_path: crease.icon_path,
503 label: crease.label,
504 },
505 context: None,
506 })
507 .collect(),
508 })
509 .collect(),
510 next_message_id,
511 last_prompt_id: PromptId::new(),
512 project_context,
513 checkpoints_by_message: HashMap::default(),
514 completion_count: 0,
515 pending_completions: Vec::new(),
516 last_restore_checkpoint: None,
517 pending_checkpoint: None,
518 project: project.clone(),
519 prompt_builder,
520 tools,
521 tool_use,
522 action_log: cx.new(|_| ActionLog::new(project)),
523 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
524 request_token_usage: serialized.request_token_usage,
525 cumulative_token_usage: serialized.cumulative_token_usage,
526 exceeded_window_error: None,
527 last_usage: None,
528 tool_use_limit_reached: false,
529 feedback: None,
530 message_feedback: HashMap::default(),
531 last_auto_capture_at: None,
532 last_received_chunk_at: None,
533 request_callback: None,
534 remaining_turns: u32::MAX,
535 configured_model,
536 }
537 }
538
539 pub fn set_request_callback(
540 &mut self,
541 callback: impl 'static
542 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
543 ) {
544 self.request_callback = Some(Box::new(callback));
545 }
546
547 pub fn id(&self) -> &ThreadId {
548 &self.id
549 }
550
551 pub fn is_empty(&self) -> bool {
552 self.messages.is_empty()
553 }
554
555 pub fn updated_at(&self) -> DateTime<Utc> {
556 self.updated_at
557 }
558
559 pub fn touch_updated_at(&mut self) {
560 self.updated_at = Utc::now();
561 }
562
563 pub fn advance_prompt_id(&mut self) {
564 self.last_prompt_id = PromptId::new();
565 }
566
567 pub fn summary(&self) -> Option<SharedString> {
568 self.summary.clone()
569 }
570
571 pub fn project_context(&self) -> SharedProjectContext {
572 self.project_context.clone()
573 }
574
575 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
576 if self.configured_model.is_none() {
577 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
578 }
579 self.configured_model.clone()
580 }
581
582 pub fn configured_model(&self) -> Option<ConfiguredModel> {
583 self.configured_model.clone()
584 }
585
586 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
587 self.configured_model = model;
588 cx.notify();
589 }
590
591 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
592
593 pub fn summary_or_default(&self) -> SharedString {
594 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
595 }
596
597 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
598 let Some(current_summary) = &self.summary else {
599 // Don't allow setting summary until generated
600 return;
601 };
602
603 let mut new_summary = new_summary.into();
604
605 if new_summary.is_empty() {
606 new_summary = Self::DEFAULT_SUMMARY;
607 }
608
609 if current_summary != &new_summary {
610 self.summary = Some(new_summary);
611 cx.emit(ThreadEvent::SummaryChanged);
612 }
613 }
614
615 pub fn completion_mode(&self) -> CompletionMode {
616 self.completion_mode
617 }
618
619 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
620 self.completion_mode = mode;
621 }
622
623 pub fn message(&self, id: MessageId) -> Option<&Message> {
624 let index = self
625 .messages
626 .binary_search_by(|message| message.id.cmp(&id))
627 .ok()?;
628
629 self.messages.get(index)
630 }
631
632 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
633 self.messages.iter()
634 }
635
636 pub fn is_generating(&self) -> bool {
637 !self.pending_completions.is_empty() || !self.all_tools_finished()
638 }
639
640 /// Indicates whether streaming of language model events is stale.
641 /// When `is_generating()` is false, this method returns `None`.
642 pub fn is_generation_stale(&self) -> Option<bool> {
643 const STALE_THRESHOLD: u128 = 250;
644
645 self.last_received_chunk_at
646 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
647 }
648
649 fn received_chunk(&mut self) {
650 self.last_received_chunk_at = Some(Instant::now());
651 }
652
653 pub fn queue_state(&self) -> Option<QueueState> {
654 self.pending_completions
655 .first()
656 .map(|pending_completion| pending_completion.queue_state)
657 }
658
659 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
660 &self.tools
661 }
662
663 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
664 self.tool_use
665 .pending_tool_uses()
666 .into_iter()
667 .find(|tool_use| &tool_use.id == id)
668 }
669
670 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
671 self.tool_use
672 .pending_tool_uses()
673 .into_iter()
674 .filter(|tool_use| tool_use.status.needs_confirmation())
675 }
676
677 pub fn has_pending_tool_uses(&self) -> bool {
678 !self.tool_use.pending_tool_uses().is_empty()
679 }
680
681 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
682 self.checkpoints_by_message.get(&id).cloned()
683 }
684
685 pub fn restore_checkpoint(
686 &mut self,
687 checkpoint: ThreadCheckpoint,
688 cx: &mut Context<Self>,
689 ) -> Task<Result<()>> {
690 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
691 message_id: checkpoint.message_id,
692 });
693 cx.emit(ThreadEvent::CheckpointChanged);
694 cx.notify();
695
696 let git_store = self.project().read(cx).git_store().clone();
697 let restore = git_store.update(cx, |git_store, cx| {
698 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
699 });
700
701 cx.spawn(async move |this, cx| {
702 let result = restore.await;
703 this.update(cx, |this, cx| {
704 if let Err(err) = result.as_ref() {
705 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
706 message_id: checkpoint.message_id,
707 error: err.to_string(),
708 });
709 } else {
710 this.truncate(checkpoint.message_id, cx);
711 this.last_restore_checkpoint = None;
712 }
713 this.pending_checkpoint = None;
714 cx.emit(ThreadEvent::CheckpointChanged);
715 cx.notify();
716 })?;
717 result
718 })
719 }
720
721 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
722 let pending_checkpoint = if self.is_generating() {
723 return;
724 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
725 checkpoint
726 } else {
727 return;
728 };
729
730 let git_store = self.project.read(cx).git_store().clone();
731 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
732 cx.spawn(async move |this, cx| match final_checkpoint.await {
733 Ok(final_checkpoint) => {
734 let equal = git_store
735 .update(cx, |store, cx| {
736 store.compare_checkpoints(
737 pending_checkpoint.git_checkpoint.clone(),
738 final_checkpoint.clone(),
739 cx,
740 )
741 })?
742 .await
743 .unwrap_or(false);
744
745 if !equal {
746 this.update(cx, |this, cx| {
747 this.insert_checkpoint(pending_checkpoint, cx)
748 })?;
749 }
750
751 Ok(())
752 }
753 Err(_) => this.update(cx, |this, cx| {
754 this.insert_checkpoint(pending_checkpoint, cx)
755 }),
756 })
757 .detach();
758 }
759
760 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
761 self.checkpoints_by_message
762 .insert(checkpoint.message_id, checkpoint);
763 cx.emit(ThreadEvent::CheckpointChanged);
764 cx.notify();
765 }
766
767 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
768 self.last_restore_checkpoint.as_ref()
769 }
770
771 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
772 let Some(message_ix) = self
773 .messages
774 .iter()
775 .rposition(|message| message.id == message_id)
776 else {
777 return;
778 };
779 for deleted_message in self.messages.drain(message_ix..) {
780 self.checkpoints_by_message.remove(&deleted_message.id);
781 }
782 cx.notify();
783 }
784
785 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
786 self.messages
787 .iter()
788 .find(|message| message.id == id)
789 .into_iter()
790 .flat_map(|message| message.loaded_context.contexts.iter())
791 }
792
793 pub fn is_turn_end(&self, ix: usize) -> bool {
794 if self.messages.is_empty() {
795 return false;
796 }
797
798 if !self.is_generating() && ix == self.messages.len() - 1 {
799 return true;
800 }
801
802 let Some(message) = self.messages.get(ix) else {
803 return false;
804 };
805
806 if message.role != Role::Assistant {
807 return false;
808 }
809
810 self.messages
811 .get(ix + 1)
812 .and_then(|message| {
813 self.message(message.id)
814 .map(|next_message| next_message.role == Role::User)
815 })
816 .unwrap_or(false)
817 }
818
819 pub fn last_usage(&self) -> Option<RequestUsage> {
820 self.last_usage
821 }
822
823 pub fn tool_use_limit_reached(&self) -> bool {
824 self.tool_use_limit_reached
825 }
826
827 /// Returns whether all of the tool uses have finished running.
828 pub fn all_tools_finished(&self) -> bool {
829 // If the only pending tool uses left are the ones with errors, then
830 // that means that we've finished running all of the pending tools.
831 self.tool_use
832 .pending_tool_uses()
833 .iter()
834 .all(|tool_use| tool_use.status.is_error())
835 }
836
837 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
838 self.tool_use.tool_uses_for_message(id, cx)
839 }
840
841 pub fn tool_results_for_message(
842 &self,
843 assistant_message_id: MessageId,
844 ) -> Vec<&LanguageModelToolResult> {
845 self.tool_use.tool_results_for_message(assistant_message_id)
846 }
847
848 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
849 self.tool_use.tool_result(id)
850 }
851
852 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
853 Some(&self.tool_use.tool_result(id)?.content)
854 }
855
856 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
857 self.tool_use.tool_result_card(id).cloned()
858 }
859
860 /// Return tools that are both enabled and supported by the model
861 pub fn available_tools(
862 &self,
863 cx: &App,
864 model: Arc<dyn LanguageModel>,
865 ) -> Vec<LanguageModelRequestTool> {
866 if model.supports_tools() {
867 self.tools()
868 .read(cx)
869 .enabled_tools(cx)
870 .into_iter()
871 .filter_map(|tool| {
872 // Skip tools that cannot be supported
873 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
874 Some(LanguageModelRequestTool {
875 name: tool.name(),
876 description: tool.description(),
877 input_schema,
878 })
879 })
880 .collect()
881 } else {
882 Vec::default()
883 }
884 }
885
886 pub fn insert_user_message(
887 &mut self,
888 text: impl Into<String>,
889 loaded_context: ContextLoadResult,
890 git_checkpoint: Option<GitStoreCheckpoint>,
891 creases: Vec<MessageCrease>,
892 cx: &mut Context<Self>,
893 ) -> MessageId {
894 if !loaded_context.referenced_buffers.is_empty() {
895 self.action_log.update(cx, |log, cx| {
896 for buffer in loaded_context.referenced_buffers {
897 log.buffer_read(buffer, cx);
898 }
899 });
900 }
901
902 let message_id = self.insert_message(
903 Role::User,
904 vec![MessageSegment::Text(text.into())],
905 loaded_context.loaded_context,
906 creases,
907 cx,
908 );
909
910 if let Some(git_checkpoint) = git_checkpoint {
911 self.pending_checkpoint = Some(ThreadCheckpoint {
912 message_id,
913 git_checkpoint,
914 });
915 }
916
917 self.auto_capture_telemetry(cx);
918
919 message_id
920 }
921
922 pub fn insert_assistant_message(
923 &mut self,
924 segments: Vec<MessageSegment>,
925 cx: &mut Context<Self>,
926 ) -> MessageId {
927 self.insert_message(
928 Role::Assistant,
929 segments,
930 LoadedContext::default(),
931 Vec::new(),
932 cx,
933 )
934 }
935
936 pub fn insert_message(
937 &mut self,
938 role: Role,
939 segments: Vec<MessageSegment>,
940 loaded_context: LoadedContext,
941 creases: Vec<MessageCrease>,
942 cx: &mut Context<Self>,
943 ) -> MessageId {
944 let id = self.next_message_id.post_inc();
945 self.messages.push(Message {
946 id,
947 role,
948 segments,
949 loaded_context,
950 creases,
951 });
952 self.touch_updated_at();
953 cx.emit(ThreadEvent::MessageAdded(id));
954 id
955 }
956
957 pub fn edit_message(
958 &mut self,
959 id: MessageId,
960 new_role: Role,
961 new_segments: Vec<MessageSegment>,
962 loaded_context: Option<LoadedContext>,
963 cx: &mut Context<Self>,
964 ) -> bool {
965 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
966 return false;
967 };
968 message.role = new_role;
969 message.segments = new_segments;
970 if let Some(context) = loaded_context {
971 message.loaded_context = context;
972 }
973 self.touch_updated_at();
974 cx.emit(ThreadEvent::MessageEdited(id));
975 true
976 }
977
978 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
979 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
980 return false;
981 };
982 self.messages.remove(index);
983 self.touch_updated_at();
984 cx.emit(ThreadEvent::MessageDeleted(id));
985 true
986 }
987
988 /// Returns the representation of this [`Thread`] in a textual form.
989 ///
990 /// This is the representation we use when attaching a thread as context to another thread.
991 pub fn text(&self) -> String {
992 let mut text = String::new();
993
994 for message in &self.messages {
995 text.push_str(match message.role {
996 language_model::Role::User => "User:",
997 language_model::Role::Assistant => "Agent:",
998 language_model::Role::System => "System:",
999 });
1000 text.push('\n');
1001
1002 for segment in &message.segments {
1003 match segment {
1004 MessageSegment::Text(content) => text.push_str(content),
1005 MessageSegment::Thinking { text: content, .. } => {
1006 text.push_str(&format!("<think>{}</think>", content))
1007 }
1008 MessageSegment::RedactedThinking(_) => {}
1009 }
1010 }
1011 text.push('\n');
1012 }
1013
1014 text
1015 }
1016
1017 /// Serializes this thread into a format for storage or telemetry.
1018 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1019 let initial_project_snapshot = self.initial_project_snapshot.clone();
1020 cx.spawn(async move |this, cx| {
1021 let initial_project_snapshot = initial_project_snapshot.await;
1022 this.read_with(cx, |this, cx| SerializedThread {
1023 version: SerializedThread::VERSION.to_string(),
1024 summary: this.summary_or_default(),
1025 updated_at: this.updated_at(),
1026 messages: this
1027 .messages()
1028 .map(|message| SerializedMessage {
1029 id: message.id,
1030 role: message.role,
1031 segments: message
1032 .segments
1033 .iter()
1034 .map(|segment| match segment {
1035 MessageSegment::Text(text) => {
1036 SerializedMessageSegment::Text { text: text.clone() }
1037 }
1038 MessageSegment::Thinking { text, signature } => {
1039 SerializedMessageSegment::Thinking {
1040 text: text.clone(),
1041 signature: signature.clone(),
1042 }
1043 }
1044 MessageSegment::RedactedThinking(data) => {
1045 SerializedMessageSegment::RedactedThinking {
1046 data: data.clone(),
1047 }
1048 }
1049 })
1050 .collect(),
1051 tool_uses: this
1052 .tool_uses_for_message(message.id, cx)
1053 .into_iter()
1054 .map(|tool_use| SerializedToolUse {
1055 id: tool_use.id,
1056 name: tool_use.name,
1057 input: tool_use.input,
1058 })
1059 .collect(),
1060 tool_results: this
1061 .tool_results_for_message(message.id)
1062 .into_iter()
1063 .map(|tool_result| SerializedToolResult {
1064 tool_use_id: tool_result.tool_use_id.clone(),
1065 is_error: tool_result.is_error,
1066 content: tool_result.content.clone(),
1067 })
1068 .collect(),
1069 context: message.loaded_context.text.clone(),
1070 creases: message
1071 .creases
1072 .iter()
1073 .map(|crease| SerializedCrease {
1074 start: crease.range.start,
1075 end: crease.range.end,
1076 icon_path: crease.metadata.icon_path.clone(),
1077 label: crease.metadata.label.clone(),
1078 })
1079 .collect(),
1080 })
1081 .collect(),
1082 initial_project_snapshot,
1083 cumulative_token_usage: this.cumulative_token_usage,
1084 request_token_usage: this.request_token_usage.clone(),
1085 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1086 exceeded_window_error: this.exceeded_window_error.clone(),
1087 model: this
1088 .configured_model
1089 .as_ref()
1090 .map(|model| SerializedLanguageModel {
1091 provider: model.provider.id().0.to_string(),
1092 model: model.model.id().0.to_string(),
1093 }),
1094 completion_mode: Some(this.completion_mode),
1095 })
1096 })
1097 }
1098
1099 pub fn remaining_turns(&self) -> u32 {
1100 self.remaining_turns
1101 }
1102
1103 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1104 self.remaining_turns = remaining_turns;
1105 }
1106
1107 pub fn send_to_model(
1108 &mut self,
1109 model: Arc<dyn LanguageModel>,
1110 window: Option<AnyWindowHandle>,
1111 cx: &mut Context<Self>,
1112 ) {
1113 if self.remaining_turns == 0 {
1114 return;
1115 }
1116
1117 self.remaining_turns -= 1;
1118
1119 let request = self.to_completion_request(model.clone(), cx);
1120
1121 self.stream_completion(request, model, window, cx);
1122 }
1123
1124 pub fn used_tools_since_last_user_message(&self) -> bool {
1125 for message in self.messages.iter().rev() {
1126 if self.tool_use.message_has_tool_results(message.id) {
1127 return true;
1128 } else if message.role == Role::User {
1129 return false;
1130 }
1131 }
1132
1133 false
1134 }
1135
1136 pub fn to_completion_request(
1137 &self,
1138 model: Arc<dyn LanguageModel>,
1139 cx: &mut Context<Self>,
1140 ) -> LanguageModelRequest {
1141 let mut request = LanguageModelRequest {
1142 thread_id: Some(self.id.to_string()),
1143 prompt_id: Some(self.last_prompt_id.to_string()),
1144 mode: None,
1145 messages: vec![],
1146 tools: Vec::new(),
1147 stop: Vec::new(),
1148 temperature: AssistantSettings::temperature_for_model(&model, cx),
1149 };
1150
1151 let available_tools = self.available_tools(cx, model.clone());
1152 let available_tool_names = available_tools
1153 .iter()
1154 .map(|tool| tool.name.clone())
1155 .collect();
1156
1157 let model_context = &ModelContext {
1158 available_tools: available_tool_names,
1159 };
1160
1161 if let Some(project_context) = self.project_context.borrow().as_ref() {
1162 match self
1163 .prompt_builder
1164 .generate_assistant_system_prompt(project_context, model_context)
1165 {
1166 Err(err) => {
1167 let message = format!("{err:?}").into();
1168 log::error!("{message}");
1169 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1170 header: "Error generating system prompt".into(),
1171 message,
1172 }));
1173 }
1174 Ok(system_prompt) => {
1175 request.messages.push(LanguageModelRequestMessage {
1176 role: Role::System,
1177 content: vec![MessageContent::Text(system_prompt)],
1178 cache: true,
1179 });
1180 }
1181 }
1182 } else {
1183 let message = "Context for system prompt unexpectedly not ready.".into();
1184 log::error!("{message}");
1185 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1186 header: "Error generating system prompt".into(),
1187 message,
1188 }));
1189 }
1190
1191 for message in &self.messages {
1192 let mut request_message = LanguageModelRequestMessage {
1193 role: message.role,
1194 content: Vec::new(),
1195 cache: false,
1196 };
1197
1198 message
1199 .loaded_context
1200 .add_to_request_message(&mut request_message);
1201
1202 for segment in &message.segments {
1203 match segment {
1204 MessageSegment::Text(text) => {
1205 if !text.is_empty() {
1206 request_message
1207 .content
1208 .push(MessageContent::Text(text.into()));
1209 }
1210 }
1211 MessageSegment::Thinking { text, signature } => {
1212 if !text.is_empty() {
1213 request_message.content.push(MessageContent::Thinking {
1214 text: text.into(),
1215 signature: signature.clone(),
1216 });
1217 }
1218 }
1219 MessageSegment::RedactedThinking(data) => {
1220 request_message
1221 .content
1222 .push(MessageContent::RedactedThinking(data.clone()));
1223 }
1224 };
1225 }
1226
1227 self.tool_use
1228 .attach_tool_uses(message.id, &mut request_message);
1229
1230 request.messages.push(request_message);
1231
1232 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1233 request.messages.push(tool_results_message);
1234 }
1235 }
1236
1237 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1238 if let Some(last) = request.messages.last_mut() {
1239 last.cache = true;
1240 }
1241
1242 self.attached_tracked_files_state(&mut request.messages, cx);
1243
1244 request.tools = available_tools;
1245 request.mode = if model.supports_max_mode() {
1246 Some(self.completion_mode.into())
1247 } else {
1248 Some(CompletionMode::Normal.into())
1249 };
1250
1251 request
1252 }
1253
1254 fn to_summarize_request(
1255 &self,
1256 model: &Arc<dyn LanguageModel>,
1257 added_user_message: String,
1258 cx: &App,
1259 ) -> LanguageModelRequest {
1260 let mut request = LanguageModelRequest {
1261 thread_id: None,
1262 prompt_id: None,
1263 mode: None,
1264 messages: vec![],
1265 tools: Vec::new(),
1266 stop: Vec::new(),
1267 temperature: AssistantSettings::temperature_for_model(model, cx),
1268 };
1269
1270 for message in &self.messages {
1271 let mut request_message = LanguageModelRequestMessage {
1272 role: message.role,
1273 content: Vec::new(),
1274 cache: false,
1275 };
1276
1277 for segment in &message.segments {
1278 match segment {
1279 MessageSegment::Text(text) => request_message
1280 .content
1281 .push(MessageContent::Text(text.clone())),
1282 MessageSegment::Thinking { .. } => {}
1283 MessageSegment::RedactedThinking(_) => {}
1284 }
1285 }
1286
1287 if request_message.content.is_empty() {
1288 continue;
1289 }
1290
1291 request.messages.push(request_message);
1292 }
1293
1294 request.messages.push(LanguageModelRequestMessage {
1295 role: Role::User,
1296 content: vec![MessageContent::Text(added_user_message)],
1297 cache: false,
1298 });
1299
1300 request
1301 }
1302
1303 fn attached_tracked_files_state(
1304 &self,
1305 messages: &mut Vec<LanguageModelRequestMessage>,
1306 cx: &App,
1307 ) {
1308 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1309
1310 let mut stale_message = String::new();
1311
1312 let action_log = self.action_log.read(cx);
1313
1314 for stale_file in action_log.stale_buffers(cx) {
1315 let Some(file) = stale_file.read(cx).file() else {
1316 continue;
1317 };
1318
1319 if stale_message.is_empty() {
1320 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1321 }
1322
1323 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1324 }
1325
1326 let mut content = Vec::with_capacity(2);
1327
1328 if !stale_message.is_empty() {
1329 content.push(stale_message.into());
1330 }
1331
1332 if !content.is_empty() {
1333 let context_message = LanguageModelRequestMessage {
1334 role: Role::User,
1335 content,
1336 cache: false,
1337 };
1338
1339 messages.push(context_message);
1340 }
1341 }
1342
1343 pub fn stream_completion(
1344 &mut self,
1345 request: LanguageModelRequest,
1346 model: Arc<dyn LanguageModel>,
1347 window: Option<AnyWindowHandle>,
1348 cx: &mut Context<Self>,
1349 ) {
1350 self.tool_use_limit_reached = false;
1351
1352 let pending_completion_id = post_inc(&mut self.completion_count);
1353 let mut request_callback_parameters = if self.request_callback.is_some() {
1354 Some((request.clone(), Vec::new()))
1355 } else {
1356 None
1357 };
1358 let prompt_id = self.last_prompt_id.clone();
1359 let tool_use_metadata = ToolUseMetadata {
1360 model: model.clone(),
1361 thread_id: self.id.clone(),
1362 prompt_id: prompt_id.clone(),
1363 };
1364
1365 self.last_received_chunk_at = Some(Instant::now());
1366
1367 let task = cx.spawn(async move |thread, cx| {
1368 let stream_completion_future = model.stream_completion(request, &cx);
1369 let initial_token_usage =
1370 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1371 let stream_completion = async {
1372 let mut events = stream_completion_future.await?;
1373
1374 let mut stop_reason = StopReason::EndTurn;
1375 let mut current_token_usage = TokenUsage::default();
1376
1377 thread
1378 .update(cx, |_thread, cx| {
1379 cx.emit(ThreadEvent::NewRequest);
1380 })
1381 .ok();
1382
1383 let mut request_assistant_message_id = None;
1384
1385 while let Some(event) = events.next().await {
1386 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1387 response_events
1388 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1389 }
1390
1391 thread.update(cx, |thread, cx| {
1392 let event = match event {
1393 Ok(event) => event,
1394 Err(LanguageModelCompletionError::BadInputJson {
1395 id,
1396 tool_name,
1397 raw_input: invalid_input_json,
1398 json_parse_error,
1399 }) => {
1400 thread.receive_invalid_tool_json(
1401 id,
1402 tool_name,
1403 invalid_input_json,
1404 json_parse_error,
1405 window,
1406 cx,
1407 );
1408 return Ok(());
1409 }
1410 Err(LanguageModelCompletionError::Other(error)) => {
1411 return Err(error);
1412 }
1413 };
1414
1415 match event {
1416 LanguageModelCompletionEvent::StartMessage { .. } => {
1417 request_assistant_message_id =
1418 Some(thread.insert_assistant_message(
1419 vec![MessageSegment::Text(String::new())],
1420 cx,
1421 ));
1422 }
1423 LanguageModelCompletionEvent::Stop(reason) => {
1424 stop_reason = reason;
1425 }
1426 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1427 thread.update_token_usage_at_last_message(token_usage);
1428 thread.cumulative_token_usage = thread.cumulative_token_usage
1429 + token_usage
1430 - current_token_usage;
1431 current_token_usage = token_usage;
1432 }
1433 LanguageModelCompletionEvent::Text(chunk) => {
1434 thread.received_chunk();
1435
1436 cx.emit(ThreadEvent::ReceivedTextChunk);
1437 if let Some(last_message) = thread.messages.last_mut() {
1438 if last_message.role == Role::Assistant
1439 && !thread.tool_use.has_tool_results(last_message.id)
1440 {
1441 last_message.push_text(&chunk);
1442 cx.emit(ThreadEvent::StreamedAssistantText(
1443 last_message.id,
1444 chunk,
1445 ));
1446 } else {
1447 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1448 // of a new Assistant response.
1449 //
1450 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1451 // will result in duplicating the text of the chunk in the rendered Markdown.
1452 request_assistant_message_id =
1453 Some(thread.insert_assistant_message(
1454 vec![MessageSegment::Text(chunk.to_string())],
1455 cx,
1456 ));
1457 };
1458 }
1459 }
1460 LanguageModelCompletionEvent::Thinking {
1461 text: chunk,
1462 signature,
1463 } => {
1464 thread.received_chunk();
1465
1466 if let Some(last_message) = thread.messages.last_mut() {
1467 if last_message.role == Role::Assistant
1468 && !thread.tool_use.has_tool_results(last_message.id)
1469 {
1470 last_message.push_thinking(&chunk, signature);
1471 cx.emit(ThreadEvent::StreamedAssistantThinking(
1472 last_message.id,
1473 chunk,
1474 ));
1475 } else {
1476 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1477 // of a new Assistant response.
1478 //
1479 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1480 // will result in duplicating the text of the chunk in the rendered Markdown.
1481 request_assistant_message_id =
1482 Some(thread.insert_assistant_message(
1483 vec![MessageSegment::Thinking {
1484 text: chunk.to_string(),
1485 signature,
1486 }],
1487 cx,
1488 ));
1489 };
1490 }
1491 }
1492 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1493 let last_assistant_message_id = request_assistant_message_id
1494 .unwrap_or_else(|| {
1495 let new_assistant_message_id =
1496 thread.insert_assistant_message(vec![], cx);
1497 request_assistant_message_id =
1498 Some(new_assistant_message_id);
1499 new_assistant_message_id
1500 });
1501
1502 let tool_use_id = tool_use.id.clone();
1503 let streamed_input = if tool_use.is_input_complete {
1504 None
1505 } else {
1506 Some((&tool_use.input).clone())
1507 };
1508
1509 let ui_text = thread.tool_use.request_tool_use(
1510 last_assistant_message_id,
1511 tool_use,
1512 tool_use_metadata.clone(),
1513 cx,
1514 );
1515
1516 if let Some(input) = streamed_input {
1517 cx.emit(ThreadEvent::StreamedToolUse {
1518 tool_use_id,
1519 ui_text,
1520 input,
1521 });
1522 }
1523 }
1524 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1525 if let Some(completion) = thread
1526 .pending_completions
1527 .iter_mut()
1528 .find(|completion| completion.id == pending_completion_id)
1529 {
1530 match status_update {
1531 CompletionRequestStatus::Queued {
1532 position,
1533 } => {
1534 completion.queue_state = QueueState::Queued { position };
1535 }
1536 CompletionRequestStatus::Started => {
1537 completion.queue_state = QueueState::Started;
1538 }
1539 CompletionRequestStatus::Failed {
1540 code, message
1541 } => {
1542 return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
1543 }
1544 CompletionRequestStatus::UsageUpdated {
1545 amount, limit
1546 } => {
1547 let usage = RequestUsage { limit, amount: amount as i32 };
1548
1549 thread.last_usage = Some(usage);
1550 }
1551 CompletionRequestStatus::ToolUseLimitReached => {
1552 thread.tool_use_limit_reached = true;
1553 }
1554 }
1555 }
1556 }
1557 }
1558
1559 thread.touch_updated_at();
1560 cx.emit(ThreadEvent::StreamedCompletion);
1561 cx.notify();
1562
1563 thread.auto_capture_telemetry(cx);
1564 Ok(())
1565 })??;
1566
1567 smol::future::yield_now().await;
1568 }
1569
1570 thread.update(cx, |thread, cx| {
1571 thread.last_received_chunk_at = None;
1572 thread
1573 .pending_completions
1574 .retain(|completion| completion.id != pending_completion_id);
1575
1576 // If there is a response without tool use, summarize the message. Otherwise,
1577 // allow two tool uses before summarizing.
1578 if thread.summary.is_none()
1579 && thread.messages.len() >= 2
1580 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1581 {
1582 thread.summarize(cx);
1583 }
1584 })?;
1585
1586 anyhow::Ok(stop_reason)
1587 };
1588
1589 let result = stream_completion.await;
1590
1591 thread
1592 .update(cx, |thread, cx| {
1593 thread.finalize_pending_checkpoint(cx);
1594 match result.as_ref() {
1595 Ok(stop_reason) => match stop_reason {
1596 StopReason::ToolUse => {
1597 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1598 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1599 }
1600 StopReason::EndTurn | StopReason::MaxTokens => {
1601 thread.project.update(cx, |project, cx| {
1602 project.set_agent_location(None, cx);
1603 });
1604 }
1605 },
1606 Err(error) => {
1607 thread.project.update(cx, |project, cx| {
1608 project.set_agent_location(None, cx);
1609 });
1610
1611 if error.is::<PaymentRequiredError>() {
1612 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1613 } else if error.is::<MaxMonthlySpendReachedError>() {
1614 cx.emit(ThreadEvent::ShowError(
1615 ThreadError::MaxMonthlySpendReached,
1616 ));
1617 } else if let Some(error) =
1618 error.downcast_ref::<ModelRequestLimitReachedError>()
1619 {
1620 cx.emit(ThreadEvent::ShowError(
1621 ThreadError::ModelRequestLimitReached { plan: error.plan },
1622 ));
1623 } else if let Some(known_error) =
1624 error.downcast_ref::<LanguageModelKnownError>()
1625 {
1626 match known_error {
1627 LanguageModelKnownError::ContextWindowLimitExceeded {
1628 tokens,
1629 } => {
1630 thread.exceeded_window_error = Some(ExceededWindowError {
1631 model_id: model.id(),
1632 token_count: *tokens,
1633 });
1634 cx.notify();
1635 }
1636 }
1637 } else {
1638 let error_message = error
1639 .chain()
1640 .map(|err| err.to_string())
1641 .collect::<Vec<_>>()
1642 .join("\n");
1643 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1644 header: "Error interacting with language model".into(),
1645 message: SharedString::from(error_message.clone()),
1646 }));
1647 }
1648
1649 thread.cancel_last_completion(window, cx);
1650 }
1651 }
1652 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1653
1654 if let Some((request_callback, (request, response_events))) = thread
1655 .request_callback
1656 .as_mut()
1657 .zip(request_callback_parameters.as_ref())
1658 {
1659 request_callback(request, response_events);
1660 }
1661
1662 thread.auto_capture_telemetry(cx);
1663
1664 if let Ok(initial_usage) = initial_token_usage {
1665 let usage = thread.cumulative_token_usage - initial_usage;
1666
1667 telemetry::event!(
1668 "Assistant Thread Completion",
1669 thread_id = thread.id().to_string(),
1670 prompt_id = prompt_id,
1671 model = model.telemetry_id(),
1672 model_provider = model.provider_id().to_string(),
1673 input_tokens = usage.input_tokens,
1674 output_tokens = usage.output_tokens,
1675 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1676 cache_read_input_tokens = usage.cache_read_input_tokens,
1677 );
1678 }
1679 })
1680 .ok();
1681 });
1682
1683 self.pending_completions.push(PendingCompletion {
1684 id: pending_completion_id,
1685 queue_state: QueueState::Sending,
1686 _task: task,
1687 });
1688 }
1689
1690 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1691 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1692 return;
1693 };
1694
1695 if !model.provider.is_authenticated(cx) {
1696 return;
1697 }
1698
1699 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1700 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1701 If the conversation is about a specific subject, include it in the title. \
1702 Be descriptive. DO NOT speak in the first person.";
1703
1704 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1705
1706 self.pending_summary = cx.spawn(async move |this, cx| {
1707 async move {
1708 let mut messages = model.model.stream_completion(request, &cx).await?;
1709
1710 let mut new_summary = String::new();
1711 while let Some(event) = messages.next().await {
1712 let event = event?;
1713 let text = match event {
1714 LanguageModelCompletionEvent::Text(text) => text,
1715 LanguageModelCompletionEvent::StatusUpdate(
1716 CompletionRequestStatus::UsageUpdated { amount, limit },
1717 ) => {
1718 this.update(cx, |thread, _cx| {
1719 thread.last_usage = Some(RequestUsage {
1720 limit,
1721 amount: amount as i32,
1722 });
1723 })?;
1724 continue;
1725 }
1726 _ => continue,
1727 };
1728
1729 let mut lines = text.lines();
1730 new_summary.extend(lines.next());
1731
1732 // Stop if the LLM generated multiple lines.
1733 if lines.next().is_some() {
1734 break;
1735 }
1736 }
1737
1738 this.update(cx, |this, cx| {
1739 if !new_summary.is_empty() {
1740 this.summary = Some(new_summary.into());
1741 }
1742
1743 cx.emit(ThreadEvent::SummaryGenerated);
1744 })?;
1745
1746 anyhow::Ok(())
1747 }
1748 .log_err()
1749 .await
1750 });
1751 }
1752
1753 pub fn start_generating_detailed_summary_if_needed(
1754 &mut self,
1755 thread_store: WeakEntity<ThreadStore>,
1756 cx: &mut Context<Self>,
1757 ) {
1758 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1759 return;
1760 };
1761
1762 match &*self.detailed_summary_rx.borrow() {
1763 DetailedSummaryState::Generating { message_id, .. }
1764 | DetailedSummaryState::Generated { message_id, .. }
1765 if *message_id == last_message_id =>
1766 {
1767 // Already up-to-date
1768 return;
1769 }
1770 _ => {}
1771 }
1772
1773 let Some(ConfiguredModel { model, provider }) =
1774 LanguageModelRegistry::read_global(cx).thread_summary_model()
1775 else {
1776 return;
1777 };
1778
1779 if !provider.is_authenticated(cx) {
1780 return;
1781 }
1782
1783 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1784 1. A brief overview of what was discussed\n\
1785 2. Key facts or information discovered\n\
1786 3. Outcomes or conclusions reached\n\
1787 4. Any action items or next steps if any\n\
1788 Format it in Markdown with headings and bullet points.";
1789
1790 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1791
1792 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1793 message_id: last_message_id,
1794 };
1795
1796 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1797 // be better to allow the old task to complete, but this would require logic for choosing
1798 // which result to prefer (the old task could complete after the new one, resulting in a
1799 // stale summary).
1800 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1801 let stream = model.stream_completion_text(request, &cx);
1802 let Some(mut messages) = stream.await.log_err() else {
1803 thread
1804 .update(cx, |thread, _cx| {
1805 *thread.detailed_summary_tx.borrow_mut() =
1806 DetailedSummaryState::NotGenerated;
1807 })
1808 .ok()?;
1809 return None;
1810 };
1811
1812 let mut new_detailed_summary = String::new();
1813
1814 while let Some(chunk) = messages.stream.next().await {
1815 if let Some(chunk) = chunk.log_err() {
1816 new_detailed_summary.push_str(&chunk);
1817 }
1818 }
1819
1820 thread
1821 .update(cx, |thread, _cx| {
1822 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1823 text: new_detailed_summary.into(),
1824 message_id: last_message_id,
1825 };
1826 })
1827 .ok()?;
1828
1829 // Save thread so its summary can be reused later
1830 if let Some(thread) = thread.upgrade() {
1831 if let Ok(Ok(save_task)) = cx.update(|cx| {
1832 thread_store
1833 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1834 }) {
1835 save_task.await.log_err();
1836 }
1837 }
1838
1839 Some(())
1840 });
1841 }
1842
1843 pub async fn wait_for_detailed_summary_or_text(
1844 this: &Entity<Self>,
1845 cx: &mut AsyncApp,
1846 ) -> Option<SharedString> {
1847 let mut detailed_summary_rx = this
1848 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1849 .ok()?;
1850 loop {
1851 match detailed_summary_rx.recv().await? {
1852 DetailedSummaryState::Generating { .. } => {}
1853 DetailedSummaryState::NotGenerated => {
1854 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1855 }
1856 DetailedSummaryState::Generated { text, .. } => return Some(text),
1857 }
1858 }
1859 }
1860
1861 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1862 self.detailed_summary_rx
1863 .borrow()
1864 .text()
1865 .unwrap_or_else(|| self.text().into())
1866 }
1867
1868 pub fn is_generating_detailed_summary(&self) -> bool {
1869 matches!(
1870 &*self.detailed_summary_rx.borrow(),
1871 DetailedSummaryState::Generating { .. }
1872 )
1873 }
1874
1875 pub fn use_pending_tools(
1876 &mut self,
1877 window: Option<AnyWindowHandle>,
1878 cx: &mut Context<Self>,
1879 model: Arc<dyn LanguageModel>,
1880 ) -> Vec<PendingToolUse> {
1881 self.auto_capture_telemetry(cx);
1882 let request = self.to_completion_request(model, cx);
1883 let messages = Arc::new(request.messages);
1884 let pending_tool_uses = self
1885 .tool_use
1886 .pending_tool_uses()
1887 .into_iter()
1888 .filter(|tool_use| tool_use.status.is_idle())
1889 .cloned()
1890 .collect::<Vec<_>>();
1891
1892 for tool_use in pending_tool_uses.iter() {
1893 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1894 if tool.needs_confirmation(&tool_use.input, cx)
1895 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1896 {
1897 self.tool_use.confirm_tool_use(
1898 tool_use.id.clone(),
1899 tool_use.ui_text.clone(),
1900 tool_use.input.clone(),
1901 messages.clone(),
1902 tool,
1903 );
1904 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1905 } else {
1906 self.run_tool(
1907 tool_use.id.clone(),
1908 tool_use.ui_text.clone(),
1909 tool_use.input.clone(),
1910 &messages,
1911 tool,
1912 window,
1913 cx,
1914 );
1915 }
1916 } else {
1917 self.handle_hallucinated_tool_use(
1918 tool_use.id.clone(),
1919 tool_use.name.clone(),
1920 window,
1921 cx,
1922 );
1923 }
1924 }
1925
1926 pending_tool_uses
1927 }
1928
1929 pub fn handle_hallucinated_tool_use(
1930 &mut self,
1931 tool_use_id: LanguageModelToolUseId,
1932 hallucinated_tool_name: Arc<str>,
1933 window: Option<AnyWindowHandle>,
1934 cx: &mut Context<Thread>,
1935 ) {
1936 let available_tools = self.tools.read(cx).enabled_tools(cx);
1937
1938 let tool_list = available_tools
1939 .iter()
1940 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1941 .collect::<Vec<_>>()
1942 .join("\n");
1943
1944 let error_message = format!(
1945 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1946 hallucinated_tool_name, tool_list
1947 );
1948
1949 let pending_tool_use = self.tool_use.insert_tool_output(
1950 tool_use_id.clone(),
1951 hallucinated_tool_name,
1952 Err(anyhow!("Missing tool call: {error_message}")),
1953 self.configured_model.as_ref(),
1954 );
1955
1956 cx.emit(ThreadEvent::MissingToolUse {
1957 tool_use_id: tool_use_id.clone(),
1958 ui_text: error_message.into(),
1959 });
1960
1961 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1962 }
1963
1964 pub fn receive_invalid_tool_json(
1965 &mut self,
1966 tool_use_id: LanguageModelToolUseId,
1967 tool_name: Arc<str>,
1968 invalid_json: Arc<str>,
1969 error: String,
1970 window: Option<AnyWindowHandle>,
1971 cx: &mut Context<Thread>,
1972 ) {
1973 log::error!("The model returned invalid input JSON: {invalid_json}");
1974
1975 let pending_tool_use = self.tool_use.insert_tool_output(
1976 tool_use_id.clone(),
1977 tool_name,
1978 Err(anyhow!("Error parsing input JSON: {error}")),
1979 self.configured_model.as_ref(),
1980 );
1981 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
1982 pending_tool_use.ui_text.clone()
1983 } else {
1984 log::error!(
1985 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
1986 );
1987 format!("Unknown tool {}", tool_use_id).into()
1988 };
1989
1990 cx.emit(ThreadEvent::InvalidToolInput {
1991 tool_use_id: tool_use_id.clone(),
1992 ui_text,
1993 invalid_input_json: invalid_json,
1994 });
1995
1996 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
1997 }
1998
1999 pub fn run_tool(
2000 &mut self,
2001 tool_use_id: LanguageModelToolUseId,
2002 ui_text: impl Into<SharedString>,
2003 input: serde_json::Value,
2004 messages: &[LanguageModelRequestMessage],
2005 tool: Arc<dyn Tool>,
2006 window: Option<AnyWindowHandle>,
2007 cx: &mut Context<Thread>,
2008 ) {
2009 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
2010 self.tool_use
2011 .run_pending_tool(tool_use_id, ui_text.into(), task);
2012 }
2013
2014 fn spawn_tool_use(
2015 &mut self,
2016 tool_use_id: LanguageModelToolUseId,
2017 messages: &[LanguageModelRequestMessage],
2018 input: serde_json::Value,
2019 tool: Arc<dyn Tool>,
2020 window: Option<AnyWindowHandle>,
2021 cx: &mut Context<Thread>,
2022 ) -> Task<()> {
2023 let tool_name: Arc<str> = tool.name().into();
2024
2025 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2026 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2027 } else {
2028 tool.run(
2029 input,
2030 messages,
2031 self.project.clone(),
2032 self.action_log.clone(),
2033 window,
2034 cx,
2035 )
2036 };
2037
2038 // Store the card separately if it exists
2039 if let Some(card) = tool_result.card.clone() {
2040 self.tool_use
2041 .insert_tool_result_card(tool_use_id.clone(), card);
2042 }
2043
2044 cx.spawn({
2045 async move |thread: WeakEntity<Thread>, cx| {
2046 let output = tool_result.output.await;
2047
2048 thread
2049 .update(cx, |thread, cx| {
2050 let pending_tool_use = thread.tool_use.insert_tool_output(
2051 tool_use_id.clone(),
2052 tool_name,
2053 output,
2054 thread.configured_model.as_ref(),
2055 );
2056 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2057 })
2058 .ok();
2059 }
2060 })
2061 }
2062
2063 fn tool_finished(
2064 &mut self,
2065 tool_use_id: LanguageModelToolUseId,
2066 pending_tool_use: Option<PendingToolUse>,
2067 canceled: bool,
2068 window: Option<AnyWindowHandle>,
2069 cx: &mut Context<Self>,
2070 ) {
2071 if self.all_tools_finished() {
2072 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2073 if !canceled {
2074 self.send_to_model(model.clone(), window, cx);
2075 }
2076 self.auto_capture_telemetry(cx);
2077 }
2078 }
2079
2080 cx.emit(ThreadEvent::ToolFinished {
2081 tool_use_id,
2082 pending_tool_use,
2083 });
2084 }
2085
2086 /// Cancels the last pending completion, if there are any pending.
2087 ///
2088 /// Returns whether a completion was canceled.
2089 pub fn cancel_last_completion(
2090 &mut self,
2091 window: Option<AnyWindowHandle>,
2092 cx: &mut Context<Self>,
2093 ) -> bool {
2094 let mut canceled = self.pending_completions.pop().is_some();
2095
2096 for pending_tool_use in self.tool_use.cancel_pending() {
2097 canceled = true;
2098 self.tool_finished(
2099 pending_tool_use.id.clone(),
2100 Some(pending_tool_use),
2101 true,
2102 window,
2103 cx,
2104 );
2105 }
2106
2107 self.finalize_pending_checkpoint(cx);
2108
2109 if canceled {
2110 cx.emit(ThreadEvent::CompletionCanceled);
2111 }
2112
2113 canceled
2114 }
2115
2116 /// Signals that any in-progress editing should be canceled.
2117 ///
2118 /// This method is used to notify listeners (like ActiveThread) that
2119 /// they should cancel any editing operations.
2120 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2121 cx.emit(ThreadEvent::CancelEditing);
2122 }
2123
2124 pub fn feedback(&self) -> Option<ThreadFeedback> {
2125 self.feedback
2126 }
2127
2128 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2129 self.message_feedback.get(&message_id).copied()
2130 }
2131
2132 pub fn report_message_feedback(
2133 &mut self,
2134 message_id: MessageId,
2135 feedback: ThreadFeedback,
2136 cx: &mut Context<Self>,
2137 ) -> Task<Result<()>> {
2138 if self.message_feedback.get(&message_id) == Some(&feedback) {
2139 return Task::ready(Ok(()));
2140 }
2141
2142 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2143 let serialized_thread = self.serialize(cx);
2144 let thread_id = self.id().clone();
2145 let client = self.project.read(cx).client();
2146
2147 let enabled_tool_names: Vec<String> = self
2148 .tools()
2149 .read(cx)
2150 .enabled_tools(cx)
2151 .iter()
2152 .map(|tool| tool.name().to_string())
2153 .collect();
2154
2155 self.message_feedback.insert(message_id, feedback);
2156
2157 cx.notify();
2158
2159 let message_content = self
2160 .message(message_id)
2161 .map(|msg| msg.to_string())
2162 .unwrap_or_default();
2163
2164 cx.background_spawn(async move {
2165 let final_project_snapshot = final_project_snapshot.await;
2166 let serialized_thread = serialized_thread.await?;
2167 let thread_data =
2168 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2169
2170 let rating = match feedback {
2171 ThreadFeedback::Positive => "positive",
2172 ThreadFeedback::Negative => "negative",
2173 };
2174 telemetry::event!(
2175 "Assistant Thread Rated",
2176 rating,
2177 thread_id,
2178 enabled_tool_names,
2179 message_id = message_id.0,
2180 message_content,
2181 thread_data,
2182 final_project_snapshot
2183 );
2184 client.telemetry().flush_events().await;
2185
2186 Ok(())
2187 })
2188 }
2189
2190 pub fn report_feedback(
2191 &mut self,
2192 feedback: ThreadFeedback,
2193 cx: &mut Context<Self>,
2194 ) -> Task<Result<()>> {
2195 let last_assistant_message_id = self
2196 .messages
2197 .iter()
2198 .rev()
2199 .find(|msg| msg.role == Role::Assistant)
2200 .map(|msg| msg.id);
2201
2202 if let Some(message_id) = last_assistant_message_id {
2203 self.report_message_feedback(message_id, feedback, cx)
2204 } else {
2205 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2206 let serialized_thread = self.serialize(cx);
2207 let thread_id = self.id().clone();
2208 let client = self.project.read(cx).client();
2209 self.feedback = Some(feedback);
2210 cx.notify();
2211
2212 cx.background_spawn(async move {
2213 let final_project_snapshot = final_project_snapshot.await;
2214 let serialized_thread = serialized_thread.await?;
2215 let thread_data = serde_json::to_value(serialized_thread)
2216 .unwrap_or_else(|_| serde_json::Value::Null);
2217
2218 let rating = match feedback {
2219 ThreadFeedback::Positive => "positive",
2220 ThreadFeedback::Negative => "negative",
2221 };
2222 telemetry::event!(
2223 "Assistant Thread Rated",
2224 rating,
2225 thread_id,
2226 thread_data,
2227 final_project_snapshot
2228 );
2229 client.telemetry().flush_events().await;
2230
2231 Ok(())
2232 })
2233 }
2234 }
2235
2236 /// Create a snapshot of the current project state including git information and unsaved buffers.
2237 fn project_snapshot(
2238 project: Entity<Project>,
2239 cx: &mut Context<Self>,
2240 ) -> Task<Arc<ProjectSnapshot>> {
2241 let git_store = project.read(cx).git_store().clone();
2242 let worktree_snapshots: Vec<_> = project
2243 .read(cx)
2244 .visible_worktrees(cx)
2245 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2246 .collect();
2247
2248 cx.spawn(async move |_, cx| {
2249 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2250
2251 let mut unsaved_buffers = Vec::new();
2252 cx.update(|app_cx| {
2253 let buffer_store = project.read(app_cx).buffer_store();
2254 for buffer_handle in buffer_store.read(app_cx).buffers() {
2255 let buffer = buffer_handle.read(app_cx);
2256 if buffer.is_dirty() {
2257 if let Some(file) = buffer.file() {
2258 let path = file.path().to_string_lossy().to_string();
2259 unsaved_buffers.push(path);
2260 }
2261 }
2262 }
2263 })
2264 .ok();
2265
2266 Arc::new(ProjectSnapshot {
2267 worktree_snapshots,
2268 unsaved_buffer_paths: unsaved_buffers,
2269 timestamp: Utc::now(),
2270 })
2271 })
2272 }
2273
2274 fn worktree_snapshot(
2275 worktree: Entity<project::Worktree>,
2276 git_store: Entity<GitStore>,
2277 cx: &App,
2278 ) -> Task<WorktreeSnapshot> {
2279 cx.spawn(async move |cx| {
2280 // Get worktree path and snapshot
2281 let worktree_info = cx.update(|app_cx| {
2282 let worktree = worktree.read(app_cx);
2283 let path = worktree.abs_path().to_string_lossy().to_string();
2284 let snapshot = worktree.snapshot();
2285 (path, snapshot)
2286 });
2287
2288 let Ok((worktree_path, _snapshot)) = worktree_info else {
2289 return WorktreeSnapshot {
2290 worktree_path: String::new(),
2291 git_state: None,
2292 };
2293 };
2294
2295 let git_state = git_store
2296 .update(cx, |git_store, cx| {
2297 git_store
2298 .repositories()
2299 .values()
2300 .find(|repo| {
2301 repo.read(cx)
2302 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2303 .is_some()
2304 })
2305 .cloned()
2306 })
2307 .ok()
2308 .flatten()
2309 .map(|repo| {
2310 repo.update(cx, |repo, _| {
2311 let current_branch =
2312 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2313 repo.send_job(None, |state, _| async move {
2314 let RepositoryState::Local { backend, .. } = state else {
2315 return GitState {
2316 remote_url: None,
2317 head_sha: None,
2318 current_branch,
2319 diff: None,
2320 };
2321 };
2322
2323 let remote_url = backend.remote_url("origin");
2324 let head_sha = backend.head_sha().await;
2325 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2326
2327 GitState {
2328 remote_url,
2329 head_sha,
2330 current_branch,
2331 diff,
2332 }
2333 })
2334 })
2335 });
2336
2337 let git_state = match git_state {
2338 Some(git_state) => match git_state.ok() {
2339 Some(git_state) => git_state.await.ok(),
2340 None => None,
2341 },
2342 None => None,
2343 };
2344
2345 WorktreeSnapshot {
2346 worktree_path,
2347 git_state,
2348 }
2349 })
2350 }
2351
2352 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2353 let mut markdown = Vec::new();
2354
2355 if let Some(summary) = self.summary() {
2356 writeln!(markdown, "# {summary}\n")?;
2357 };
2358
2359 for message in self.messages() {
2360 writeln!(
2361 markdown,
2362 "## {role}\n",
2363 role = match message.role {
2364 Role::User => "User",
2365 Role::Assistant => "Agent",
2366 Role::System => "System",
2367 }
2368 )?;
2369
2370 if !message.loaded_context.text.is_empty() {
2371 writeln!(markdown, "{}", message.loaded_context.text)?;
2372 }
2373
2374 if !message.loaded_context.images.is_empty() {
2375 writeln!(
2376 markdown,
2377 "\n{} images attached as context.\n",
2378 message.loaded_context.images.len()
2379 )?;
2380 }
2381
2382 for segment in &message.segments {
2383 match segment {
2384 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2385 MessageSegment::Thinking { text, .. } => {
2386 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2387 }
2388 MessageSegment::RedactedThinking(_) => {}
2389 }
2390 }
2391
2392 for tool_use in self.tool_uses_for_message(message.id, cx) {
2393 writeln!(
2394 markdown,
2395 "**Use Tool: {} ({})**",
2396 tool_use.name, tool_use.id
2397 )?;
2398 writeln!(markdown, "```json")?;
2399 writeln!(
2400 markdown,
2401 "{}",
2402 serde_json::to_string_pretty(&tool_use.input)?
2403 )?;
2404 writeln!(markdown, "```")?;
2405 }
2406
2407 for tool_result in self.tool_results_for_message(message.id) {
2408 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2409 if tool_result.is_error {
2410 write!(markdown, " (Error)")?;
2411 }
2412
2413 writeln!(markdown, "**\n")?;
2414 writeln!(markdown, "{}", tool_result.content)?;
2415 }
2416 }
2417
2418 Ok(String::from_utf8_lossy(&markdown).to_string())
2419 }
2420
2421 pub fn keep_edits_in_range(
2422 &mut self,
2423 buffer: Entity<language::Buffer>,
2424 buffer_range: Range<language::Anchor>,
2425 cx: &mut Context<Self>,
2426 ) {
2427 self.action_log.update(cx, |action_log, cx| {
2428 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2429 });
2430 }
2431
2432 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2433 self.action_log
2434 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2435 }
2436
2437 pub fn reject_edits_in_ranges(
2438 &mut self,
2439 buffer: Entity<language::Buffer>,
2440 buffer_ranges: Vec<Range<language::Anchor>>,
2441 cx: &mut Context<Self>,
2442 ) -> Task<Result<()>> {
2443 self.action_log.update(cx, |action_log, cx| {
2444 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2445 })
2446 }
2447
2448 pub fn action_log(&self) -> &Entity<ActionLog> {
2449 &self.action_log
2450 }
2451
2452 pub fn project(&self) -> &Entity<Project> {
2453 &self.project
2454 }
2455
2456 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2457 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2458 return;
2459 }
2460
2461 let now = Instant::now();
2462 if let Some(last) = self.last_auto_capture_at {
2463 if now.duration_since(last).as_secs() < 10 {
2464 return;
2465 }
2466 }
2467
2468 self.last_auto_capture_at = Some(now);
2469
2470 let thread_id = self.id().clone();
2471 let github_login = self
2472 .project
2473 .read(cx)
2474 .user_store()
2475 .read(cx)
2476 .current_user()
2477 .map(|user| user.github_login.clone());
2478 let client = self.project.read(cx).client().clone();
2479 let serialize_task = self.serialize(cx);
2480
2481 cx.background_executor()
2482 .spawn(async move {
2483 if let Ok(serialized_thread) = serialize_task.await {
2484 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2485 telemetry::event!(
2486 "Agent Thread Auto-Captured",
2487 thread_id = thread_id.to_string(),
2488 thread_data = thread_data,
2489 auto_capture_reason = "tracked_user",
2490 github_login = github_login
2491 );
2492
2493 client.telemetry().flush_events().await;
2494 }
2495 }
2496 })
2497 .detach();
2498 }
2499
2500 pub fn cumulative_token_usage(&self) -> TokenUsage {
2501 self.cumulative_token_usage
2502 }
2503
2504 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2505 let Some(model) = self.configured_model.as_ref() else {
2506 return TotalTokenUsage::default();
2507 };
2508
2509 let max = model.model.max_token_count();
2510
2511 let index = self
2512 .messages
2513 .iter()
2514 .position(|msg| msg.id == message_id)
2515 .unwrap_or(0);
2516
2517 if index == 0 {
2518 return TotalTokenUsage { total: 0, max };
2519 }
2520
2521 let token_usage = &self
2522 .request_token_usage
2523 .get(index - 1)
2524 .cloned()
2525 .unwrap_or_default();
2526
2527 TotalTokenUsage {
2528 total: token_usage.total_tokens() as usize,
2529 max,
2530 }
2531 }
2532
2533 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2534 let model = self.configured_model.as_ref()?;
2535
2536 let max = model.model.max_token_count();
2537
2538 if let Some(exceeded_error) = &self.exceeded_window_error {
2539 if model.model.id() == exceeded_error.model_id {
2540 return Some(TotalTokenUsage {
2541 total: exceeded_error.token_count,
2542 max,
2543 });
2544 }
2545 }
2546
2547 let total = self
2548 .token_usage_at_last_message()
2549 .unwrap_or_default()
2550 .total_tokens() as usize;
2551
2552 Some(TotalTokenUsage { total, max })
2553 }
2554
2555 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2556 self.request_token_usage
2557 .get(self.messages.len().saturating_sub(1))
2558 .or_else(|| self.request_token_usage.last())
2559 .cloned()
2560 }
2561
2562 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2563 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2564 self.request_token_usage
2565 .resize(self.messages.len(), placeholder);
2566
2567 if let Some(last) = self.request_token_usage.last_mut() {
2568 *last = token_usage;
2569 }
2570 }
2571
2572 pub fn deny_tool_use(
2573 &mut self,
2574 tool_use_id: LanguageModelToolUseId,
2575 tool_name: Arc<str>,
2576 window: Option<AnyWindowHandle>,
2577 cx: &mut Context<Self>,
2578 ) {
2579 let err = Err(anyhow::anyhow!(
2580 "Permission to run tool action denied by user"
2581 ));
2582
2583 self.tool_use.insert_tool_output(
2584 tool_use_id.clone(),
2585 tool_name,
2586 err,
2587 self.configured_model.as_ref(),
2588 );
2589 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2590 }
2591}
2592
2593#[derive(Debug, Clone, Error)]
2594pub enum ThreadError {
2595 #[error("Payment required")]
2596 PaymentRequired,
2597 #[error("Max monthly spend reached")]
2598 MaxMonthlySpendReached,
2599 #[error("Model request limit reached")]
2600 ModelRequestLimitReached { plan: Plan },
2601 #[error("Message {header}: {message}")]
2602 Message {
2603 header: SharedString,
2604 message: SharedString,
2605 },
2606}
2607
2608#[derive(Debug, Clone)]
2609pub enum ThreadEvent {
2610 ShowError(ThreadError),
2611 StreamedCompletion,
2612 ReceivedTextChunk,
2613 NewRequest,
2614 StreamedAssistantText(MessageId, String),
2615 StreamedAssistantThinking(MessageId, String),
2616 StreamedToolUse {
2617 tool_use_id: LanguageModelToolUseId,
2618 ui_text: Arc<str>,
2619 input: serde_json::Value,
2620 },
2621 MissingToolUse {
2622 tool_use_id: LanguageModelToolUseId,
2623 ui_text: Arc<str>,
2624 },
2625 InvalidToolInput {
2626 tool_use_id: LanguageModelToolUseId,
2627 ui_text: Arc<str>,
2628 invalid_input_json: Arc<str>,
2629 },
2630 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2631 MessageAdded(MessageId),
2632 MessageEdited(MessageId),
2633 MessageDeleted(MessageId),
2634 SummaryGenerated,
2635 SummaryChanged,
2636 UsePendingTools {
2637 tool_uses: Vec<PendingToolUse>,
2638 },
2639 ToolFinished {
2640 #[allow(unused)]
2641 tool_use_id: LanguageModelToolUseId,
2642 /// The pending tool use that corresponds to this tool.
2643 pending_tool_use: Option<PendingToolUse>,
2644 },
2645 CheckpointChanged,
2646 ToolConfirmationNeeded,
2647 CancelEditing,
2648 CompletionCanceled,
2649}
2650
2651impl EventEmitter<ThreadEvent> for Thread {}
2652
2653struct PendingCompletion {
2654 id: usize,
2655 queue_state: QueueState,
2656 _task: Task<()>,
2657}
2658
2659#[cfg(test)]
2660mod tests {
2661 use super::*;
2662 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2663 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2664 use assistant_tool::ToolRegistry;
2665 use editor::EditorSettings;
2666 use gpui::TestAppContext;
2667 use language_model::fake_provider::FakeLanguageModel;
2668 use project::{FakeFs, Project};
2669 use prompt_store::PromptBuilder;
2670 use serde_json::json;
2671 use settings::{Settings, SettingsStore};
2672 use std::sync::Arc;
2673 use theme::ThemeSettings;
2674 use util::path;
2675 use workspace::Workspace;
2676
2677 #[gpui::test]
2678 async fn test_message_with_context(cx: &mut TestAppContext) {
2679 init_test_settings(cx);
2680
2681 let project = create_test_project(
2682 cx,
2683 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2684 )
2685 .await;
2686
2687 let (_workspace, _thread_store, thread, context_store, model) =
2688 setup_test_environment(cx, project.clone()).await;
2689
2690 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2691 .await
2692 .unwrap();
2693
2694 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2695 let loaded_context = cx
2696 .update(|cx| load_context(vec![context], &project, &None, cx))
2697 .await;
2698
2699 // Insert user message with context
2700 let message_id = thread.update(cx, |thread, cx| {
2701 thread.insert_user_message(
2702 "Please explain this code",
2703 loaded_context,
2704 None,
2705 Vec::new(),
2706 cx,
2707 )
2708 });
2709
2710 // Check content and context in message object
2711 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2712
2713 // Use different path format strings based on platform for the test
2714 #[cfg(windows)]
2715 let path_part = r"test\code.rs";
2716 #[cfg(not(windows))]
2717 let path_part = "test/code.rs";
2718
2719 let expected_context = format!(
2720 r#"
2721<context>
2722The following items were attached by the user. They are up-to-date and don't need to be re-read.
2723
2724<files>
2725```rs {path_part}
2726fn main() {{
2727 println!("Hello, world!");
2728}}
2729```
2730</files>
2731</context>
2732"#
2733 );
2734
2735 assert_eq!(message.role, Role::User);
2736 assert_eq!(message.segments.len(), 1);
2737 assert_eq!(
2738 message.segments[0],
2739 MessageSegment::Text("Please explain this code".to_string())
2740 );
2741 assert_eq!(message.loaded_context.text, expected_context);
2742
2743 // Check message in request
2744 let request = thread.update(cx, |thread, cx| {
2745 thread.to_completion_request(model.clone(), cx)
2746 });
2747
2748 assert_eq!(request.messages.len(), 2);
2749 let expected_full_message = format!("{}Please explain this code", expected_context);
2750 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2751 }
2752
2753 #[gpui::test]
2754 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2755 init_test_settings(cx);
2756
2757 let project = create_test_project(
2758 cx,
2759 json!({
2760 "file1.rs": "fn function1() {}\n",
2761 "file2.rs": "fn function2() {}\n",
2762 "file3.rs": "fn function3() {}\n",
2763 "file4.rs": "fn function4() {}\n",
2764 }),
2765 )
2766 .await;
2767
2768 let (_, _thread_store, thread, context_store, model) =
2769 setup_test_environment(cx, project.clone()).await;
2770
2771 // First message with context 1
2772 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2773 .await
2774 .unwrap();
2775 let new_contexts = context_store.update(cx, |store, cx| {
2776 store.new_context_for_thread(thread.read(cx), None)
2777 });
2778 assert_eq!(new_contexts.len(), 1);
2779 let loaded_context = cx
2780 .update(|cx| load_context(new_contexts, &project, &None, cx))
2781 .await;
2782 let message1_id = thread.update(cx, |thread, cx| {
2783 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2784 });
2785
2786 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2787 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2788 .await
2789 .unwrap();
2790 let new_contexts = context_store.update(cx, |store, cx| {
2791 store.new_context_for_thread(thread.read(cx), None)
2792 });
2793 assert_eq!(new_contexts.len(), 1);
2794 let loaded_context = cx
2795 .update(|cx| load_context(new_contexts, &project, &None, cx))
2796 .await;
2797 let message2_id = thread.update(cx, |thread, cx| {
2798 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2799 });
2800
2801 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2802 //
2803 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2804 .await
2805 .unwrap();
2806 let new_contexts = context_store.update(cx, |store, cx| {
2807 store.new_context_for_thread(thread.read(cx), None)
2808 });
2809 assert_eq!(new_contexts.len(), 1);
2810 let loaded_context = cx
2811 .update(|cx| load_context(new_contexts, &project, &None, cx))
2812 .await;
2813 let message3_id = thread.update(cx, |thread, cx| {
2814 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2815 });
2816
2817 // Check what contexts are included in each message
2818 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2819 (
2820 thread.message(message1_id).unwrap().clone(),
2821 thread.message(message2_id).unwrap().clone(),
2822 thread.message(message3_id).unwrap().clone(),
2823 )
2824 });
2825
2826 // First message should include context 1
2827 assert!(message1.loaded_context.text.contains("file1.rs"));
2828
2829 // Second message should include only context 2 (not 1)
2830 assert!(!message2.loaded_context.text.contains("file1.rs"));
2831 assert!(message2.loaded_context.text.contains("file2.rs"));
2832
2833 // Third message should include only context 3 (not 1 or 2)
2834 assert!(!message3.loaded_context.text.contains("file1.rs"));
2835 assert!(!message3.loaded_context.text.contains("file2.rs"));
2836 assert!(message3.loaded_context.text.contains("file3.rs"));
2837
2838 // Check entire request to make sure all contexts are properly included
2839 let request = thread.update(cx, |thread, cx| {
2840 thread.to_completion_request(model.clone(), cx)
2841 });
2842
2843 // The request should contain all 3 messages
2844 assert_eq!(request.messages.len(), 4);
2845
2846 // Check that the contexts are properly formatted in each message
2847 assert!(request.messages[1].string_contents().contains("file1.rs"));
2848 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2849 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2850
2851 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2852 assert!(request.messages[2].string_contents().contains("file2.rs"));
2853 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2854
2855 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2856 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2857 assert!(request.messages[3].string_contents().contains("file3.rs"));
2858
2859 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2860 .await
2861 .unwrap();
2862 let new_contexts = context_store.update(cx, |store, cx| {
2863 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2864 });
2865 assert_eq!(new_contexts.len(), 3);
2866 let loaded_context = cx
2867 .update(|cx| load_context(new_contexts, &project, &None, cx))
2868 .await
2869 .loaded_context;
2870
2871 assert!(!loaded_context.text.contains("file1.rs"));
2872 assert!(loaded_context.text.contains("file2.rs"));
2873 assert!(loaded_context.text.contains("file3.rs"));
2874 assert!(loaded_context.text.contains("file4.rs"));
2875
2876 let new_contexts = context_store.update(cx, |store, cx| {
2877 // Remove file4.rs
2878 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2879 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2880 });
2881 assert_eq!(new_contexts.len(), 2);
2882 let loaded_context = cx
2883 .update(|cx| load_context(new_contexts, &project, &None, cx))
2884 .await
2885 .loaded_context;
2886
2887 assert!(!loaded_context.text.contains("file1.rs"));
2888 assert!(loaded_context.text.contains("file2.rs"));
2889 assert!(loaded_context.text.contains("file3.rs"));
2890 assert!(!loaded_context.text.contains("file4.rs"));
2891
2892 let new_contexts = context_store.update(cx, |store, cx| {
2893 // Remove file3.rs
2894 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2895 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2896 });
2897 assert_eq!(new_contexts.len(), 1);
2898 let loaded_context = cx
2899 .update(|cx| load_context(new_contexts, &project, &None, cx))
2900 .await
2901 .loaded_context;
2902
2903 assert!(!loaded_context.text.contains("file1.rs"));
2904 assert!(loaded_context.text.contains("file2.rs"));
2905 assert!(!loaded_context.text.contains("file3.rs"));
2906 assert!(!loaded_context.text.contains("file4.rs"));
2907 }
2908
2909 #[gpui::test]
2910 async fn test_message_without_files(cx: &mut TestAppContext) {
2911 init_test_settings(cx);
2912
2913 let project = create_test_project(
2914 cx,
2915 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2916 )
2917 .await;
2918
2919 let (_, _thread_store, thread, _context_store, model) =
2920 setup_test_environment(cx, project.clone()).await;
2921
2922 // Insert user message without any context (empty context vector)
2923 let message_id = thread.update(cx, |thread, cx| {
2924 thread.insert_user_message(
2925 "What is the best way to learn Rust?",
2926 ContextLoadResult::default(),
2927 None,
2928 Vec::new(),
2929 cx,
2930 )
2931 });
2932
2933 // Check content and context in message object
2934 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2935
2936 // Context should be empty when no files are included
2937 assert_eq!(message.role, Role::User);
2938 assert_eq!(message.segments.len(), 1);
2939 assert_eq!(
2940 message.segments[0],
2941 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2942 );
2943 assert_eq!(message.loaded_context.text, "");
2944
2945 // Check message in request
2946 let request = thread.update(cx, |thread, cx| {
2947 thread.to_completion_request(model.clone(), cx)
2948 });
2949
2950 assert_eq!(request.messages.len(), 2);
2951 assert_eq!(
2952 request.messages[1].string_contents(),
2953 "What is the best way to learn Rust?"
2954 );
2955
2956 // Add second message, also without context
2957 let message2_id = thread.update(cx, |thread, cx| {
2958 thread.insert_user_message(
2959 "Are there any good books?",
2960 ContextLoadResult::default(),
2961 None,
2962 Vec::new(),
2963 cx,
2964 )
2965 });
2966
2967 let message2 =
2968 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2969 assert_eq!(message2.loaded_context.text, "");
2970
2971 // Check that both messages appear in the request
2972 let request = thread.update(cx, |thread, cx| {
2973 thread.to_completion_request(model.clone(), cx)
2974 });
2975
2976 assert_eq!(request.messages.len(), 3);
2977 assert_eq!(
2978 request.messages[1].string_contents(),
2979 "What is the best way to learn Rust?"
2980 );
2981 assert_eq!(
2982 request.messages[2].string_contents(),
2983 "Are there any good books?"
2984 );
2985 }
2986
2987 #[gpui::test]
2988 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2989 init_test_settings(cx);
2990
2991 let project = create_test_project(
2992 cx,
2993 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2994 )
2995 .await;
2996
2997 let (_workspace, _thread_store, thread, context_store, model) =
2998 setup_test_environment(cx, project.clone()).await;
2999
3000 // Open buffer and add it to context
3001 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3002 .await
3003 .unwrap();
3004
3005 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3006 let loaded_context = cx
3007 .update(|cx| load_context(vec![context], &project, &None, cx))
3008 .await;
3009
3010 // Insert user message with the buffer as context
3011 thread.update(cx, |thread, cx| {
3012 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3013 });
3014
3015 // Create a request and check that it doesn't have a stale buffer warning yet
3016 let initial_request = thread.update(cx, |thread, cx| {
3017 thread.to_completion_request(model.clone(), cx)
3018 });
3019
3020 // Make sure we don't have a stale file warning yet
3021 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3022 msg.string_contents()
3023 .contains("These files changed since last read:")
3024 });
3025 assert!(
3026 !has_stale_warning,
3027 "Should not have stale buffer warning before buffer is modified"
3028 );
3029
3030 // Modify the buffer
3031 buffer.update(cx, |buffer, cx| {
3032 // Find a position at the end of line 1
3033 buffer.edit(
3034 [(1..1, "\n println!(\"Added a new line\");\n")],
3035 None,
3036 cx,
3037 );
3038 });
3039
3040 // Insert another user message without context
3041 thread.update(cx, |thread, cx| {
3042 thread.insert_user_message(
3043 "What does the code do now?",
3044 ContextLoadResult::default(),
3045 None,
3046 Vec::new(),
3047 cx,
3048 )
3049 });
3050
3051 // Create a new request and check for the stale buffer warning
3052 let new_request = thread.update(cx, |thread, cx| {
3053 thread.to_completion_request(model.clone(), cx)
3054 });
3055
3056 // We should have a stale file warning as the last message
3057 let last_message = new_request
3058 .messages
3059 .last()
3060 .expect("Request should have messages");
3061
3062 // The last message should be the stale buffer notification
3063 assert_eq!(last_message.role, Role::User);
3064
3065 // Check the exact content of the message
3066 let expected_content = "These files changed since last read:\n- code.rs\n";
3067 assert_eq!(
3068 last_message.string_contents(),
3069 expected_content,
3070 "Last message should be exactly the stale buffer notification"
3071 );
3072 }
3073
3074 #[gpui::test]
3075 async fn test_temperature_setting(cx: &mut TestAppContext) {
3076 init_test_settings(cx);
3077
3078 let project = create_test_project(
3079 cx,
3080 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3081 )
3082 .await;
3083
3084 let (_workspace, _thread_store, thread, _context_store, model) =
3085 setup_test_environment(cx, project.clone()).await;
3086
3087 // Both model and provider
3088 cx.update(|cx| {
3089 AssistantSettings::override_global(
3090 AssistantSettings {
3091 model_parameters: vec![LanguageModelParameters {
3092 provider: Some(model.provider_id().0.to_string().into()),
3093 model: Some(model.id().0.clone()),
3094 temperature: Some(0.66),
3095 }],
3096 ..AssistantSettings::get_global(cx).clone()
3097 },
3098 cx,
3099 );
3100 });
3101
3102 let request = thread.update(cx, |thread, cx| {
3103 thread.to_completion_request(model.clone(), cx)
3104 });
3105 assert_eq!(request.temperature, Some(0.66));
3106
3107 // Only model
3108 cx.update(|cx| {
3109 AssistantSettings::override_global(
3110 AssistantSettings {
3111 model_parameters: vec![LanguageModelParameters {
3112 provider: None,
3113 model: Some(model.id().0.clone()),
3114 temperature: Some(0.66),
3115 }],
3116 ..AssistantSettings::get_global(cx).clone()
3117 },
3118 cx,
3119 );
3120 });
3121
3122 let request = thread.update(cx, |thread, cx| {
3123 thread.to_completion_request(model.clone(), cx)
3124 });
3125 assert_eq!(request.temperature, Some(0.66));
3126
3127 // Only provider
3128 cx.update(|cx| {
3129 AssistantSettings::override_global(
3130 AssistantSettings {
3131 model_parameters: vec![LanguageModelParameters {
3132 provider: Some(model.provider_id().0.to_string().into()),
3133 model: None,
3134 temperature: Some(0.66),
3135 }],
3136 ..AssistantSettings::get_global(cx).clone()
3137 },
3138 cx,
3139 );
3140 });
3141
3142 let request = thread.update(cx, |thread, cx| {
3143 thread.to_completion_request(model.clone(), cx)
3144 });
3145 assert_eq!(request.temperature, Some(0.66));
3146
3147 // Same model name, different provider
3148 cx.update(|cx| {
3149 AssistantSettings::override_global(
3150 AssistantSettings {
3151 model_parameters: vec![LanguageModelParameters {
3152 provider: Some("anthropic".into()),
3153 model: Some(model.id().0.clone()),
3154 temperature: Some(0.66),
3155 }],
3156 ..AssistantSettings::get_global(cx).clone()
3157 },
3158 cx,
3159 );
3160 });
3161
3162 let request = thread.update(cx, |thread, cx| {
3163 thread.to_completion_request(model.clone(), cx)
3164 });
3165 assert_eq!(request.temperature, None);
3166 }
3167
3168 fn init_test_settings(cx: &mut TestAppContext) {
3169 cx.update(|cx| {
3170 let settings_store = SettingsStore::test(cx);
3171 cx.set_global(settings_store);
3172 language::init(cx);
3173 Project::init_settings(cx);
3174 AssistantSettings::register(cx);
3175 prompt_store::init(cx);
3176 thread_store::init(cx);
3177 workspace::init_settings(cx);
3178 language_model::init_settings(cx);
3179 ThemeSettings::register(cx);
3180 EditorSettings::register(cx);
3181 ToolRegistry::default_global(cx);
3182 });
3183 }
3184
3185 // Helper to create a test project with test files
3186 async fn create_test_project(
3187 cx: &mut TestAppContext,
3188 files: serde_json::Value,
3189 ) -> Entity<Project> {
3190 let fs = FakeFs::new(cx.executor());
3191 fs.insert_tree(path!("/test"), files).await;
3192 Project::test(fs, [path!("/test").as_ref()], cx).await
3193 }
3194
3195 async fn setup_test_environment(
3196 cx: &mut TestAppContext,
3197 project: Entity<Project>,
3198 ) -> (
3199 Entity<Workspace>,
3200 Entity<ThreadStore>,
3201 Entity<Thread>,
3202 Entity<ContextStore>,
3203 Arc<dyn LanguageModel>,
3204 ) {
3205 let (workspace, cx) =
3206 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3207
3208 let thread_store = cx
3209 .update(|_, cx| {
3210 ThreadStore::load(
3211 project.clone(),
3212 cx.new(|_| ToolWorkingSet::default()),
3213 None,
3214 Arc::new(PromptBuilder::new(None).unwrap()),
3215 cx,
3216 )
3217 })
3218 .await
3219 .unwrap();
3220
3221 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3222 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3223
3224 let model = FakeLanguageModel::default();
3225 let model: Arc<dyn LanguageModel> = Arc::new(model);
3226
3227 (workspace, thread_store, thread, context_store, model)
3228 }
3229
3230 async fn add_file_to_context(
3231 project: &Entity<Project>,
3232 context_store: &Entity<ContextStore>,
3233 path: &str,
3234 cx: &mut TestAppContext,
3235 ) -> Result<Entity<language::Buffer>> {
3236 let buffer_path = project
3237 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3238 .unwrap();
3239
3240 let buffer = project
3241 .update(cx, |project, cx| {
3242 project.open_buffer(buffer_path.clone(), cx)
3243 })
3244 .await
3245 .unwrap();
3246
3247 context_store.update(cx, |context_store, cx| {
3248 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3249 });
3250
3251 Ok(buffer)
3252 }
3253}