1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, TryFutureExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: Option<SharedString>,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362 configured_profile: Option<AgentProfile>,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct ExceededWindowError {
367 /// Model used when last message exceeded context window
368 model_id: LanguageModelId,
369 /// Token count including last message
370 token_count: usize,
371}
372
373impl Thread {
374 pub fn new(
375 project: Entity<Project>,
376 tools: Entity<ToolWorkingSet>,
377 prompt_builder: Arc<PromptBuilder>,
378 system_prompt: SharedProjectContext,
379 cx: &mut Context<Self>,
380 ) -> Self {
381 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
382 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
383 let assistant_settings = AssistantSettings::get_global(cx);
384 let profile_id = &assistant_settings.default_profile;
385 let configured_profile = assistant_settings.profiles.get(profile_id).cloned();
386
387 Self {
388 id: ThreadId::new(),
389 updated_at: Utc::now(),
390 summary: None,
391 pending_summary: Task::ready(None),
392 detailed_summary_task: Task::ready(None),
393 detailed_summary_tx,
394 detailed_summary_rx,
395 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
396 messages: Vec::new(),
397 next_message_id: MessageId(0),
398 last_prompt_id: PromptId::new(),
399 project_context: system_prompt,
400 checkpoints_by_message: HashMap::default(),
401 completion_count: 0,
402 pending_completions: Vec::new(),
403 project: project.clone(),
404 prompt_builder,
405 tools: tools.clone(),
406 last_restore_checkpoint: None,
407 pending_checkpoint: None,
408 tool_use: ToolUseState::new(tools.clone()),
409 action_log: cx.new(|_| ActionLog::new(project.clone())),
410 initial_project_snapshot: {
411 let project_snapshot = Self::project_snapshot(project, cx);
412 cx.foreground_executor()
413 .spawn(async move { Some(project_snapshot.await) })
414 .shared()
415 },
416 request_token_usage: Vec::new(),
417 cumulative_token_usage: TokenUsage::default(),
418 exceeded_window_error: None,
419 last_usage: None,
420 tool_use_limit_reached: false,
421 feedback: None,
422 message_feedback: HashMap::default(),
423 last_auto_capture_at: None,
424 last_received_chunk_at: None,
425 request_callback: None,
426 remaining_turns: u32::MAX,
427 configured_model,
428 configured_profile,
429 }
430 }
431
432 pub fn deserialize(
433 id: ThreadId,
434 serialized: SerializedThread,
435 project: Entity<Project>,
436 tools: Entity<ToolWorkingSet>,
437 prompt_builder: Arc<PromptBuilder>,
438 project_context: SharedProjectContext,
439 window: &mut Window,
440 cx: &mut Context<Self>,
441 ) -> Self {
442 let next_message_id = MessageId(
443 serialized
444 .messages
445 .last()
446 .map(|message| message.id.0 + 1)
447 .unwrap_or(0),
448 );
449 let tool_use = ToolUseState::from_serialized_messages(
450 tools.clone(),
451 &serialized.messages,
452 project.clone(),
453 window,
454 cx,
455 );
456 let (detailed_summary_tx, detailed_summary_rx) =
457 postage::watch::channel_with(serialized.detailed_summary_state);
458
459 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
460 serialized
461 .model
462 .and_then(|model| {
463 let model = SelectedModel {
464 provider: model.provider.clone().into(),
465 model: model.model.clone().into(),
466 };
467 registry.select_model(&model, cx)
468 })
469 .or_else(|| registry.default_model())
470 });
471
472 let completion_mode = serialized
473 .completion_mode
474 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
475
476 let configured_profile = serialized.profile.and_then(|profile| {
477 AssistantSettings::get_global(cx)
478 .profiles
479 .get(&profile)
480 .cloned()
481 });
482
483 Self {
484 id,
485 updated_at: serialized.updated_at,
486 summary: Some(serialized.summary),
487 pending_summary: Task::ready(None),
488 detailed_summary_task: Task::ready(None),
489 detailed_summary_tx,
490 detailed_summary_rx,
491 completion_mode,
492 messages: serialized
493 .messages
494 .into_iter()
495 .map(|message| Message {
496 id: message.id,
497 role: message.role,
498 segments: message
499 .segments
500 .into_iter()
501 .map(|segment| match segment {
502 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
503 SerializedMessageSegment::Thinking { text, signature } => {
504 MessageSegment::Thinking { text, signature }
505 }
506 SerializedMessageSegment::RedactedThinking { data } => {
507 MessageSegment::RedactedThinking(data)
508 }
509 })
510 .collect(),
511 loaded_context: LoadedContext {
512 contexts: Vec::new(),
513 text: message.context,
514 images: Vec::new(),
515 },
516 creases: message
517 .creases
518 .into_iter()
519 .map(|crease| MessageCrease {
520 range: crease.start..crease.end,
521 metadata: CreaseMetadata {
522 icon_path: crease.icon_path,
523 label: crease.label,
524 },
525 context: None,
526 })
527 .collect(),
528 })
529 .collect(),
530 next_message_id,
531 last_prompt_id: PromptId::new(),
532 project_context,
533 checkpoints_by_message: HashMap::default(),
534 completion_count: 0,
535 pending_completions: Vec::new(),
536 last_restore_checkpoint: None,
537 pending_checkpoint: None,
538 project: project.clone(),
539 prompt_builder,
540 tools,
541 tool_use,
542 action_log: cx.new(|_| ActionLog::new(project)),
543 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
544 request_token_usage: serialized.request_token_usage,
545 cumulative_token_usage: serialized.cumulative_token_usage,
546 exceeded_window_error: None,
547 last_usage: None,
548 tool_use_limit_reached: false,
549 feedback: None,
550 message_feedback: HashMap::default(),
551 last_auto_capture_at: None,
552 last_received_chunk_at: None,
553 request_callback: None,
554 remaining_turns: u32::MAX,
555 configured_model,
556 configured_profile,
557 }
558 }
559
560 pub fn set_request_callback(
561 &mut self,
562 callback: impl 'static
563 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
564 ) {
565 self.request_callback = Some(Box::new(callback));
566 }
567
568 pub fn id(&self) -> &ThreadId {
569 &self.id
570 }
571
572 pub fn is_empty(&self) -> bool {
573 self.messages.is_empty()
574 }
575
576 pub fn updated_at(&self) -> DateTime<Utc> {
577 self.updated_at
578 }
579
580 pub fn touch_updated_at(&mut self) {
581 self.updated_at = Utc::now();
582 }
583
584 pub fn advance_prompt_id(&mut self) {
585 self.last_prompt_id = PromptId::new();
586 }
587
588 pub fn summary(&self) -> Option<SharedString> {
589 self.summary.clone()
590 }
591
592 pub fn project_context(&self) -> SharedProjectContext {
593 self.project_context.clone()
594 }
595
596 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
597 if self.configured_model.is_none() {
598 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
599 }
600 self.configured_model.clone()
601 }
602
603 pub fn configured_model(&self) -> Option<ConfiguredModel> {
604 self.configured_model.clone()
605 }
606
607 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
608 self.configured_model = model;
609 cx.notify();
610 }
611
612 pub fn configured_profile(&self) -> Option<AgentProfile> {
613 self.configured_profile.clone()
614 }
615
616 pub fn set_configured_profile(
617 &mut self,
618 profile: Option<AgentProfile>,
619 cx: &mut Context<Self>,
620 ) {
621 self.configured_profile = profile;
622 cx.notify();
623 }
624
625 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
626
627 pub fn summary_or_default(&self) -> SharedString {
628 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
629 }
630
631 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
632 let Some(current_summary) = &self.summary else {
633 // Don't allow setting summary until generated
634 return;
635 };
636
637 let mut new_summary = new_summary.into();
638
639 if new_summary.is_empty() {
640 new_summary = Self::DEFAULT_SUMMARY;
641 }
642
643 if current_summary != &new_summary {
644 self.summary = Some(new_summary);
645 cx.emit(ThreadEvent::SummaryChanged);
646 }
647 }
648
649 pub fn completion_mode(&self) -> CompletionMode {
650 self.completion_mode
651 }
652
653 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
654 self.completion_mode = mode;
655 }
656
657 pub fn message(&self, id: MessageId) -> Option<&Message> {
658 let index = self
659 .messages
660 .binary_search_by(|message| message.id.cmp(&id))
661 .ok()?;
662
663 self.messages.get(index)
664 }
665
666 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
667 self.messages.iter()
668 }
669
670 pub fn is_generating(&self) -> bool {
671 !self.pending_completions.is_empty() || !self.all_tools_finished()
672 }
673
674 /// Indicates whether streaming of language model events is stale.
675 /// When `is_generating()` is false, this method returns `None`.
676 pub fn is_generation_stale(&self) -> Option<bool> {
677 const STALE_THRESHOLD: u128 = 250;
678
679 self.last_received_chunk_at
680 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
681 }
682
683 fn received_chunk(&mut self) {
684 self.last_received_chunk_at = Some(Instant::now());
685 }
686
687 pub fn queue_state(&self) -> Option<QueueState> {
688 self.pending_completions
689 .first()
690 .map(|pending_completion| pending_completion.queue_state)
691 }
692
693 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
694 &self.tools
695 }
696
697 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
698 self.tool_use
699 .pending_tool_uses()
700 .into_iter()
701 .find(|tool_use| &tool_use.id == id)
702 }
703
704 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
705 self.tool_use
706 .pending_tool_uses()
707 .into_iter()
708 .filter(|tool_use| tool_use.status.needs_confirmation())
709 }
710
711 pub fn has_pending_tool_uses(&self) -> bool {
712 !self.tool_use.pending_tool_uses().is_empty()
713 }
714
715 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
716 self.checkpoints_by_message.get(&id).cloned()
717 }
718
719 pub fn restore_checkpoint(
720 &mut self,
721 checkpoint: ThreadCheckpoint,
722 cx: &mut Context<Self>,
723 ) -> Task<Result<()>> {
724 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
725 message_id: checkpoint.message_id,
726 });
727 cx.emit(ThreadEvent::CheckpointChanged);
728 cx.notify();
729
730 let git_store = self.project().read(cx).git_store().clone();
731 let restore = git_store.update(cx, |git_store, cx| {
732 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
733 });
734
735 cx.spawn(async move |this, cx| {
736 let result = restore.await;
737 this.update(cx, |this, cx| {
738 if let Err(err) = result.as_ref() {
739 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
740 message_id: checkpoint.message_id,
741 error: err.to_string(),
742 });
743 } else {
744 this.truncate(checkpoint.message_id, cx);
745 this.last_restore_checkpoint = None;
746 }
747 this.pending_checkpoint = None;
748 cx.emit(ThreadEvent::CheckpointChanged);
749 cx.notify();
750 })?;
751 result
752 })
753 }
754
755 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
756 let pending_checkpoint = if self.is_generating() {
757 return;
758 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
759 checkpoint
760 } else {
761 return;
762 };
763
764 let git_store = self.project.read(cx).git_store().clone();
765 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
766 cx.spawn(async move |this, cx| match final_checkpoint.await {
767 Ok(final_checkpoint) => {
768 let equal = git_store
769 .update(cx, |store, cx| {
770 store.compare_checkpoints(
771 pending_checkpoint.git_checkpoint.clone(),
772 final_checkpoint.clone(),
773 cx,
774 )
775 })?
776 .await
777 .unwrap_or(false);
778
779 if !equal {
780 this.update(cx, |this, cx| {
781 this.insert_checkpoint(pending_checkpoint, cx)
782 })?;
783 }
784
785 Ok(())
786 }
787 Err(_) => this.update(cx, |this, cx| {
788 this.insert_checkpoint(pending_checkpoint, cx)
789 }),
790 })
791 .detach();
792 }
793
794 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
795 self.checkpoints_by_message
796 .insert(checkpoint.message_id, checkpoint);
797 cx.emit(ThreadEvent::CheckpointChanged);
798 cx.notify();
799 }
800
801 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
802 self.last_restore_checkpoint.as_ref()
803 }
804
805 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
806 let Some(message_ix) = self
807 .messages
808 .iter()
809 .rposition(|message| message.id == message_id)
810 else {
811 return;
812 };
813 for deleted_message in self.messages.drain(message_ix..) {
814 self.checkpoints_by_message.remove(&deleted_message.id);
815 }
816 cx.notify();
817 }
818
819 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
820 self.messages
821 .iter()
822 .find(|message| message.id == id)
823 .into_iter()
824 .flat_map(|message| message.loaded_context.contexts.iter())
825 }
826
827 pub fn is_turn_end(&self, ix: usize) -> bool {
828 if self.messages.is_empty() {
829 return false;
830 }
831
832 if !self.is_generating() && ix == self.messages.len() - 1 {
833 return true;
834 }
835
836 let Some(message) = self.messages.get(ix) else {
837 return false;
838 };
839
840 if message.role != Role::Assistant {
841 return false;
842 }
843
844 self.messages
845 .get(ix + 1)
846 .and_then(|message| {
847 self.message(message.id)
848 .map(|next_message| next_message.role == Role::User)
849 })
850 .unwrap_or(false)
851 }
852
853 pub fn last_usage(&self) -> Option<RequestUsage> {
854 self.last_usage
855 }
856
857 pub fn tool_use_limit_reached(&self) -> bool {
858 self.tool_use_limit_reached
859 }
860
861 /// Returns whether all of the tool uses have finished running.
862 pub fn all_tools_finished(&self) -> bool {
863 // If the only pending tool uses left are the ones with errors, then
864 // that means that we've finished running all of the pending tools.
865 self.tool_use
866 .pending_tool_uses()
867 .iter()
868 .all(|tool_use| tool_use.status.is_error())
869 }
870
871 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
872 self.tool_use.tool_uses_for_message(id, cx)
873 }
874
875 pub fn tool_results_for_message(
876 &self,
877 assistant_message_id: MessageId,
878 ) -> Vec<&LanguageModelToolResult> {
879 self.tool_use.tool_results_for_message(assistant_message_id)
880 }
881
882 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
883 self.tool_use.tool_result(id)
884 }
885
886 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
887 Some(&self.tool_use.tool_result(id)?.content)
888 }
889
890 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
891 self.tool_use.tool_result_card(id).cloned()
892 }
893
894 /// Return tools that are both enabled and supported by the model
895 pub fn available_tools(
896 &self,
897 cx: &App,
898 model: Arc<dyn LanguageModel>,
899 ) -> Vec<LanguageModelRequestTool> {
900 if model.supports_tools() {
901 self.tools()
902 .read(cx)
903 .enabled_tools(cx)
904 .into_iter()
905 .filter_map(|tool| {
906 // Skip tools that cannot be supported
907 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
908 Some(LanguageModelRequestTool {
909 name: tool.name(),
910 description: tool.description(),
911 input_schema,
912 })
913 })
914 .collect()
915 } else {
916 Vec::default()
917 }
918 }
919
920 pub fn insert_user_message(
921 &mut self,
922 text: impl Into<String>,
923 loaded_context: ContextLoadResult,
924 git_checkpoint: Option<GitStoreCheckpoint>,
925 creases: Vec<MessageCrease>,
926 cx: &mut Context<Self>,
927 ) -> MessageId {
928 if !loaded_context.referenced_buffers.is_empty() {
929 self.action_log.update(cx, |log, cx| {
930 for buffer in loaded_context.referenced_buffers {
931 log.buffer_read(buffer, cx);
932 }
933 });
934 }
935
936 let message_id = self.insert_message(
937 Role::User,
938 vec![MessageSegment::Text(text.into())],
939 loaded_context.loaded_context,
940 creases,
941 cx,
942 );
943
944 if let Some(git_checkpoint) = git_checkpoint {
945 self.pending_checkpoint = Some(ThreadCheckpoint {
946 message_id,
947 git_checkpoint,
948 });
949 }
950
951 self.auto_capture_telemetry(cx);
952
953 message_id
954 }
955
956 pub fn insert_assistant_message(
957 &mut self,
958 segments: Vec<MessageSegment>,
959 cx: &mut Context<Self>,
960 ) -> MessageId {
961 self.insert_message(
962 Role::Assistant,
963 segments,
964 LoadedContext::default(),
965 Vec::new(),
966 cx,
967 )
968 }
969
970 pub fn insert_message(
971 &mut self,
972 role: Role,
973 segments: Vec<MessageSegment>,
974 loaded_context: LoadedContext,
975 creases: Vec<MessageCrease>,
976 cx: &mut Context<Self>,
977 ) -> MessageId {
978 let id = self.next_message_id.post_inc();
979 self.messages.push(Message {
980 id,
981 role,
982 segments,
983 loaded_context,
984 creases,
985 });
986 self.touch_updated_at();
987 cx.emit(ThreadEvent::MessageAdded(id));
988 id
989 }
990
991 pub fn edit_message(
992 &mut self,
993 id: MessageId,
994 new_role: Role,
995 new_segments: Vec<MessageSegment>,
996 loaded_context: Option<LoadedContext>,
997 cx: &mut Context<Self>,
998 ) -> bool {
999 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1000 return false;
1001 };
1002 message.role = new_role;
1003 message.segments = new_segments;
1004 if let Some(context) = loaded_context {
1005 message.loaded_context = context;
1006 }
1007 self.touch_updated_at();
1008 cx.emit(ThreadEvent::MessageEdited(id));
1009 true
1010 }
1011
1012 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1013 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1014 return false;
1015 };
1016 self.messages.remove(index);
1017 self.touch_updated_at();
1018 cx.emit(ThreadEvent::MessageDeleted(id));
1019 true
1020 }
1021
1022 /// Returns the representation of this [`Thread`] in a textual form.
1023 ///
1024 /// This is the representation we use when attaching a thread as context to another thread.
1025 pub fn text(&self) -> String {
1026 let mut text = String::new();
1027
1028 for message in &self.messages {
1029 text.push_str(match message.role {
1030 language_model::Role::User => "User:",
1031 language_model::Role::Assistant => "Agent:",
1032 language_model::Role::System => "System:",
1033 });
1034 text.push('\n');
1035
1036 for segment in &message.segments {
1037 match segment {
1038 MessageSegment::Text(content) => text.push_str(content),
1039 MessageSegment::Thinking { text: content, .. } => {
1040 text.push_str(&format!("<think>{}</think>", content))
1041 }
1042 MessageSegment::RedactedThinking(_) => {}
1043 }
1044 }
1045 text.push('\n');
1046 }
1047
1048 text
1049 }
1050
1051 /// Serializes this thread into a format for storage or telemetry.
1052 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1053 let initial_project_snapshot = self.initial_project_snapshot.clone();
1054 cx.spawn(async move |this, cx| {
1055 let initial_project_snapshot = initial_project_snapshot.await;
1056 this.read_with(cx, |this, cx| SerializedThread {
1057 version: SerializedThread::VERSION.to_string(),
1058 summary: this.summary_or_default(),
1059 updated_at: this.updated_at(),
1060 messages: this
1061 .messages()
1062 .map(|message| SerializedMessage {
1063 id: message.id,
1064 role: message.role,
1065 segments: message
1066 .segments
1067 .iter()
1068 .map(|segment| match segment {
1069 MessageSegment::Text(text) => {
1070 SerializedMessageSegment::Text { text: text.clone() }
1071 }
1072 MessageSegment::Thinking { text, signature } => {
1073 SerializedMessageSegment::Thinking {
1074 text: text.clone(),
1075 signature: signature.clone(),
1076 }
1077 }
1078 MessageSegment::RedactedThinking(data) => {
1079 SerializedMessageSegment::RedactedThinking {
1080 data: data.clone(),
1081 }
1082 }
1083 })
1084 .collect(),
1085 tool_uses: this
1086 .tool_uses_for_message(message.id, cx)
1087 .into_iter()
1088 .map(|tool_use| SerializedToolUse {
1089 id: tool_use.id,
1090 name: tool_use.name,
1091 input: tool_use.input,
1092 })
1093 .collect(),
1094 tool_results: this
1095 .tool_results_for_message(message.id)
1096 .into_iter()
1097 .map(|tool_result| SerializedToolResult {
1098 tool_use_id: tool_result.tool_use_id.clone(),
1099 is_error: tool_result.is_error,
1100 content: tool_result.content.clone(),
1101 output: tool_result.output.clone(),
1102 })
1103 .collect(),
1104 context: message.loaded_context.text.clone(),
1105 creases: message
1106 .creases
1107 .iter()
1108 .map(|crease| SerializedCrease {
1109 start: crease.range.start,
1110 end: crease.range.end,
1111 icon_path: crease.metadata.icon_path.clone(),
1112 label: crease.metadata.label.clone(),
1113 })
1114 .collect(),
1115 })
1116 .collect(),
1117 initial_project_snapshot,
1118 cumulative_token_usage: this.cumulative_token_usage,
1119 request_token_usage: this.request_token_usage.clone(),
1120 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1121 exceeded_window_error: this.exceeded_window_error.clone(),
1122 model: this
1123 .configured_model
1124 .as_ref()
1125 .map(|model| SerializedLanguageModel {
1126 provider: model.provider.id().0.to_string(),
1127 model: model.model.id().0.to_string(),
1128 }),
1129 profile: this
1130 .configured_profile
1131 .as_ref()
1132 .map(|profile| AgentProfileId(profile.name.clone().into())),
1133 completion_mode: Some(this.completion_mode),
1134 })
1135 })
1136 }
1137
1138 pub fn remaining_turns(&self) -> u32 {
1139 self.remaining_turns
1140 }
1141
1142 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1143 self.remaining_turns = remaining_turns;
1144 }
1145
1146 pub fn send_to_model(
1147 &mut self,
1148 model: Arc<dyn LanguageModel>,
1149 window: Option<AnyWindowHandle>,
1150 cx: &mut Context<Self>,
1151 ) {
1152 if self.remaining_turns == 0 {
1153 return;
1154 }
1155
1156 self.remaining_turns -= 1;
1157
1158 let request = self.to_completion_request(model.clone(), cx);
1159
1160 self.stream_completion(request, model, window, cx);
1161 }
1162
1163 pub fn used_tools_since_last_user_message(&self) -> bool {
1164 for message in self.messages.iter().rev() {
1165 if self.tool_use.message_has_tool_results(message.id) {
1166 return true;
1167 } else if message.role == Role::User {
1168 return false;
1169 }
1170 }
1171
1172 false
1173 }
1174
1175 pub fn to_completion_request(
1176 &self,
1177 model: Arc<dyn LanguageModel>,
1178 cx: &mut Context<Self>,
1179 ) -> LanguageModelRequest {
1180 let mut request = LanguageModelRequest {
1181 thread_id: Some(self.id.to_string()),
1182 prompt_id: Some(self.last_prompt_id.to_string()),
1183 mode: None,
1184 messages: vec![],
1185 tools: Vec::new(),
1186 stop: Vec::new(),
1187 temperature: AssistantSettings::temperature_for_model(&model, cx),
1188 };
1189
1190 let available_tools = self.available_tools(cx, model.clone());
1191 let available_tool_names = available_tools
1192 .iter()
1193 .map(|tool| tool.name.clone())
1194 .collect();
1195
1196 let model_context = &ModelContext {
1197 available_tools: available_tool_names,
1198 };
1199
1200 if let Some(project_context) = self.project_context.borrow().as_ref() {
1201 match self
1202 .prompt_builder
1203 .generate_assistant_system_prompt(project_context, model_context)
1204 {
1205 Err(err) => {
1206 let message = format!("{err:?}").into();
1207 log::error!("{message}");
1208 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1209 header: "Error generating system prompt".into(),
1210 message,
1211 }));
1212 }
1213 Ok(system_prompt) => {
1214 request.messages.push(LanguageModelRequestMessage {
1215 role: Role::System,
1216 content: vec![MessageContent::Text(system_prompt)],
1217 cache: true,
1218 });
1219 }
1220 }
1221 } else {
1222 let message = "Context for system prompt unexpectedly not ready.".into();
1223 log::error!("{message}");
1224 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1225 header: "Error generating system prompt".into(),
1226 message,
1227 }));
1228 }
1229
1230 for message in &self.messages {
1231 let mut request_message = LanguageModelRequestMessage {
1232 role: message.role,
1233 content: Vec::new(),
1234 cache: false,
1235 };
1236
1237 message
1238 .loaded_context
1239 .add_to_request_message(&mut request_message);
1240
1241 for segment in &message.segments {
1242 match segment {
1243 MessageSegment::Text(text) => {
1244 if !text.is_empty() {
1245 request_message
1246 .content
1247 .push(MessageContent::Text(text.into()));
1248 }
1249 }
1250 MessageSegment::Thinking { text, signature } => {
1251 if !text.is_empty() {
1252 request_message.content.push(MessageContent::Thinking {
1253 text: text.into(),
1254 signature: signature.clone(),
1255 });
1256 }
1257 }
1258 MessageSegment::RedactedThinking(data) => {
1259 request_message
1260 .content
1261 .push(MessageContent::RedactedThinking(data.clone()));
1262 }
1263 };
1264 }
1265
1266 self.tool_use
1267 .attach_tool_uses(message.id, &mut request_message);
1268
1269 request.messages.push(request_message);
1270
1271 if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
1272 request.messages.push(tool_results_message);
1273 }
1274 }
1275
1276 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1277 if let Some(last) = request.messages.last_mut() {
1278 last.cache = true;
1279 }
1280
1281 self.attached_tracked_files_state(&mut request.messages, cx);
1282
1283 request.tools = available_tools;
1284 request.mode = if model.supports_max_mode() {
1285 Some(self.completion_mode.into())
1286 } else {
1287 Some(CompletionMode::Normal.into())
1288 };
1289
1290 request
1291 }
1292
1293 fn to_summarize_request(
1294 &self,
1295 model: &Arc<dyn LanguageModel>,
1296 added_user_message: String,
1297 cx: &App,
1298 ) -> LanguageModelRequest {
1299 let mut request = LanguageModelRequest {
1300 thread_id: None,
1301 prompt_id: None,
1302 mode: None,
1303 messages: vec![],
1304 tools: Vec::new(),
1305 stop: Vec::new(),
1306 temperature: AssistantSettings::temperature_for_model(model, cx),
1307 };
1308
1309 for message in &self.messages {
1310 let mut request_message = LanguageModelRequestMessage {
1311 role: message.role,
1312 content: Vec::new(),
1313 cache: false,
1314 };
1315
1316 for segment in &message.segments {
1317 match segment {
1318 MessageSegment::Text(text) => request_message
1319 .content
1320 .push(MessageContent::Text(text.clone())),
1321 MessageSegment::Thinking { .. } => {}
1322 MessageSegment::RedactedThinking(_) => {}
1323 }
1324 }
1325
1326 if request_message.content.is_empty() {
1327 continue;
1328 }
1329
1330 request.messages.push(request_message);
1331 }
1332
1333 request.messages.push(LanguageModelRequestMessage {
1334 role: Role::User,
1335 content: vec![MessageContent::Text(added_user_message)],
1336 cache: false,
1337 });
1338
1339 request
1340 }
1341
1342 fn attached_tracked_files_state(
1343 &self,
1344 messages: &mut Vec<LanguageModelRequestMessage>,
1345 cx: &App,
1346 ) {
1347 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1348
1349 let mut stale_message = String::new();
1350
1351 let action_log = self.action_log.read(cx);
1352
1353 for stale_file in action_log.stale_buffers(cx) {
1354 let Some(file) = stale_file.read(cx).file() else {
1355 continue;
1356 };
1357
1358 if stale_message.is_empty() {
1359 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1360 }
1361
1362 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1363 }
1364
1365 let mut content = Vec::with_capacity(2);
1366
1367 if !stale_message.is_empty() {
1368 content.push(stale_message.into());
1369 }
1370
1371 if !content.is_empty() {
1372 let context_message = LanguageModelRequestMessage {
1373 role: Role::User,
1374 content,
1375 cache: false,
1376 };
1377
1378 messages.push(context_message);
1379 }
1380 }
1381
1382 pub fn stream_completion(
1383 &mut self,
1384 request: LanguageModelRequest,
1385 model: Arc<dyn LanguageModel>,
1386 window: Option<AnyWindowHandle>,
1387 cx: &mut Context<Self>,
1388 ) {
1389 self.tool_use_limit_reached = false;
1390
1391 let pending_completion_id = post_inc(&mut self.completion_count);
1392 let mut request_callback_parameters = if self.request_callback.is_some() {
1393 Some((request.clone(), Vec::new()))
1394 } else {
1395 None
1396 };
1397 let prompt_id = self.last_prompt_id.clone();
1398 let tool_use_metadata = ToolUseMetadata {
1399 model: model.clone(),
1400 thread_id: self.id.clone(),
1401 prompt_id: prompt_id.clone(),
1402 };
1403
1404 self.last_received_chunk_at = Some(Instant::now());
1405
1406 let task = cx.spawn(async move |thread, cx| {
1407 let stream_completion_future = model.stream_completion(request, &cx);
1408 let initial_token_usage =
1409 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1410 let stream_completion = async {
1411 let mut events = stream_completion_future.await?;
1412
1413 let mut stop_reason = StopReason::EndTurn;
1414 let mut current_token_usage = TokenUsage::default();
1415
1416 thread
1417 .update(cx, |_thread, cx| {
1418 cx.emit(ThreadEvent::NewRequest);
1419 })
1420 .ok();
1421
1422 let mut request_assistant_message_id = None;
1423
1424 while let Some(event) = events.next().await {
1425 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1426 response_events
1427 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1428 }
1429
1430 thread.update(cx, |thread, cx| {
1431 let event = match event {
1432 Ok(event) => event,
1433 Err(LanguageModelCompletionError::BadInputJson {
1434 id,
1435 tool_name,
1436 raw_input: invalid_input_json,
1437 json_parse_error,
1438 }) => {
1439 thread.receive_invalid_tool_json(
1440 id,
1441 tool_name,
1442 invalid_input_json,
1443 json_parse_error,
1444 window,
1445 cx,
1446 );
1447 return Ok(());
1448 }
1449 Err(LanguageModelCompletionError::Other(error)) => {
1450 return Err(error);
1451 }
1452 };
1453
1454 match event {
1455 LanguageModelCompletionEvent::StartMessage { .. } => {
1456 request_assistant_message_id =
1457 Some(thread.insert_assistant_message(
1458 vec![MessageSegment::Text(String::new())],
1459 cx,
1460 ));
1461 }
1462 LanguageModelCompletionEvent::Stop(reason) => {
1463 stop_reason = reason;
1464 }
1465 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1466 thread.update_token_usage_at_last_message(token_usage);
1467 thread.cumulative_token_usage = thread.cumulative_token_usage
1468 + token_usage
1469 - current_token_usage;
1470 current_token_usage = token_usage;
1471 }
1472 LanguageModelCompletionEvent::Text(chunk) => {
1473 thread.received_chunk();
1474
1475 cx.emit(ThreadEvent::ReceivedTextChunk);
1476 if let Some(last_message) = thread.messages.last_mut() {
1477 if last_message.role == Role::Assistant
1478 && !thread.tool_use.has_tool_results(last_message.id)
1479 {
1480 last_message.push_text(&chunk);
1481 cx.emit(ThreadEvent::StreamedAssistantText(
1482 last_message.id,
1483 chunk,
1484 ));
1485 } else {
1486 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1487 // of a new Assistant response.
1488 //
1489 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1490 // will result in duplicating the text of the chunk in the rendered Markdown.
1491 request_assistant_message_id =
1492 Some(thread.insert_assistant_message(
1493 vec![MessageSegment::Text(chunk.to_string())],
1494 cx,
1495 ));
1496 };
1497 }
1498 }
1499 LanguageModelCompletionEvent::Thinking {
1500 text: chunk,
1501 signature,
1502 } => {
1503 thread.received_chunk();
1504
1505 if let Some(last_message) = thread.messages.last_mut() {
1506 if last_message.role == Role::Assistant
1507 && !thread.tool_use.has_tool_results(last_message.id)
1508 {
1509 last_message.push_thinking(&chunk, signature);
1510 cx.emit(ThreadEvent::StreamedAssistantThinking(
1511 last_message.id,
1512 chunk,
1513 ));
1514 } else {
1515 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1516 // of a new Assistant response.
1517 //
1518 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1519 // will result in duplicating the text of the chunk in the rendered Markdown.
1520 request_assistant_message_id =
1521 Some(thread.insert_assistant_message(
1522 vec![MessageSegment::Thinking {
1523 text: chunk.to_string(),
1524 signature,
1525 }],
1526 cx,
1527 ));
1528 };
1529 }
1530 }
1531 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1532 let last_assistant_message_id = request_assistant_message_id
1533 .unwrap_or_else(|| {
1534 let new_assistant_message_id =
1535 thread.insert_assistant_message(vec![], cx);
1536 request_assistant_message_id =
1537 Some(new_assistant_message_id);
1538 new_assistant_message_id
1539 });
1540
1541 let tool_use_id = tool_use.id.clone();
1542 let streamed_input = if tool_use.is_input_complete {
1543 None
1544 } else {
1545 Some((&tool_use.input).clone())
1546 };
1547
1548 let ui_text = thread.tool_use.request_tool_use(
1549 last_assistant_message_id,
1550 tool_use,
1551 tool_use_metadata.clone(),
1552 cx,
1553 );
1554
1555 if let Some(input) = streamed_input {
1556 cx.emit(ThreadEvent::StreamedToolUse {
1557 tool_use_id,
1558 ui_text,
1559 input,
1560 });
1561 }
1562 }
1563 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1564 if let Some(completion) = thread
1565 .pending_completions
1566 .iter_mut()
1567 .find(|completion| completion.id == pending_completion_id)
1568 {
1569 match status_update {
1570 CompletionRequestStatus::Queued {
1571 position,
1572 } => {
1573 completion.queue_state = QueueState::Queued { position };
1574 }
1575 CompletionRequestStatus::Started => {
1576 completion.queue_state = QueueState::Started;
1577 }
1578 CompletionRequestStatus::Failed {
1579 code, message, request_id
1580 } => {
1581 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1582 }
1583 CompletionRequestStatus::UsageUpdated {
1584 amount, limit
1585 } => {
1586 let usage = RequestUsage { limit, amount: amount as i32 };
1587
1588 thread.last_usage = Some(usage);
1589 }
1590 CompletionRequestStatus::ToolUseLimitReached => {
1591 thread.tool_use_limit_reached = true;
1592 }
1593 }
1594 }
1595 }
1596 }
1597
1598 thread.touch_updated_at();
1599 cx.emit(ThreadEvent::StreamedCompletion);
1600 cx.notify();
1601
1602 thread.auto_capture_telemetry(cx);
1603 Ok(())
1604 })??;
1605
1606 smol::future::yield_now().await;
1607 }
1608
1609 thread.update(cx, |thread, cx| {
1610 thread.last_received_chunk_at = None;
1611 thread
1612 .pending_completions
1613 .retain(|completion| completion.id != pending_completion_id);
1614
1615 // If there is a response without tool use, summarize the message. Otherwise,
1616 // allow two tool uses before summarizing.
1617 if thread.summary.is_none()
1618 && thread.messages.len() >= 2
1619 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1620 {
1621 thread.summarize(cx);
1622 }
1623 })?;
1624
1625 anyhow::Ok(stop_reason)
1626 };
1627
1628 let result = stream_completion.await;
1629
1630 thread
1631 .update(cx, |thread, cx| {
1632 thread.finalize_pending_checkpoint(cx);
1633 match result.as_ref() {
1634 Ok(stop_reason) => match stop_reason {
1635 StopReason::ToolUse => {
1636 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1637 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1638 }
1639 StopReason::EndTurn | StopReason::MaxTokens => {
1640 thread.project.update(cx, |project, cx| {
1641 project.set_agent_location(None, cx);
1642 });
1643 }
1644 },
1645 Err(error) => {
1646 thread.project.update(cx, |project, cx| {
1647 project.set_agent_location(None, cx);
1648 });
1649
1650 if error.is::<PaymentRequiredError>() {
1651 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1652 } else if error.is::<MaxMonthlySpendReachedError>() {
1653 cx.emit(ThreadEvent::ShowError(
1654 ThreadError::MaxMonthlySpendReached,
1655 ));
1656 } else if let Some(error) =
1657 error.downcast_ref::<ModelRequestLimitReachedError>()
1658 {
1659 cx.emit(ThreadEvent::ShowError(
1660 ThreadError::ModelRequestLimitReached { plan: error.plan },
1661 ));
1662 } else if let Some(known_error) =
1663 error.downcast_ref::<LanguageModelKnownError>()
1664 {
1665 match known_error {
1666 LanguageModelKnownError::ContextWindowLimitExceeded {
1667 tokens,
1668 } => {
1669 thread.exceeded_window_error = Some(ExceededWindowError {
1670 model_id: model.id(),
1671 token_count: *tokens,
1672 });
1673 cx.notify();
1674 }
1675 }
1676 } else {
1677 let error_message = error
1678 .chain()
1679 .map(|err| err.to_string())
1680 .collect::<Vec<_>>()
1681 .join("\n");
1682 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1683 header: "Error interacting with language model".into(),
1684 message: SharedString::from(error_message.clone()),
1685 }));
1686 }
1687
1688 thread.cancel_last_completion(window, cx);
1689 }
1690 }
1691 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1692
1693 if let Some((request_callback, (request, response_events))) = thread
1694 .request_callback
1695 .as_mut()
1696 .zip(request_callback_parameters.as_ref())
1697 {
1698 request_callback(request, response_events);
1699 }
1700
1701 thread.auto_capture_telemetry(cx);
1702
1703 if let Ok(initial_usage) = initial_token_usage {
1704 let usage = thread.cumulative_token_usage - initial_usage;
1705
1706 telemetry::event!(
1707 "Assistant Thread Completion",
1708 thread_id = thread.id().to_string(),
1709 prompt_id = prompt_id,
1710 model = model.telemetry_id(),
1711 model_provider = model.provider_id().to_string(),
1712 input_tokens = usage.input_tokens,
1713 output_tokens = usage.output_tokens,
1714 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1715 cache_read_input_tokens = usage.cache_read_input_tokens,
1716 );
1717 }
1718 })
1719 .ok();
1720 });
1721
1722 self.pending_completions.push(PendingCompletion {
1723 id: pending_completion_id,
1724 queue_state: QueueState::Sending,
1725 _task: task,
1726 });
1727 }
1728
1729 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1730 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1731 return;
1732 };
1733
1734 if !model.provider.is_authenticated(cx) {
1735 return;
1736 }
1737
1738 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1739 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1740 If the conversation is about a specific subject, include it in the title. \
1741 Be descriptive. DO NOT speak in the first person.";
1742
1743 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1744
1745 self.pending_summary = cx.spawn(async move |this, cx| {
1746 async move {
1747 let mut messages = model.model.stream_completion(request, &cx).await?;
1748
1749 let mut new_summary = String::new();
1750 while let Some(event) = messages.next().await {
1751 let event = event?;
1752 let text = match event {
1753 LanguageModelCompletionEvent::Text(text) => text,
1754 LanguageModelCompletionEvent::StatusUpdate(
1755 CompletionRequestStatus::UsageUpdated { amount, limit },
1756 ) => {
1757 this.update(cx, |thread, _cx| {
1758 thread.last_usage = Some(RequestUsage {
1759 limit,
1760 amount: amount as i32,
1761 });
1762 })?;
1763 continue;
1764 }
1765 _ => continue,
1766 };
1767
1768 let mut lines = text.lines();
1769 new_summary.extend(lines.next());
1770
1771 // Stop if the LLM generated multiple lines.
1772 if lines.next().is_some() {
1773 break;
1774 }
1775 }
1776
1777 this.update(cx, |this, cx| {
1778 if !new_summary.is_empty() {
1779 this.summary = Some(new_summary.into());
1780 }
1781
1782 cx.emit(ThreadEvent::SummaryGenerated);
1783 })?;
1784
1785 anyhow::Ok(())
1786 }
1787 .log_err()
1788 .await
1789 });
1790 }
1791
1792 pub fn start_generating_detailed_summary_if_needed(
1793 &mut self,
1794 thread_store: WeakEntity<ThreadStore>,
1795 cx: &mut Context<Self>,
1796 ) {
1797 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1798 return;
1799 };
1800
1801 match &*self.detailed_summary_rx.borrow() {
1802 DetailedSummaryState::Generating { message_id, .. }
1803 | DetailedSummaryState::Generated { message_id, .. }
1804 if *message_id == last_message_id =>
1805 {
1806 // Already up-to-date
1807 return;
1808 }
1809 _ => {}
1810 }
1811
1812 let Some(ConfiguredModel { model, provider }) =
1813 LanguageModelRegistry::read_global(cx).thread_summary_model()
1814 else {
1815 return;
1816 };
1817
1818 if !provider.is_authenticated(cx) {
1819 return;
1820 }
1821
1822 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1823 1. A brief overview of what was discussed\n\
1824 2. Key facts or information discovered\n\
1825 3. Outcomes or conclusions reached\n\
1826 4. Any action items or next steps if any\n\
1827 Format it in Markdown with headings and bullet points.";
1828
1829 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1830
1831 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1832 message_id: last_message_id,
1833 };
1834
1835 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1836 // be better to allow the old task to complete, but this would require logic for choosing
1837 // which result to prefer (the old task could complete after the new one, resulting in a
1838 // stale summary).
1839 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1840 let stream = model.stream_completion_text(request, &cx);
1841 let Some(mut messages) = stream.await.log_err() else {
1842 thread
1843 .update(cx, |thread, _cx| {
1844 *thread.detailed_summary_tx.borrow_mut() =
1845 DetailedSummaryState::NotGenerated;
1846 })
1847 .ok()?;
1848 return None;
1849 };
1850
1851 let mut new_detailed_summary = String::new();
1852
1853 while let Some(chunk) = messages.stream.next().await {
1854 if let Some(chunk) = chunk.log_err() {
1855 new_detailed_summary.push_str(&chunk);
1856 }
1857 }
1858
1859 thread
1860 .update(cx, |thread, _cx| {
1861 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1862 text: new_detailed_summary.into(),
1863 message_id: last_message_id,
1864 };
1865 })
1866 .ok()?;
1867
1868 // Save thread so its summary can be reused later
1869 if let Some(thread) = thread.upgrade() {
1870 if let Ok(Ok(save_task)) = cx.update(|cx| {
1871 thread_store
1872 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1873 }) {
1874 save_task.await.log_err();
1875 }
1876 }
1877
1878 Some(())
1879 });
1880 }
1881
1882 pub async fn wait_for_detailed_summary_or_text(
1883 this: &Entity<Self>,
1884 cx: &mut AsyncApp,
1885 ) -> Option<SharedString> {
1886 let mut detailed_summary_rx = this
1887 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1888 .ok()?;
1889 loop {
1890 match detailed_summary_rx.recv().await? {
1891 DetailedSummaryState::Generating { .. } => {}
1892 DetailedSummaryState::NotGenerated => {
1893 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1894 }
1895 DetailedSummaryState::Generated { text, .. } => return Some(text),
1896 }
1897 }
1898 }
1899
1900 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1901 self.detailed_summary_rx
1902 .borrow()
1903 .text()
1904 .unwrap_or_else(|| self.text().into())
1905 }
1906
1907 pub fn is_generating_detailed_summary(&self) -> bool {
1908 matches!(
1909 &*self.detailed_summary_rx.borrow(),
1910 DetailedSummaryState::Generating { .. }
1911 )
1912 }
1913
1914 pub fn use_pending_tools(
1915 &mut self,
1916 window: Option<AnyWindowHandle>,
1917 cx: &mut Context<Self>,
1918 model: Arc<dyn LanguageModel>,
1919 ) -> Vec<PendingToolUse> {
1920 self.auto_capture_telemetry(cx);
1921 let request = self.to_completion_request(model.clone(), cx);
1922 let messages = Arc::new(request.messages);
1923 let pending_tool_uses = self
1924 .tool_use
1925 .pending_tool_uses()
1926 .into_iter()
1927 .filter(|tool_use| tool_use.status.is_idle())
1928 .cloned()
1929 .collect::<Vec<_>>();
1930
1931 for tool_use in pending_tool_uses.iter() {
1932 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1933 if tool.needs_confirmation(&tool_use.input, cx)
1934 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1935 {
1936 self.tool_use.confirm_tool_use(
1937 tool_use.id.clone(),
1938 tool_use.ui_text.clone(),
1939 tool_use.input.clone(),
1940 messages.clone(),
1941 tool,
1942 );
1943 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1944 } else {
1945 self.run_tool(
1946 tool_use.id.clone(),
1947 tool_use.ui_text.clone(),
1948 tool_use.input.clone(),
1949 &messages,
1950 tool,
1951 model.clone(),
1952 window,
1953 cx,
1954 );
1955 }
1956 } else {
1957 self.handle_hallucinated_tool_use(
1958 tool_use.id.clone(),
1959 tool_use.name.clone(),
1960 window,
1961 cx,
1962 );
1963 }
1964 }
1965
1966 pending_tool_uses
1967 }
1968
1969 pub fn handle_hallucinated_tool_use(
1970 &mut self,
1971 tool_use_id: LanguageModelToolUseId,
1972 hallucinated_tool_name: Arc<str>,
1973 window: Option<AnyWindowHandle>,
1974 cx: &mut Context<Thread>,
1975 ) {
1976 let available_tools = self.tools.read(cx).enabled_tools(cx);
1977
1978 let tool_list = available_tools
1979 .iter()
1980 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
1981 .collect::<Vec<_>>()
1982 .join("\n");
1983
1984 let error_message = format!(
1985 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
1986 hallucinated_tool_name, tool_list
1987 );
1988
1989 let pending_tool_use = self.tool_use.insert_tool_output(
1990 tool_use_id.clone(),
1991 hallucinated_tool_name,
1992 Err(anyhow!("Missing tool call: {error_message}")),
1993 self.configured_model.as_ref(),
1994 );
1995
1996 cx.emit(ThreadEvent::MissingToolUse {
1997 tool_use_id: tool_use_id.clone(),
1998 ui_text: error_message.into(),
1999 });
2000
2001 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2002 }
2003
2004 pub fn receive_invalid_tool_json(
2005 &mut self,
2006 tool_use_id: LanguageModelToolUseId,
2007 tool_name: Arc<str>,
2008 invalid_json: Arc<str>,
2009 error: String,
2010 window: Option<AnyWindowHandle>,
2011 cx: &mut Context<Thread>,
2012 ) {
2013 log::error!("The model returned invalid input JSON: {invalid_json}");
2014
2015 let pending_tool_use = self.tool_use.insert_tool_output(
2016 tool_use_id.clone(),
2017 tool_name,
2018 Err(anyhow!("Error parsing input JSON: {error}")),
2019 self.configured_model.as_ref(),
2020 );
2021 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2022 pending_tool_use.ui_text.clone()
2023 } else {
2024 log::error!(
2025 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2026 );
2027 format!("Unknown tool {}", tool_use_id).into()
2028 };
2029
2030 cx.emit(ThreadEvent::InvalidToolInput {
2031 tool_use_id: tool_use_id.clone(),
2032 ui_text,
2033 invalid_input_json: invalid_json,
2034 });
2035
2036 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2037 }
2038
2039 pub fn run_tool(
2040 &mut self,
2041 tool_use_id: LanguageModelToolUseId,
2042 ui_text: impl Into<SharedString>,
2043 input: serde_json::Value,
2044 messages: &[LanguageModelRequestMessage],
2045 tool: Arc<dyn Tool>,
2046 model: Arc<dyn LanguageModel>,
2047 window: Option<AnyWindowHandle>,
2048 cx: &mut Context<Thread>,
2049 ) {
2050 let task = self.spawn_tool_use(
2051 tool_use_id.clone(),
2052 messages,
2053 input,
2054 tool,
2055 model,
2056 window,
2057 cx,
2058 );
2059 self.tool_use
2060 .run_pending_tool(tool_use_id, ui_text.into(), task);
2061 }
2062
2063 fn spawn_tool_use(
2064 &mut self,
2065 tool_use_id: LanguageModelToolUseId,
2066 messages: &[LanguageModelRequestMessage],
2067 input: serde_json::Value,
2068 tool: Arc<dyn Tool>,
2069 model: Arc<dyn LanguageModel>,
2070 window: Option<AnyWindowHandle>,
2071 cx: &mut Context<Thread>,
2072 ) -> Task<()> {
2073 let tool_name: Arc<str> = tool.name().into();
2074
2075 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2076 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2077 } else {
2078 tool.run(
2079 input,
2080 messages,
2081 self.project.clone(),
2082 self.action_log.clone(),
2083 model,
2084 window,
2085 cx,
2086 )
2087 };
2088
2089 // Store the card separately if it exists
2090 if let Some(card) = tool_result.card.clone() {
2091 self.tool_use
2092 .insert_tool_result_card(tool_use_id.clone(), card);
2093 }
2094
2095 cx.spawn({
2096 async move |thread: WeakEntity<Thread>, cx| {
2097 let output = tool_result.output.await;
2098
2099 thread
2100 .update(cx, |thread, cx| {
2101 let pending_tool_use = thread.tool_use.insert_tool_output(
2102 tool_use_id.clone(),
2103 tool_name,
2104 output,
2105 thread.configured_model.as_ref(),
2106 );
2107 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2108 })
2109 .ok();
2110 }
2111 })
2112 }
2113
2114 fn tool_finished(
2115 &mut self,
2116 tool_use_id: LanguageModelToolUseId,
2117 pending_tool_use: Option<PendingToolUse>,
2118 canceled: bool,
2119 window: Option<AnyWindowHandle>,
2120 cx: &mut Context<Self>,
2121 ) {
2122 if self.all_tools_finished() {
2123 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2124 if !canceled {
2125 self.send_to_model(model.clone(), window, cx);
2126 }
2127 self.auto_capture_telemetry(cx);
2128 }
2129 }
2130
2131 cx.emit(ThreadEvent::ToolFinished {
2132 tool_use_id,
2133 pending_tool_use,
2134 });
2135 }
2136
2137 /// Cancels the last pending completion, if there are any pending.
2138 ///
2139 /// Returns whether a completion was canceled.
2140 pub fn cancel_last_completion(
2141 &mut self,
2142 window: Option<AnyWindowHandle>,
2143 cx: &mut Context<Self>,
2144 ) -> bool {
2145 let mut canceled = self.pending_completions.pop().is_some();
2146
2147 for pending_tool_use in self.tool_use.cancel_pending() {
2148 canceled = true;
2149 self.tool_finished(
2150 pending_tool_use.id.clone(),
2151 Some(pending_tool_use),
2152 true,
2153 window,
2154 cx,
2155 );
2156 }
2157
2158 self.finalize_pending_checkpoint(cx);
2159
2160 if canceled {
2161 cx.emit(ThreadEvent::CompletionCanceled);
2162 }
2163
2164 canceled
2165 }
2166
2167 /// Signals that any in-progress editing should be canceled.
2168 ///
2169 /// This method is used to notify listeners (like ActiveThread) that
2170 /// they should cancel any editing operations.
2171 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2172 cx.emit(ThreadEvent::CancelEditing);
2173 }
2174
2175 pub fn feedback(&self) -> Option<ThreadFeedback> {
2176 self.feedback
2177 }
2178
2179 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2180 self.message_feedback.get(&message_id).copied()
2181 }
2182
2183 pub fn report_message_feedback(
2184 &mut self,
2185 message_id: MessageId,
2186 feedback: ThreadFeedback,
2187 cx: &mut Context<Self>,
2188 ) -> Task<Result<()>> {
2189 if self.message_feedback.get(&message_id) == Some(&feedback) {
2190 return Task::ready(Ok(()));
2191 }
2192
2193 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2194 let serialized_thread = self.serialize(cx);
2195 let thread_id = self.id().clone();
2196 let client = self.project.read(cx).client();
2197
2198 let enabled_tool_names: Vec<String> = self
2199 .tools()
2200 .read(cx)
2201 .enabled_tools(cx)
2202 .iter()
2203 .map(|tool| tool.name().to_string())
2204 .collect();
2205
2206 self.message_feedback.insert(message_id, feedback);
2207
2208 cx.notify();
2209
2210 let message_content = self
2211 .message(message_id)
2212 .map(|msg| msg.to_string())
2213 .unwrap_or_default();
2214
2215 cx.background_spawn(async move {
2216 let final_project_snapshot = final_project_snapshot.await;
2217 let serialized_thread = serialized_thread.await?;
2218 let thread_data =
2219 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2220
2221 let rating = match feedback {
2222 ThreadFeedback::Positive => "positive",
2223 ThreadFeedback::Negative => "negative",
2224 };
2225 telemetry::event!(
2226 "Assistant Thread Rated",
2227 rating,
2228 thread_id,
2229 enabled_tool_names,
2230 message_id = message_id.0,
2231 message_content,
2232 thread_data,
2233 final_project_snapshot
2234 );
2235 client.telemetry().flush_events().await;
2236
2237 Ok(())
2238 })
2239 }
2240
2241 pub fn report_feedback(
2242 &mut self,
2243 feedback: ThreadFeedback,
2244 cx: &mut Context<Self>,
2245 ) -> Task<Result<()>> {
2246 let last_assistant_message_id = self
2247 .messages
2248 .iter()
2249 .rev()
2250 .find(|msg| msg.role == Role::Assistant)
2251 .map(|msg| msg.id);
2252
2253 if let Some(message_id) = last_assistant_message_id {
2254 self.report_message_feedback(message_id, feedback, cx)
2255 } else {
2256 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2257 let serialized_thread = self.serialize(cx);
2258 let thread_id = self.id().clone();
2259 let client = self.project.read(cx).client();
2260 self.feedback = Some(feedback);
2261 cx.notify();
2262
2263 cx.background_spawn(async move {
2264 let final_project_snapshot = final_project_snapshot.await;
2265 let serialized_thread = serialized_thread.await?;
2266 let thread_data = serde_json::to_value(serialized_thread)
2267 .unwrap_or_else(|_| serde_json::Value::Null);
2268
2269 let rating = match feedback {
2270 ThreadFeedback::Positive => "positive",
2271 ThreadFeedback::Negative => "negative",
2272 };
2273 telemetry::event!(
2274 "Assistant Thread Rated",
2275 rating,
2276 thread_id,
2277 thread_data,
2278 final_project_snapshot
2279 );
2280 client.telemetry().flush_events().await;
2281
2282 Ok(())
2283 })
2284 }
2285 }
2286
2287 /// Create a snapshot of the current project state including git information and unsaved buffers.
2288 fn project_snapshot(
2289 project: Entity<Project>,
2290 cx: &mut Context<Self>,
2291 ) -> Task<Arc<ProjectSnapshot>> {
2292 let git_store = project.read(cx).git_store().clone();
2293 let worktree_snapshots: Vec<_> = project
2294 .read(cx)
2295 .visible_worktrees(cx)
2296 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2297 .collect();
2298
2299 cx.spawn(async move |_, cx| {
2300 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2301
2302 let mut unsaved_buffers = Vec::new();
2303 cx.update(|app_cx| {
2304 let buffer_store = project.read(app_cx).buffer_store();
2305 for buffer_handle in buffer_store.read(app_cx).buffers() {
2306 let buffer = buffer_handle.read(app_cx);
2307 if buffer.is_dirty() {
2308 if let Some(file) = buffer.file() {
2309 let path = file.path().to_string_lossy().to_string();
2310 unsaved_buffers.push(path);
2311 }
2312 }
2313 }
2314 })
2315 .ok();
2316
2317 Arc::new(ProjectSnapshot {
2318 worktree_snapshots,
2319 unsaved_buffer_paths: unsaved_buffers,
2320 timestamp: Utc::now(),
2321 })
2322 })
2323 }
2324
2325 fn worktree_snapshot(
2326 worktree: Entity<project::Worktree>,
2327 git_store: Entity<GitStore>,
2328 cx: &App,
2329 ) -> Task<WorktreeSnapshot> {
2330 cx.spawn(async move |cx| {
2331 // Get worktree path and snapshot
2332 let worktree_info = cx.update(|app_cx| {
2333 let worktree = worktree.read(app_cx);
2334 let path = worktree.abs_path().to_string_lossy().to_string();
2335 let snapshot = worktree.snapshot();
2336 (path, snapshot)
2337 });
2338
2339 let Ok((worktree_path, _snapshot)) = worktree_info else {
2340 return WorktreeSnapshot {
2341 worktree_path: String::new(),
2342 git_state: None,
2343 };
2344 };
2345
2346 let git_state = git_store
2347 .update(cx, |git_store, cx| {
2348 git_store
2349 .repositories()
2350 .values()
2351 .find(|repo| {
2352 repo.read(cx)
2353 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2354 .is_some()
2355 })
2356 .cloned()
2357 })
2358 .ok()
2359 .flatten()
2360 .map(|repo| {
2361 repo.update(cx, |repo, _| {
2362 let current_branch =
2363 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2364 repo.send_job(None, |state, _| async move {
2365 let RepositoryState::Local { backend, .. } = state else {
2366 return GitState {
2367 remote_url: None,
2368 head_sha: None,
2369 current_branch,
2370 diff: None,
2371 };
2372 };
2373
2374 let remote_url = backend.remote_url("origin");
2375 let head_sha = backend.head_sha().await;
2376 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2377
2378 GitState {
2379 remote_url,
2380 head_sha,
2381 current_branch,
2382 diff,
2383 }
2384 })
2385 })
2386 });
2387
2388 let git_state = match git_state {
2389 Some(git_state) => match git_state.ok() {
2390 Some(git_state) => git_state.await.ok(),
2391 None => None,
2392 },
2393 None => None,
2394 };
2395
2396 WorktreeSnapshot {
2397 worktree_path,
2398 git_state,
2399 }
2400 })
2401 }
2402
2403 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2404 let mut markdown = Vec::new();
2405
2406 if let Some(summary) = self.summary() {
2407 writeln!(markdown, "# {summary}\n")?;
2408 };
2409
2410 for message in self.messages() {
2411 writeln!(
2412 markdown,
2413 "## {role}\n",
2414 role = match message.role {
2415 Role::User => "User",
2416 Role::Assistant => "Agent",
2417 Role::System => "System",
2418 }
2419 )?;
2420
2421 if !message.loaded_context.text.is_empty() {
2422 writeln!(markdown, "{}", message.loaded_context.text)?;
2423 }
2424
2425 if !message.loaded_context.images.is_empty() {
2426 writeln!(
2427 markdown,
2428 "\n{} images attached as context.\n",
2429 message.loaded_context.images.len()
2430 )?;
2431 }
2432
2433 for segment in &message.segments {
2434 match segment {
2435 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2436 MessageSegment::Thinking { text, .. } => {
2437 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2438 }
2439 MessageSegment::RedactedThinking(_) => {}
2440 }
2441 }
2442
2443 for tool_use in self.tool_uses_for_message(message.id, cx) {
2444 writeln!(
2445 markdown,
2446 "**Use Tool: {} ({})**",
2447 tool_use.name, tool_use.id
2448 )?;
2449 writeln!(markdown, "```json")?;
2450 writeln!(
2451 markdown,
2452 "{}",
2453 serde_json::to_string_pretty(&tool_use.input)?
2454 )?;
2455 writeln!(markdown, "```")?;
2456 }
2457
2458 for tool_result in self.tool_results_for_message(message.id) {
2459 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2460 if tool_result.is_error {
2461 write!(markdown, " (Error)")?;
2462 }
2463
2464 writeln!(markdown, "**\n")?;
2465 writeln!(markdown, "{}", tool_result.content)?;
2466 }
2467 }
2468
2469 Ok(String::from_utf8_lossy(&markdown).to_string())
2470 }
2471
2472 pub fn keep_edits_in_range(
2473 &mut self,
2474 buffer: Entity<language::Buffer>,
2475 buffer_range: Range<language::Anchor>,
2476 cx: &mut Context<Self>,
2477 ) {
2478 self.action_log.update(cx, |action_log, cx| {
2479 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2480 });
2481 }
2482
2483 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2484 self.action_log
2485 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2486 }
2487
2488 pub fn reject_edits_in_ranges(
2489 &mut self,
2490 buffer: Entity<language::Buffer>,
2491 buffer_ranges: Vec<Range<language::Anchor>>,
2492 cx: &mut Context<Self>,
2493 ) -> Task<Result<()>> {
2494 self.action_log.update(cx, |action_log, cx| {
2495 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2496 })
2497 }
2498
2499 pub fn action_log(&self) -> &Entity<ActionLog> {
2500 &self.action_log
2501 }
2502
2503 pub fn project(&self) -> &Entity<Project> {
2504 &self.project
2505 }
2506
2507 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2508 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2509 return;
2510 }
2511
2512 let now = Instant::now();
2513 if let Some(last) = self.last_auto_capture_at {
2514 if now.duration_since(last).as_secs() < 10 {
2515 return;
2516 }
2517 }
2518
2519 self.last_auto_capture_at = Some(now);
2520
2521 let thread_id = self.id().clone();
2522 let github_login = self
2523 .project
2524 .read(cx)
2525 .user_store()
2526 .read(cx)
2527 .current_user()
2528 .map(|user| user.github_login.clone());
2529 let client = self.project.read(cx).client().clone();
2530 let serialize_task = self.serialize(cx);
2531
2532 cx.background_executor()
2533 .spawn(async move {
2534 if let Ok(serialized_thread) = serialize_task.await {
2535 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2536 telemetry::event!(
2537 "Agent Thread Auto-Captured",
2538 thread_id = thread_id.to_string(),
2539 thread_data = thread_data,
2540 auto_capture_reason = "tracked_user",
2541 github_login = github_login
2542 );
2543
2544 client.telemetry().flush_events().await;
2545 }
2546 }
2547 })
2548 .detach();
2549 }
2550
2551 pub fn cumulative_token_usage(&self) -> TokenUsage {
2552 self.cumulative_token_usage
2553 }
2554
2555 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2556 let Some(model) = self.configured_model.as_ref() else {
2557 return TotalTokenUsage::default();
2558 };
2559
2560 let max = model.model.max_token_count();
2561
2562 let index = self
2563 .messages
2564 .iter()
2565 .position(|msg| msg.id == message_id)
2566 .unwrap_or(0);
2567
2568 if index == 0 {
2569 return TotalTokenUsage { total: 0, max };
2570 }
2571
2572 let token_usage = &self
2573 .request_token_usage
2574 .get(index - 1)
2575 .cloned()
2576 .unwrap_or_default();
2577
2578 TotalTokenUsage {
2579 total: token_usage.total_tokens() as usize,
2580 max,
2581 }
2582 }
2583
2584 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2585 let model = self.configured_model.as_ref()?;
2586
2587 let max = model.model.max_token_count();
2588
2589 if let Some(exceeded_error) = &self.exceeded_window_error {
2590 if model.model.id() == exceeded_error.model_id {
2591 return Some(TotalTokenUsage {
2592 total: exceeded_error.token_count,
2593 max,
2594 });
2595 }
2596 }
2597
2598 let total = self
2599 .token_usage_at_last_message()
2600 .unwrap_or_default()
2601 .total_tokens() as usize;
2602
2603 Some(TotalTokenUsage { total, max })
2604 }
2605
2606 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2607 self.request_token_usage
2608 .get(self.messages.len().saturating_sub(1))
2609 .or_else(|| self.request_token_usage.last())
2610 .cloned()
2611 }
2612
2613 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2614 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2615 self.request_token_usage
2616 .resize(self.messages.len(), placeholder);
2617
2618 if let Some(last) = self.request_token_usage.last_mut() {
2619 *last = token_usage;
2620 }
2621 }
2622
2623 pub fn deny_tool_use(
2624 &mut self,
2625 tool_use_id: LanguageModelToolUseId,
2626 tool_name: Arc<str>,
2627 window: Option<AnyWindowHandle>,
2628 cx: &mut Context<Self>,
2629 ) {
2630 let err = Err(anyhow::anyhow!(
2631 "Permission to run tool action denied by user"
2632 ));
2633
2634 self.tool_use.insert_tool_output(
2635 tool_use_id.clone(),
2636 tool_name,
2637 err,
2638 self.configured_model.as_ref(),
2639 );
2640 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2641 }
2642}
2643
2644#[derive(Debug, Clone, Error)]
2645pub enum ThreadError {
2646 #[error("Payment required")]
2647 PaymentRequired,
2648 #[error("Max monthly spend reached")]
2649 MaxMonthlySpendReached,
2650 #[error("Model request limit reached")]
2651 ModelRequestLimitReached { plan: Plan },
2652 #[error("Message {header}: {message}")]
2653 Message {
2654 header: SharedString,
2655 message: SharedString,
2656 },
2657}
2658
2659#[derive(Debug, Clone)]
2660pub enum ThreadEvent {
2661 ShowError(ThreadError),
2662 StreamedCompletion,
2663 ReceivedTextChunk,
2664 NewRequest,
2665 StreamedAssistantText(MessageId, String),
2666 StreamedAssistantThinking(MessageId, String),
2667 StreamedToolUse {
2668 tool_use_id: LanguageModelToolUseId,
2669 ui_text: Arc<str>,
2670 input: serde_json::Value,
2671 },
2672 MissingToolUse {
2673 tool_use_id: LanguageModelToolUseId,
2674 ui_text: Arc<str>,
2675 },
2676 InvalidToolInput {
2677 tool_use_id: LanguageModelToolUseId,
2678 ui_text: Arc<str>,
2679 invalid_input_json: Arc<str>,
2680 },
2681 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2682 MessageAdded(MessageId),
2683 MessageEdited(MessageId),
2684 MessageDeleted(MessageId),
2685 SummaryGenerated,
2686 SummaryChanged,
2687 UsePendingTools {
2688 tool_uses: Vec<PendingToolUse>,
2689 },
2690 ToolFinished {
2691 #[allow(unused)]
2692 tool_use_id: LanguageModelToolUseId,
2693 /// The pending tool use that corresponds to this tool.
2694 pending_tool_use: Option<PendingToolUse>,
2695 },
2696 CheckpointChanged,
2697 ToolConfirmationNeeded,
2698 CancelEditing,
2699 CompletionCanceled,
2700}
2701
2702impl EventEmitter<ThreadEvent> for Thread {}
2703
2704struct PendingCompletion {
2705 id: usize,
2706 queue_state: QueueState,
2707 _task: Task<()>,
2708}
2709
2710#[cfg(test)]
2711mod tests {
2712 use super::*;
2713 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2714 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2715 use assistant_tool::ToolRegistry;
2716 use editor::EditorSettings;
2717 use gpui::TestAppContext;
2718 use language_model::fake_provider::FakeLanguageModel;
2719 use project::{FakeFs, Project};
2720 use prompt_store::PromptBuilder;
2721 use serde_json::json;
2722 use settings::{Settings, SettingsStore};
2723 use std::sync::Arc;
2724 use theme::ThemeSettings;
2725 use util::path;
2726 use workspace::Workspace;
2727
2728 #[gpui::test]
2729 async fn test_message_with_context(cx: &mut TestAppContext) {
2730 init_test_settings(cx);
2731
2732 let project = create_test_project(
2733 cx,
2734 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2735 )
2736 .await;
2737
2738 let (_workspace, _thread_store, thread, context_store, model) =
2739 setup_test_environment(cx, project.clone()).await;
2740
2741 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2742 .await
2743 .unwrap();
2744
2745 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2746 let loaded_context = cx
2747 .update(|cx| load_context(vec![context], &project, &None, cx))
2748 .await;
2749
2750 // Insert user message with context
2751 let message_id = thread.update(cx, |thread, cx| {
2752 thread.insert_user_message(
2753 "Please explain this code",
2754 loaded_context,
2755 None,
2756 Vec::new(),
2757 cx,
2758 )
2759 });
2760
2761 // Check content and context in message object
2762 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2763
2764 // Use different path format strings based on platform for the test
2765 #[cfg(windows)]
2766 let path_part = r"test\code.rs";
2767 #[cfg(not(windows))]
2768 let path_part = "test/code.rs";
2769
2770 let expected_context = format!(
2771 r#"
2772<context>
2773The following items were attached by the user. They are up-to-date and don't need to be re-read.
2774
2775<files>
2776```rs {path_part}
2777fn main() {{
2778 println!("Hello, world!");
2779}}
2780```
2781</files>
2782</context>
2783"#
2784 );
2785
2786 assert_eq!(message.role, Role::User);
2787 assert_eq!(message.segments.len(), 1);
2788 assert_eq!(
2789 message.segments[0],
2790 MessageSegment::Text("Please explain this code".to_string())
2791 );
2792 assert_eq!(message.loaded_context.text, expected_context);
2793
2794 // Check message in request
2795 let request = thread.update(cx, |thread, cx| {
2796 thread.to_completion_request(model.clone(), cx)
2797 });
2798
2799 assert_eq!(request.messages.len(), 2);
2800 let expected_full_message = format!("{}Please explain this code", expected_context);
2801 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2802 }
2803
2804 #[gpui::test]
2805 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2806 init_test_settings(cx);
2807
2808 let project = create_test_project(
2809 cx,
2810 json!({
2811 "file1.rs": "fn function1() {}\n",
2812 "file2.rs": "fn function2() {}\n",
2813 "file3.rs": "fn function3() {}\n",
2814 "file4.rs": "fn function4() {}\n",
2815 }),
2816 )
2817 .await;
2818
2819 let (_, _thread_store, thread, context_store, model) =
2820 setup_test_environment(cx, project.clone()).await;
2821
2822 // First message with context 1
2823 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2824 .await
2825 .unwrap();
2826 let new_contexts = context_store.update(cx, |store, cx| {
2827 store.new_context_for_thread(thread.read(cx), None)
2828 });
2829 assert_eq!(new_contexts.len(), 1);
2830 let loaded_context = cx
2831 .update(|cx| load_context(new_contexts, &project, &None, cx))
2832 .await;
2833 let message1_id = thread.update(cx, |thread, cx| {
2834 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2835 });
2836
2837 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2838 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2839 .await
2840 .unwrap();
2841 let new_contexts = context_store.update(cx, |store, cx| {
2842 store.new_context_for_thread(thread.read(cx), None)
2843 });
2844 assert_eq!(new_contexts.len(), 1);
2845 let loaded_context = cx
2846 .update(|cx| load_context(new_contexts, &project, &None, cx))
2847 .await;
2848 let message2_id = thread.update(cx, |thread, cx| {
2849 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2850 });
2851
2852 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2853 //
2854 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2855 .await
2856 .unwrap();
2857 let new_contexts = context_store.update(cx, |store, cx| {
2858 store.new_context_for_thread(thread.read(cx), None)
2859 });
2860 assert_eq!(new_contexts.len(), 1);
2861 let loaded_context = cx
2862 .update(|cx| load_context(new_contexts, &project, &None, cx))
2863 .await;
2864 let message3_id = thread.update(cx, |thread, cx| {
2865 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2866 });
2867
2868 // Check what contexts are included in each message
2869 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2870 (
2871 thread.message(message1_id).unwrap().clone(),
2872 thread.message(message2_id).unwrap().clone(),
2873 thread.message(message3_id).unwrap().clone(),
2874 )
2875 });
2876
2877 // First message should include context 1
2878 assert!(message1.loaded_context.text.contains("file1.rs"));
2879
2880 // Second message should include only context 2 (not 1)
2881 assert!(!message2.loaded_context.text.contains("file1.rs"));
2882 assert!(message2.loaded_context.text.contains("file2.rs"));
2883
2884 // Third message should include only context 3 (not 1 or 2)
2885 assert!(!message3.loaded_context.text.contains("file1.rs"));
2886 assert!(!message3.loaded_context.text.contains("file2.rs"));
2887 assert!(message3.loaded_context.text.contains("file3.rs"));
2888
2889 // Check entire request to make sure all contexts are properly included
2890 let request = thread.update(cx, |thread, cx| {
2891 thread.to_completion_request(model.clone(), cx)
2892 });
2893
2894 // The request should contain all 3 messages
2895 assert_eq!(request.messages.len(), 4);
2896
2897 // Check that the contexts are properly formatted in each message
2898 assert!(request.messages[1].string_contents().contains("file1.rs"));
2899 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2900 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2901
2902 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2903 assert!(request.messages[2].string_contents().contains("file2.rs"));
2904 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2905
2906 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2907 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2908 assert!(request.messages[3].string_contents().contains("file3.rs"));
2909
2910 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2911 .await
2912 .unwrap();
2913 let new_contexts = context_store.update(cx, |store, cx| {
2914 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2915 });
2916 assert_eq!(new_contexts.len(), 3);
2917 let loaded_context = cx
2918 .update(|cx| load_context(new_contexts, &project, &None, cx))
2919 .await
2920 .loaded_context;
2921
2922 assert!(!loaded_context.text.contains("file1.rs"));
2923 assert!(loaded_context.text.contains("file2.rs"));
2924 assert!(loaded_context.text.contains("file3.rs"));
2925 assert!(loaded_context.text.contains("file4.rs"));
2926
2927 let new_contexts = context_store.update(cx, |store, cx| {
2928 // Remove file4.rs
2929 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2930 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2931 });
2932 assert_eq!(new_contexts.len(), 2);
2933 let loaded_context = cx
2934 .update(|cx| load_context(new_contexts, &project, &None, cx))
2935 .await
2936 .loaded_context;
2937
2938 assert!(!loaded_context.text.contains("file1.rs"));
2939 assert!(loaded_context.text.contains("file2.rs"));
2940 assert!(loaded_context.text.contains("file3.rs"));
2941 assert!(!loaded_context.text.contains("file4.rs"));
2942
2943 let new_contexts = context_store.update(cx, |store, cx| {
2944 // Remove file3.rs
2945 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2946 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2947 });
2948 assert_eq!(new_contexts.len(), 1);
2949 let loaded_context = cx
2950 .update(|cx| load_context(new_contexts, &project, &None, cx))
2951 .await
2952 .loaded_context;
2953
2954 assert!(!loaded_context.text.contains("file1.rs"));
2955 assert!(loaded_context.text.contains("file2.rs"));
2956 assert!(!loaded_context.text.contains("file3.rs"));
2957 assert!(!loaded_context.text.contains("file4.rs"));
2958 }
2959
2960 #[gpui::test]
2961 async fn test_message_without_files(cx: &mut TestAppContext) {
2962 init_test_settings(cx);
2963
2964 let project = create_test_project(
2965 cx,
2966 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2967 )
2968 .await;
2969
2970 let (_, _thread_store, thread, _context_store, model) =
2971 setup_test_environment(cx, project.clone()).await;
2972
2973 // Insert user message without any context (empty context vector)
2974 let message_id = thread.update(cx, |thread, cx| {
2975 thread.insert_user_message(
2976 "What is the best way to learn Rust?",
2977 ContextLoadResult::default(),
2978 None,
2979 Vec::new(),
2980 cx,
2981 )
2982 });
2983
2984 // Check content and context in message object
2985 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2986
2987 // Context should be empty when no files are included
2988 assert_eq!(message.role, Role::User);
2989 assert_eq!(message.segments.len(), 1);
2990 assert_eq!(
2991 message.segments[0],
2992 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2993 );
2994 assert_eq!(message.loaded_context.text, "");
2995
2996 // Check message in request
2997 let request = thread.update(cx, |thread, cx| {
2998 thread.to_completion_request(model.clone(), cx)
2999 });
3000
3001 assert_eq!(request.messages.len(), 2);
3002 assert_eq!(
3003 request.messages[1].string_contents(),
3004 "What is the best way to learn Rust?"
3005 );
3006
3007 // Add second message, also without context
3008 let message2_id = thread.update(cx, |thread, cx| {
3009 thread.insert_user_message(
3010 "Are there any good books?",
3011 ContextLoadResult::default(),
3012 None,
3013 Vec::new(),
3014 cx,
3015 )
3016 });
3017
3018 let message2 =
3019 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3020 assert_eq!(message2.loaded_context.text, "");
3021
3022 // Check that both messages appear in the request
3023 let request = thread.update(cx, |thread, cx| {
3024 thread.to_completion_request(model.clone(), cx)
3025 });
3026
3027 assert_eq!(request.messages.len(), 3);
3028 assert_eq!(
3029 request.messages[1].string_contents(),
3030 "What is the best way to learn Rust?"
3031 );
3032 assert_eq!(
3033 request.messages[2].string_contents(),
3034 "Are there any good books?"
3035 );
3036 }
3037
3038 #[gpui::test]
3039 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3040 init_test_settings(cx);
3041
3042 let project = create_test_project(
3043 cx,
3044 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3045 )
3046 .await;
3047
3048 let (_workspace, _thread_store, thread, context_store, model) =
3049 setup_test_environment(cx, project.clone()).await;
3050
3051 // Open buffer and add it to context
3052 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3053 .await
3054 .unwrap();
3055
3056 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3057 let loaded_context = cx
3058 .update(|cx| load_context(vec![context], &project, &None, cx))
3059 .await;
3060
3061 // Insert user message with the buffer as context
3062 thread.update(cx, |thread, cx| {
3063 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3064 });
3065
3066 // Create a request and check that it doesn't have a stale buffer warning yet
3067 let initial_request = thread.update(cx, |thread, cx| {
3068 thread.to_completion_request(model.clone(), cx)
3069 });
3070
3071 // Make sure we don't have a stale file warning yet
3072 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3073 msg.string_contents()
3074 .contains("These files changed since last read:")
3075 });
3076 assert!(
3077 !has_stale_warning,
3078 "Should not have stale buffer warning before buffer is modified"
3079 );
3080
3081 // Modify the buffer
3082 buffer.update(cx, |buffer, cx| {
3083 // Find a position at the end of line 1
3084 buffer.edit(
3085 [(1..1, "\n println!(\"Added a new line\");\n")],
3086 None,
3087 cx,
3088 );
3089 });
3090
3091 // Insert another user message without context
3092 thread.update(cx, |thread, cx| {
3093 thread.insert_user_message(
3094 "What does the code do now?",
3095 ContextLoadResult::default(),
3096 None,
3097 Vec::new(),
3098 cx,
3099 )
3100 });
3101
3102 // Create a new request and check for the stale buffer warning
3103 let new_request = thread.update(cx, |thread, cx| {
3104 thread.to_completion_request(model.clone(), cx)
3105 });
3106
3107 // We should have a stale file warning as the last message
3108 let last_message = new_request
3109 .messages
3110 .last()
3111 .expect("Request should have messages");
3112
3113 // The last message should be the stale buffer notification
3114 assert_eq!(last_message.role, Role::User);
3115
3116 // Check the exact content of the message
3117 let expected_content = "These files changed since last read:\n- code.rs\n";
3118 assert_eq!(
3119 last_message.string_contents(),
3120 expected_content,
3121 "Last message should be exactly the stale buffer notification"
3122 );
3123 }
3124
3125 #[gpui::test]
3126 async fn test_temperature_setting(cx: &mut TestAppContext) {
3127 init_test_settings(cx);
3128
3129 let project = create_test_project(
3130 cx,
3131 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3132 )
3133 .await;
3134
3135 let (_workspace, _thread_store, thread, _context_store, model) =
3136 setup_test_environment(cx, project.clone()).await;
3137
3138 // Both model and provider
3139 cx.update(|cx| {
3140 AssistantSettings::override_global(
3141 AssistantSettings {
3142 model_parameters: vec![LanguageModelParameters {
3143 provider: Some(model.provider_id().0.to_string().into()),
3144 model: Some(model.id().0.clone()),
3145 temperature: Some(0.66),
3146 }],
3147 ..AssistantSettings::get_global(cx).clone()
3148 },
3149 cx,
3150 );
3151 });
3152
3153 let request = thread.update(cx, |thread, cx| {
3154 thread.to_completion_request(model.clone(), cx)
3155 });
3156 assert_eq!(request.temperature, Some(0.66));
3157
3158 // Only model
3159 cx.update(|cx| {
3160 AssistantSettings::override_global(
3161 AssistantSettings {
3162 model_parameters: vec![LanguageModelParameters {
3163 provider: None,
3164 model: Some(model.id().0.clone()),
3165 temperature: Some(0.66),
3166 }],
3167 ..AssistantSettings::get_global(cx).clone()
3168 },
3169 cx,
3170 );
3171 });
3172
3173 let request = thread.update(cx, |thread, cx| {
3174 thread.to_completion_request(model.clone(), cx)
3175 });
3176 assert_eq!(request.temperature, Some(0.66));
3177
3178 // Only provider
3179 cx.update(|cx| {
3180 AssistantSettings::override_global(
3181 AssistantSettings {
3182 model_parameters: vec![LanguageModelParameters {
3183 provider: Some(model.provider_id().0.to_string().into()),
3184 model: None,
3185 temperature: Some(0.66),
3186 }],
3187 ..AssistantSettings::get_global(cx).clone()
3188 },
3189 cx,
3190 );
3191 });
3192
3193 let request = thread.update(cx, |thread, cx| {
3194 thread.to_completion_request(model.clone(), cx)
3195 });
3196 assert_eq!(request.temperature, Some(0.66));
3197
3198 // Same model name, different provider
3199 cx.update(|cx| {
3200 AssistantSettings::override_global(
3201 AssistantSettings {
3202 model_parameters: vec![LanguageModelParameters {
3203 provider: Some("anthropic".into()),
3204 model: Some(model.id().0.clone()),
3205 temperature: Some(0.66),
3206 }],
3207 ..AssistantSettings::get_global(cx).clone()
3208 },
3209 cx,
3210 );
3211 });
3212
3213 let request = thread.update(cx, |thread, cx| {
3214 thread.to_completion_request(model.clone(), cx)
3215 });
3216 assert_eq!(request.temperature, None);
3217 }
3218
3219 fn init_test_settings(cx: &mut TestAppContext) {
3220 cx.update(|cx| {
3221 let settings_store = SettingsStore::test(cx);
3222 cx.set_global(settings_store);
3223 language::init(cx);
3224 Project::init_settings(cx);
3225 AssistantSettings::register(cx);
3226 prompt_store::init(cx);
3227 thread_store::init(cx);
3228 workspace::init_settings(cx);
3229 language_model::init_settings(cx);
3230 ThemeSettings::register(cx);
3231 EditorSettings::register(cx);
3232 ToolRegistry::default_global(cx);
3233 });
3234 }
3235
3236 // Helper to create a test project with test files
3237 async fn create_test_project(
3238 cx: &mut TestAppContext,
3239 files: serde_json::Value,
3240 ) -> Entity<Project> {
3241 let fs = FakeFs::new(cx.executor());
3242 fs.insert_tree(path!("/test"), files).await;
3243 Project::test(fs, [path!("/test").as_ref()], cx).await
3244 }
3245
3246 async fn setup_test_environment(
3247 cx: &mut TestAppContext,
3248 project: Entity<Project>,
3249 ) -> (
3250 Entity<Workspace>,
3251 Entity<ThreadStore>,
3252 Entity<Thread>,
3253 Entity<ContextStore>,
3254 Arc<dyn LanguageModel>,
3255 ) {
3256 let (workspace, cx) =
3257 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3258
3259 let thread_store = cx
3260 .update(|_, cx| {
3261 ThreadStore::load(
3262 project.clone(),
3263 cx.new(|_| ToolWorkingSet::default()),
3264 None,
3265 Arc::new(PromptBuilder::new(None).unwrap()),
3266 cx,
3267 )
3268 })
3269 .await
3270 .unwrap();
3271
3272 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3273 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3274
3275 let model = FakeLanguageModel::default();
3276 let model: Arc<dyn LanguageModel> = Arc::new(model);
3277
3278 (workspace, thread_store, thread, context_store, model)
3279 }
3280
3281 async fn add_file_to_context(
3282 project: &Entity<Project>,
3283 context_store: &Entity<ContextStore>,
3284 path: &str,
3285 cx: &mut TestAppContext,
3286 ) -> Result<Entity<language::Buffer>> {
3287 let buffer_path = project
3288 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3289 .unwrap();
3290
3291 let buffer = project
3292 .update(cx, |project, cx| {
3293 project.open_buffer(buffer_path.clone(), cx)
3294 })
3295 .await
3296 .unwrap();
3297
3298 context_store.update(cx, |context_store, cx| {
3299 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3300 });
3301
3302 Ok(buffer)
3303 }
3304}