1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use anyhow::{Result, anyhow};
8use assistant_settings::{AgentProfileId, AssistantSettings, CompletionMode};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, TryFutureExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118}
119
120impl Message {
121 /// Returns whether the message contains any meaningful text that should be displayed
122 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
123 pub fn should_display_content(&self) -> bool {
124 self.segments.iter().all(|segment| segment.should_display())
125 }
126
127 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
128 if let Some(MessageSegment::Thinking {
129 text: segment,
130 signature: current_signature,
131 }) = self.segments.last_mut()
132 {
133 if let Some(signature) = signature {
134 *current_signature = Some(signature);
135 }
136 segment.push_str(text);
137 } else {
138 self.segments.push(MessageSegment::Thinking {
139 text: text.to_string(),
140 signature,
141 });
142 }
143 }
144
145 pub fn push_text(&mut self, text: &str) {
146 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
147 segment.push_str(text);
148 } else {
149 self.segments.push(MessageSegment::Text(text.to_string()));
150 }
151 }
152
153 pub fn to_string(&self) -> String {
154 let mut result = String::new();
155
156 if !self.loaded_context.text.is_empty() {
157 result.push_str(&self.loaded_context.text);
158 }
159
160 for segment in &self.segments {
161 match segment {
162 MessageSegment::Text(text) => result.push_str(text),
163 MessageSegment::Thinking { text, .. } => {
164 result.push_str("<think>\n");
165 result.push_str(text);
166 result.push_str("\n</think>");
167 }
168 MessageSegment::RedactedThinking(_) => {}
169 }
170 }
171
172 result
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum MessageSegment {
178 Text(String),
179 Thinking {
180 text: String,
181 signature: Option<String>,
182 },
183 RedactedThinking(Vec<u8>),
184}
185
186impl MessageSegment {
187 pub fn should_display(&self) -> bool {
188 match self {
189 Self::Text(text) => text.is_empty(),
190 Self::Thinking { text, .. } => text.is_empty(),
191 Self::RedactedThinking(_) => false,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProjectSnapshot {
198 pub worktree_snapshots: Vec<WorktreeSnapshot>,
199 pub unsaved_buffer_paths: Vec<String>,
200 pub timestamp: DateTime<Utc>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct WorktreeSnapshot {
205 pub worktree_path: String,
206 pub git_state: Option<GitState>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct GitState {
211 pub remote_url: Option<String>,
212 pub head_sha: Option<String>,
213 pub current_branch: Option<String>,
214 pub diff: Option<String>,
215}
216
217#[derive(Clone)]
218pub struct ThreadCheckpoint {
219 message_id: MessageId,
220 git_checkpoint: GitStoreCheckpoint,
221}
222
223#[derive(Copy, Clone, Debug, PartialEq, Eq)]
224pub enum ThreadFeedback {
225 Positive,
226 Negative,
227}
228
229pub enum LastRestoreCheckpoint {
230 Pending {
231 message_id: MessageId,
232 },
233 Error {
234 message_id: MessageId,
235 error: String,
236 },
237}
238
239impl LastRestoreCheckpoint {
240 pub fn message_id(&self) -> MessageId {
241 match self {
242 LastRestoreCheckpoint::Pending { message_id } => *message_id,
243 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
244 }
245 }
246}
247
248#[derive(Clone, Debug, Default, Serialize, Deserialize)]
249pub enum DetailedSummaryState {
250 #[default]
251 NotGenerated,
252 Generating {
253 message_id: MessageId,
254 },
255 Generated {
256 text: SharedString,
257 message_id: MessageId,
258 },
259}
260
261impl DetailedSummaryState {
262 fn text(&self) -> Option<SharedString> {
263 if let Self::Generated { text, .. } = self {
264 Some(text.clone())
265 } else {
266 None
267 }
268 }
269}
270
271#[derive(Default, Debug)]
272pub struct TotalTokenUsage {
273 pub total: usize,
274 pub max: usize,
275}
276
277impl TotalTokenUsage {
278 pub fn ratio(&self) -> TokenUsageRatio {
279 #[cfg(debug_assertions)]
280 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
281 .unwrap_or("0.8".to_string())
282 .parse()
283 .unwrap();
284 #[cfg(not(debug_assertions))]
285 let warning_threshold: f32 = 0.8;
286
287 // When the maximum is unknown because there is no selected model,
288 // avoid showing the token limit warning.
289 if self.max == 0 {
290 TokenUsageRatio::Normal
291 } else if self.total >= self.max {
292 TokenUsageRatio::Exceeded
293 } else if self.total as f32 / self.max as f32 >= warning_threshold {
294 TokenUsageRatio::Warning
295 } else {
296 TokenUsageRatio::Normal
297 }
298 }
299
300 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
301 TotalTokenUsage {
302 total: self.total + tokens,
303 max: self.max,
304 }
305 }
306}
307
308#[derive(Debug, Default, PartialEq, Eq)]
309pub enum TokenUsageRatio {
310 #[default]
311 Normal,
312 Warning,
313 Exceeded,
314}
315
316#[derive(Debug, Clone, Copy)]
317pub enum QueueState {
318 Sending,
319 Queued { position: usize },
320 Started,
321}
322
323/// A thread of conversation with the LLM.
324pub struct Thread {
325 id: ThreadId,
326 updated_at: DateTime<Utc>,
327 summary: Option<SharedString>,
328 pending_summary: Task<Option<()>>,
329 detailed_summary_task: Task<Option<()>>,
330 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
331 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
332 completion_mode: assistant_settings::CompletionMode,
333 messages: Vec<Message>,
334 next_message_id: MessageId,
335 last_prompt_id: PromptId,
336 project_context: SharedProjectContext,
337 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
338 completion_count: usize,
339 pending_completions: Vec<PendingCompletion>,
340 project: Entity<Project>,
341 prompt_builder: Arc<PromptBuilder>,
342 tools: Entity<ToolWorkingSet>,
343 tool_use: ToolUseState,
344 action_log: Entity<ActionLog>,
345 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
346 pending_checkpoint: Option<ThreadCheckpoint>,
347 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
348 request_token_usage: Vec<TokenUsage>,
349 cumulative_token_usage: TokenUsage,
350 exceeded_window_error: Option<ExceededWindowError>,
351 last_usage: Option<RequestUsage>,
352 tool_use_limit_reached: bool,
353 feedback: Option<ThreadFeedback>,
354 message_feedback: HashMap<MessageId, ThreadFeedback>,
355 last_auto_capture_at: Option<Instant>,
356 last_received_chunk_at: Option<Instant>,
357 request_callback: Option<
358 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
359 >,
360 remaining_turns: u32,
361 configured_model: Option<ConfiguredModel>,
362 configured_profile_id: Option<AgentProfileId>,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct ExceededWindowError {
367 /// Model used when last message exceeded context window
368 model_id: LanguageModelId,
369 /// Token count including last message
370 token_count: usize,
371}
372
373impl Thread {
374 pub fn new(
375 project: Entity<Project>,
376 tools: Entity<ToolWorkingSet>,
377 prompt_builder: Arc<PromptBuilder>,
378 system_prompt: SharedProjectContext,
379 cx: &mut Context<Self>,
380 ) -> Self {
381 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
382 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
383 let assistant_settings = AssistantSettings::get_global(cx);
384 let configured_profile_id = assistant_settings.default_profile.clone();
385
386 Self {
387 id: ThreadId::new(),
388 updated_at: Utc::now(),
389 summary: None,
390 pending_summary: Task::ready(None),
391 detailed_summary_task: Task::ready(None),
392 detailed_summary_tx,
393 detailed_summary_rx,
394 completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
395 messages: Vec::new(),
396 next_message_id: MessageId(0),
397 last_prompt_id: PromptId::new(),
398 project_context: system_prompt,
399 checkpoints_by_message: HashMap::default(),
400 completion_count: 0,
401 pending_completions: Vec::new(),
402 project: project.clone(),
403 prompt_builder,
404 tools: tools.clone(),
405 last_restore_checkpoint: None,
406 pending_checkpoint: None,
407 tool_use: ToolUseState::new(tools.clone()),
408 action_log: cx.new(|_| ActionLog::new(project.clone())),
409 initial_project_snapshot: {
410 let project_snapshot = Self::project_snapshot(project, cx);
411 cx.foreground_executor()
412 .spawn(async move { Some(project_snapshot.await) })
413 .shared()
414 },
415 request_token_usage: Vec::new(),
416 cumulative_token_usage: TokenUsage::default(),
417 exceeded_window_error: None,
418 last_usage: None,
419 tool_use_limit_reached: false,
420 feedback: None,
421 message_feedback: HashMap::default(),
422 last_auto_capture_at: None,
423 last_received_chunk_at: None,
424 request_callback: None,
425 remaining_turns: u32::MAX,
426 configured_model,
427 configured_profile_id: Some(configured_profile_id),
428 }
429 }
430
431 pub fn deserialize(
432 id: ThreadId,
433 serialized: SerializedThread,
434 project: Entity<Project>,
435 tools: Entity<ToolWorkingSet>,
436 prompt_builder: Arc<PromptBuilder>,
437 project_context: SharedProjectContext,
438 window: &mut Window,
439 cx: &mut Context<Self>,
440 ) -> Self {
441 let next_message_id = MessageId(
442 serialized
443 .messages
444 .last()
445 .map(|message| message.id.0 + 1)
446 .unwrap_or(0),
447 );
448 let tool_use = ToolUseState::from_serialized_messages(
449 tools.clone(),
450 &serialized.messages,
451 project.clone(),
452 window,
453 cx,
454 );
455 let (detailed_summary_tx, detailed_summary_rx) =
456 postage::watch::channel_with(serialized.detailed_summary_state);
457
458 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
459 serialized
460 .model
461 .and_then(|model| {
462 let model = SelectedModel {
463 provider: model.provider.clone().into(),
464 model: model.model.clone().into(),
465 };
466 registry.select_model(&model, cx)
467 })
468 .or_else(|| registry.default_model())
469 });
470
471 let completion_mode = serialized
472 .completion_mode
473 .unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
474
475 let configured_profile_id = serialized.profile.clone();
476
477 Self {
478 id,
479 updated_at: serialized.updated_at,
480 summary: Some(serialized.summary),
481 pending_summary: Task::ready(None),
482 detailed_summary_task: Task::ready(None),
483 detailed_summary_tx,
484 detailed_summary_rx,
485 completion_mode,
486 messages: serialized
487 .messages
488 .into_iter()
489 .map(|message| Message {
490 id: message.id,
491 role: message.role,
492 segments: message
493 .segments
494 .into_iter()
495 .map(|segment| match segment {
496 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
497 SerializedMessageSegment::Thinking { text, signature } => {
498 MessageSegment::Thinking { text, signature }
499 }
500 SerializedMessageSegment::RedactedThinking { data } => {
501 MessageSegment::RedactedThinking(data)
502 }
503 })
504 .collect(),
505 loaded_context: LoadedContext {
506 contexts: Vec::new(),
507 text: message.context,
508 images: Vec::new(),
509 },
510 creases: message
511 .creases
512 .into_iter()
513 .map(|crease| MessageCrease {
514 range: crease.start..crease.end,
515 metadata: CreaseMetadata {
516 icon_path: crease.icon_path,
517 label: crease.label,
518 },
519 context: None,
520 })
521 .collect(),
522 })
523 .collect(),
524 next_message_id,
525 last_prompt_id: PromptId::new(),
526 project_context,
527 checkpoints_by_message: HashMap::default(),
528 completion_count: 0,
529 pending_completions: Vec::new(),
530 last_restore_checkpoint: None,
531 pending_checkpoint: None,
532 project: project.clone(),
533 prompt_builder,
534 tools,
535 tool_use,
536 action_log: cx.new(|_| ActionLog::new(project)),
537 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
538 request_token_usage: serialized.request_token_usage,
539 cumulative_token_usage: serialized.cumulative_token_usage,
540 exceeded_window_error: None,
541 last_usage: None,
542 tool_use_limit_reached: false,
543 feedback: None,
544 message_feedback: HashMap::default(),
545 last_auto_capture_at: None,
546 last_received_chunk_at: None,
547 request_callback: None,
548 remaining_turns: u32::MAX,
549 configured_model,
550 configured_profile_id,
551 }
552 }
553
554 pub fn set_request_callback(
555 &mut self,
556 callback: impl 'static
557 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
558 ) {
559 self.request_callback = Some(Box::new(callback));
560 }
561
562 pub fn id(&self) -> &ThreadId {
563 &self.id
564 }
565
566 pub fn is_empty(&self) -> bool {
567 self.messages.is_empty()
568 }
569
570 pub fn updated_at(&self) -> DateTime<Utc> {
571 self.updated_at
572 }
573
574 pub fn touch_updated_at(&mut self) {
575 self.updated_at = Utc::now();
576 }
577
578 pub fn advance_prompt_id(&mut self) {
579 self.last_prompt_id = PromptId::new();
580 }
581
582 pub fn summary(&self) -> Option<SharedString> {
583 self.summary.clone()
584 }
585
586 pub fn project_context(&self) -> SharedProjectContext {
587 self.project_context.clone()
588 }
589
590 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
591 if self.configured_model.is_none() {
592 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
593 }
594 self.configured_model.clone()
595 }
596
597 pub fn configured_model(&self) -> Option<ConfiguredModel> {
598 self.configured_model.clone()
599 }
600
601 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
602 self.configured_model = model;
603 cx.notify();
604 }
605
606 pub fn configured_profile_id(&self) -> Option<AgentProfileId> {
607 self.configured_profile_id.clone()
608 }
609
610 pub fn set_configured_profile_id(
611 &mut self,
612 id: Option<AgentProfileId>,
613 cx: &mut Context<Self>,
614 ) {
615 self.configured_profile_id = id;
616 cx.notify();
617 }
618
619 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
620
621 pub fn summary_or_default(&self) -> SharedString {
622 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
623 }
624
625 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
626 let Some(current_summary) = &self.summary else {
627 // Don't allow setting summary until generated
628 return;
629 };
630
631 let mut new_summary = new_summary.into();
632
633 if new_summary.is_empty() {
634 new_summary = Self::DEFAULT_SUMMARY;
635 }
636
637 if current_summary != &new_summary {
638 self.summary = Some(new_summary);
639 cx.emit(ThreadEvent::SummaryChanged);
640 }
641 }
642
643 pub fn completion_mode(&self) -> CompletionMode {
644 self.completion_mode
645 }
646
647 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
648 self.completion_mode = mode;
649 }
650
651 pub fn message(&self, id: MessageId) -> Option<&Message> {
652 let index = self
653 .messages
654 .binary_search_by(|message| message.id.cmp(&id))
655 .ok()?;
656
657 self.messages.get(index)
658 }
659
660 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
661 self.messages.iter()
662 }
663
664 pub fn is_generating(&self) -> bool {
665 !self.pending_completions.is_empty() || !self.all_tools_finished()
666 }
667
668 /// Indicates whether streaming of language model events is stale.
669 /// When `is_generating()` is false, this method returns `None`.
670 pub fn is_generation_stale(&self) -> Option<bool> {
671 const STALE_THRESHOLD: u128 = 250;
672
673 self.last_received_chunk_at
674 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
675 }
676
677 fn received_chunk(&mut self) {
678 self.last_received_chunk_at = Some(Instant::now());
679 }
680
681 pub fn queue_state(&self) -> Option<QueueState> {
682 self.pending_completions
683 .first()
684 .map(|pending_completion| pending_completion.queue_state)
685 }
686
687 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
688 &self.tools
689 }
690
691 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
692 self.tool_use
693 .pending_tool_uses()
694 .into_iter()
695 .find(|tool_use| &tool_use.id == id)
696 }
697
698 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
699 self.tool_use
700 .pending_tool_uses()
701 .into_iter()
702 .filter(|tool_use| tool_use.status.needs_confirmation())
703 }
704
705 pub fn has_pending_tool_uses(&self) -> bool {
706 !self.tool_use.pending_tool_uses().is_empty()
707 }
708
709 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
710 self.checkpoints_by_message.get(&id).cloned()
711 }
712
713 pub fn restore_checkpoint(
714 &mut self,
715 checkpoint: ThreadCheckpoint,
716 cx: &mut Context<Self>,
717 ) -> Task<Result<()>> {
718 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
719 message_id: checkpoint.message_id,
720 });
721 cx.emit(ThreadEvent::CheckpointChanged);
722 cx.notify();
723
724 let git_store = self.project().read(cx).git_store().clone();
725 let restore = git_store.update(cx, |git_store, cx| {
726 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
727 });
728
729 cx.spawn(async move |this, cx| {
730 let result = restore.await;
731 this.update(cx, |this, cx| {
732 if let Err(err) = result.as_ref() {
733 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
734 message_id: checkpoint.message_id,
735 error: err.to_string(),
736 });
737 } else {
738 this.truncate(checkpoint.message_id, cx);
739 this.last_restore_checkpoint = None;
740 }
741 this.pending_checkpoint = None;
742 cx.emit(ThreadEvent::CheckpointChanged);
743 cx.notify();
744 })?;
745 result
746 })
747 }
748
749 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
750 let pending_checkpoint = if self.is_generating() {
751 return;
752 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
753 checkpoint
754 } else {
755 return;
756 };
757
758 let git_store = self.project.read(cx).git_store().clone();
759 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
760 cx.spawn(async move |this, cx| match final_checkpoint.await {
761 Ok(final_checkpoint) => {
762 let equal = git_store
763 .update(cx, |store, cx| {
764 store.compare_checkpoints(
765 pending_checkpoint.git_checkpoint.clone(),
766 final_checkpoint.clone(),
767 cx,
768 )
769 })?
770 .await
771 .unwrap_or(false);
772
773 if !equal {
774 this.update(cx, |this, cx| {
775 this.insert_checkpoint(pending_checkpoint, cx)
776 })?;
777 }
778
779 Ok(())
780 }
781 Err(_) => this.update(cx, |this, cx| {
782 this.insert_checkpoint(pending_checkpoint, cx)
783 }),
784 })
785 .detach();
786 }
787
788 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
789 self.checkpoints_by_message
790 .insert(checkpoint.message_id, checkpoint);
791 cx.emit(ThreadEvent::CheckpointChanged);
792 cx.notify();
793 }
794
795 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
796 self.last_restore_checkpoint.as_ref()
797 }
798
799 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
800 let Some(message_ix) = self
801 .messages
802 .iter()
803 .rposition(|message| message.id == message_id)
804 else {
805 return;
806 };
807 for deleted_message in self.messages.drain(message_ix..) {
808 self.checkpoints_by_message.remove(&deleted_message.id);
809 }
810 cx.notify();
811 }
812
813 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
814 self.messages
815 .iter()
816 .find(|message| message.id == id)
817 .into_iter()
818 .flat_map(|message| message.loaded_context.contexts.iter())
819 }
820
821 pub fn is_turn_end(&self, ix: usize) -> bool {
822 if self.messages.is_empty() {
823 return false;
824 }
825
826 if !self.is_generating() && ix == self.messages.len() - 1 {
827 return true;
828 }
829
830 let Some(message) = self.messages.get(ix) else {
831 return false;
832 };
833
834 if message.role != Role::Assistant {
835 return false;
836 }
837
838 self.messages
839 .get(ix + 1)
840 .and_then(|message| {
841 self.message(message.id)
842 .map(|next_message| next_message.role == Role::User)
843 })
844 .unwrap_or(false)
845 }
846
847 pub fn last_usage(&self) -> Option<RequestUsage> {
848 self.last_usage
849 }
850
851 pub fn tool_use_limit_reached(&self) -> bool {
852 self.tool_use_limit_reached
853 }
854
855 /// Returns whether all of the tool uses have finished running.
856 pub fn all_tools_finished(&self) -> bool {
857 // If the only pending tool uses left are the ones with errors, then
858 // that means that we've finished running all of the pending tools.
859 self.tool_use
860 .pending_tool_uses()
861 .iter()
862 .all(|tool_use| tool_use.status.is_error())
863 }
864
865 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
866 self.tool_use.tool_uses_for_message(id, cx)
867 }
868
869 pub fn tool_results_for_message(
870 &self,
871 assistant_message_id: MessageId,
872 ) -> Vec<&LanguageModelToolResult> {
873 self.tool_use.tool_results_for_message(assistant_message_id)
874 }
875
876 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
877 self.tool_use.tool_result(id)
878 }
879
880 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
881 Some(&self.tool_use.tool_result(id)?.content)
882 }
883
884 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
885 self.tool_use.tool_result_card(id).cloned()
886 }
887
888 /// Return tools that are both enabled and supported by the model
889 pub fn available_tools(
890 &self,
891 cx: &App,
892 model: Arc<dyn LanguageModel>,
893 ) -> Vec<LanguageModelRequestTool> {
894 if model.supports_tools() {
895 self.tools()
896 .read(cx)
897 .enabled_tools(cx)
898 .into_iter()
899 .filter_map(|tool| {
900 // Skip tools that cannot be supported
901 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
902 Some(LanguageModelRequestTool {
903 name: tool.name(),
904 description: tool.description(),
905 input_schema,
906 })
907 })
908 .collect()
909 } else {
910 Vec::default()
911 }
912 }
913
914 pub fn insert_user_message(
915 &mut self,
916 text: impl Into<String>,
917 loaded_context: ContextLoadResult,
918 git_checkpoint: Option<GitStoreCheckpoint>,
919 creases: Vec<MessageCrease>,
920 cx: &mut Context<Self>,
921 ) -> MessageId {
922 if !loaded_context.referenced_buffers.is_empty() {
923 self.action_log.update(cx, |log, cx| {
924 for buffer in loaded_context.referenced_buffers {
925 log.buffer_read(buffer, cx);
926 }
927 });
928 }
929
930 let message_id = self.insert_message(
931 Role::User,
932 vec![MessageSegment::Text(text.into())],
933 loaded_context.loaded_context,
934 creases,
935 cx,
936 );
937
938 if let Some(git_checkpoint) = git_checkpoint {
939 self.pending_checkpoint = Some(ThreadCheckpoint {
940 message_id,
941 git_checkpoint,
942 });
943 }
944
945 self.auto_capture_telemetry(cx);
946
947 message_id
948 }
949
950 pub fn insert_assistant_message(
951 &mut self,
952 segments: Vec<MessageSegment>,
953 cx: &mut Context<Self>,
954 ) -> MessageId {
955 self.insert_message(
956 Role::Assistant,
957 segments,
958 LoadedContext::default(),
959 Vec::new(),
960 cx,
961 )
962 }
963
964 pub fn insert_message(
965 &mut self,
966 role: Role,
967 segments: Vec<MessageSegment>,
968 loaded_context: LoadedContext,
969 creases: Vec<MessageCrease>,
970 cx: &mut Context<Self>,
971 ) -> MessageId {
972 let id = self.next_message_id.post_inc();
973 self.messages.push(Message {
974 id,
975 role,
976 segments,
977 loaded_context,
978 creases,
979 });
980 self.touch_updated_at();
981 cx.emit(ThreadEvent::MessageAdded(id));
982 id
983 }
984
985 pub fn edit_message(
986 &mut self,
987 id: MessageId,
988 new_role: Role,
989 new_segments: Vec<MessageSegment>,
990 loaded_context: Option<LoadedContext>,
991 cx: &mut Context<Self>,
992 ) -> bool {
993 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
994 return false;
995 };
996 message.role = new_role;
997 message.segments = new_segments;
998 if let Some(context) = loaded_context {
999 message.loaded_context = context;
1000 }
1001 self.touch_updated_at();
1002 cx.emit(ThreadEvent::MessageEdited(id));
1003 true
1004 }
1005
1006 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1007 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1008 return false;
1009 };
1010 self.messages.remove(index);
1011 self.touch_updated_at();
1012 cx.emit(ThreadEvent::MessageDeleted(id));
1013 true
1014 }
1015
1016 /// Returns the representation of this [`Thread`] in a textual form.
1017 ///
1018 /// This is the representation we use when attaching a thread as context to another thread.
1019 pub fn text(&self) -> String {
1020 let mut text = String::new();
1021
1022 for message in &self.messages {
1023 text.push_str(match message.role {
1024 language_model::Role::User => "User:",
1025 language_model::Role::Assistant => "Agent:",
1026 language_model::Role::System => "System:",
1027 });
1028 text.push('\n');
1029
1030 for segment in &message.segments {
1031 match segment {
1032 MessageSegment::Text(content) => text.push_str(content),
1033 MessageSegment::Thinking { text: content, .. } => {
1034 text.push_str(&format!("<think>{}</think>", content))
1035 }
1036 MessageSegment::RedactedThinking(_) => {}
1037 }
1038 }
1039 text.push('\n');
1040 }
1041
1042 text
1043 }
1044
1045 /// Serializes this thread into a format for storage or telemetry.
1046 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1047 let initial_project_snapshot = self.initial_project_snapshot.clone();
1048 cx.spawn(async move |this, cx| {
1049 let initial_project_snapshot = initial_project_snapshot.await;
1050 this.read_with(cx, |this, cx| SerializedThread {
1051 version: SerializedThread::VERSION.to_string(),
1052 summary: this.summary_or_default(),
1053 updated_at: this.updated_at(),
1054 messages: this
1055 .messages()
1056 .map(|message| SerializedMessage {
1057 id: message.id,
1058 role: message.role,
1059 segments: message
1060 .segments
1061 .iter()
1062 .map(|segment| match segment {
1063 MessageSegment::Text(text) => {
1064 SerializedMessageSegment::Text { text: text.clone() }
1065 }
1066 MessageSegment::Thinking { text, signature } => {
1067 SerializedMessageSegment::Thinking {
1068 text: text.clone(),
1069 signature: signature.clone(),
1070 }
1071 }
1072 MessageSegment::RedactedThinking(data) => {
1073 SerializedMessageSegment::RedactedThinking {
1074 data: data.clone(),
1075 }
1076 }
1077 })
1078 .collect(),
1079 tool_uses: this
1080 .tool_uses_for_message(message.id, cx)
1081 .into_iter()
1082 .map(|tool_use| SerializedToolUse {
1083 id: tool_use.id,
1084 name: tool_use.name,
1085 input: tool_use.input,
1086 })
1087 .collect(),
1088 tool_results: this
1089 .tool_results_for_message(message.id)
1090 .into_iter()
1091 .map(|tool_result| SerializedToolResult {
1092 tool_use_id: tool_result.tool_use_id.clone(),
1093 is_error: tool_result.is_error,
1094 content: tool_result.content.clone(),
1095 output: tool_result.output.clone(),
1096 })
1097 .collect(),
1098 context: message.loaded_context.text.clone(),
1099 creases: message
1100 .creases
1101 .iter()
1102 .map(|crease| SerializedCrease {
1103 start: crease.range.start,
1104 end: crease.range.end,
1105 icon_path: crease.metadata.icon_path.clone(),
1106 label: crease.metadata.label.clone(),
1107 })
1108 .collect(),
1109 })
1110 .collect(),
1111 initial_project_snapshot,
1112 cumulative_token_usage: this.cumulative_token_usage,
1113 request_token_usage: this.request_token_usage.clone(),
1114 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1115 exceeded_window_error: this.exceeded_window_error.clone(),
1116 model: this
1117 .configured_model
1118 .as_ref()
1119 .map(|model| SerializedLanguageModel {
1120 provider: model.provider.id().0.to_string(),
1121 model: model.model.id().0.to_string(),
1122 }),
1123 profile: this.configured_profile_id.clone(),
1124 completion_mode: Some(this.completion_mode),
1125 })
1126 })
1127 }
1128
1129 pub fn remaining_turns(&self) -> u32 {
1130 self.remaining_turns
1131 }
1132
1133 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1134 self.remaining_turns = remaining_turns;
1135 }
1136
1137 pub fn send_to_model(
1138 &mut self,
1139 model: Arc<dyn LanguageModel>,
1140 window: Option<AnyWindowHandle>,
1141 cx: &mut Context<Self>,
1142 ) {
1143 if self.remaining_turns == 0 {
1144 return;
1145 }
1146
1147 self.remaining_turns -= 1;
1148
1149 let request = self.to_completion_request(model.clone(), cx);
1150
1151 self.stream_completion(request, model, window, cx);
1152 }
1153
1154 pub fn used_tools_since_last_user_message(&self) -> bool {
1155 for message in self.messages.iter().rev() {
1156 if self.tool_use.message_has_tool_results(message.id) {
1157 return true;
1158 } else if message.role == Role::User {
1159 return false;
1160 }
1161 }
1162
1163 false
1164 }
1165
1166 pub fn to_completion_request(
1167 &self,
1168 model: Arc<dyn LanguageModel>,
1169 cx: &mut Context<Self>,
1170 ) -> LanguageModelRequest {
1171 let mut request = LanguageModelRequest {
1172 thread_id: Some(self.id.to_string()),
1173 prompt_id: Some(self.last_prompt_id.to_string()),
1174 mode: None,
1175 messages: vec![],
1176 tools: Vec::new(),
1177 tool_choice: None,
1178 stop: Vec::new(),
1179 temperature: AssistantSettings::temperature_for_model(&model, cx),
1180 };
1181
1182 let available_tools = self.available_tools(cx, model.clone());
1183 let available_tool_names = available_tools
1184 .iter()
1185 .map(|tool| tool.name.clone())
1186 .collect();
1187
1188 let model_context = &ModelContext {
1189 available_tools: available_tool_names,
1190 };
1191
1192 if let Some(project_context) = self.project_context.borrow().as_ref() {
1193 match self
1194 .prompt_builder
1195 .generate_assistant_system_prompt(project_context, model_context)
1196 {
1197 Err(err) => {
1198 let message = format!("{err:?}").into();
1199 log::error!("{message}");
1200 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1201 header: "Error generating system prompt".into(),
1202 message,
1203 }));
1204 }
1205 Ok(system_prompt) => {
1206 request.messages.push(LanguageModelRequestMessage {
1207 role: Role::System,
1208 content: vec![MessageContent::Text(system_prompt)],
1209 cache: true,
1210 });
1211 }
1212 }
1213 } else {
1214 let message = "Context for system prompt unexpectedly not ready.".into();
1215 log::error!("{message}");
1216 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1217 header: "Error generating system prompt".into(),
1218 message,
1219 }));
1220 }
1221
1222 let mut message_ix_to_cache = None;
1223 for message in &self.messages {
1224 let mut request_message = LanguageModelRequestMessage {
1225 role: message.role,
1226 content: Vec::new(),
1227 cache: false,
1228 };
1229
1230 message
1231 .loaded_context
1232 .add_to_request_message(&mut request_message);
1233
1234 for segment in &message.segments {
1235 match segment {
1236 MessageSegment::Text(text) => {
1237 if !text.is_empty() {
1238 request_message
1239 .content
1240 .push(MessageContent::Text(text.into()));
1241 }
1242 }
1243 MessageSegment::Thinking { text, signature } => {
1244 if !text.is_empty() {
1245 request_message.content.push(MessageContent::Thinking {
1246 text: text.into(),
1247 signature: signature.clone(),
1248 });
1249 }
1250 }
1251 MessageSegment::RedactedThinking(data) => {
1252 request_message
1253 .content
1254 .push(MessageContent::RedactedThinking(data.clone()));
1255 }
1256 };
1257 }
1258
1259 let mut cache_message = true;
1260 let mut tool_results_message = LanguageModelRequestMessage {
1261 role: Role::User,
1262 content: Vec::new(),
1263 cache: false,
1264 };
1265 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1266 if let Some(tool_result) = tool_result {
1267 request_message
1268 .content
1269 .push(MessageContent::ToolUse(tool_use.clone()));
1270 tool_results_message
1271 .content
1272 .push(MessageContent::ToolResult(LanguageModelToolResult {
1273 tool_use_id: tool_use.id.clone(),
1274 tool_name: tool_result.tool_name.clone(),
1275 is_error: tool_result.is_error,
1276 content: if tool_result.content.is_empty() {
1277 // Surprisingly, the API fails if we return an empty string here.
1278 // It thinks we are sending a tool use without a tool result.
1279 "<Tool returned an empty string>".into()
1280 } else {
1281 tool_result.content.clone()
1282 },
1283 output: None,
1284 }));
1285 } else {
1286 cache_message = false;
1287 log::debug!(
1288 "skipped tool use {:?} because it is still pending",
1289 tool_use
1290 );
1291 }
1292 }
1293
1294 if cache_message {
1295 message_ix_to_cache = Some(request.messages.len());
1296 }
1297 request.messages.push(request_message);
1298
1299 if !tool_results_message.content.is_empty() {
1300 if cache_message {
1301 message_ix_to_cache = Some(request.messages.len());
1302 }
1303 request.messages.push(tool_results_message);
1304 }
1305 }
1306
1307 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1308 if let Some(message_ix_to_cache) = message_ix_to_cache {
1309 request.messages[message_ix_to_cache].cache = true;
1310 }
1311
1312 self.attached_tracked_files_state(&mut request.messages, cx);
1313
1314 request.tools = available_tools;
1315 request.mode = if model.supports_max_mode() {
1316 Some(self.completion_mode.into())
1317 } else {
1318 Some(CompletionMode::Normal.into())
1319 };
1320
1321 request
1322 }
1323
1324 fn to_summarize_request(
1325 &self,
1326 model: &Arc<dyn LanguageModel>,
1327 added_user_message: String,
1328 cx: &App,
1329 ) -> LanguageModelRequest {
1330 let mut request = LanguageModelRequest {
1331 thread_id: None,
1332 prompt_id: None,
1333 mode: None,
1334 messages: vec![],
1335 tools: Vec::new(),
1336 tool_choice: None,
1337 stop: Vec::new(),
1338 temperature: AssistantSettings::temperature_for_model(model, cx),
1339 };
1340
1341 for message in &self.messages {
1342 let mut request_message = LanguageModelRequestMessage {
1343 role: message.role,
1344 content: Vec::new(),
1345 cache: false,
1346 };
1347
1348 for segment in &message.segments {
1349 match segment {
1350 MessageSegment::Text(text) => request_message
1351 .content
1352 .push(MessageContent::Text(text.clone())),
1353 MessageSegment::Thinking { .. } => {}
1354 MessageSegment::RedactedThinking(_) => {}
1355 }
1356 }
1357
1358 if request_message.content.is_empty() {
1359 continue;
1360 }
1361
1362 request.messages.push(request_message);
1363 }
1364
1365 request.messages.push(LanguageModelRequestMessage {
1366 role: Role::User,
1367 content: vec![MessageContent::Text(added_user_message)],
1368 cache: false,
1369 });
1370
1371 request
1372 }
1373
1374 fn attached_tracked_files_state(
1375 &self,
1376 messages: &mut Vec<LanguageModelRequestMessage>,
1377 cx: &App,
1378 ) {
1379 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1380
1381 let mut stale_message = String::new();
1382
1383 let action_log = self.action_log.read(cx);
1384
1385 for stale_file in action_log.stale_buffers(cx) {
1386 let Some(file) = stale_file.read(cx).file() else {
1387 continue;
1388 };
1389
1390 if stale_message.is_empty() {
1391 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1392 }
1393
1394 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1395 }
1396
1397 let mut content = Vec::with_capacity(2);
1398
1399 if !stale_message.is_empty() {
1400 content.push(stale_message.into());
1401 }
1402
1403 if !content.is_empty() {
1404 let context_message = LanguageModelRequestMessage {
1405 role: Role::User,
1406 content,
1407 cache: false,
1408 };
1409
1410 messages.push(context_message);
1411 }
1412 }
1413
1414 pub fn stream_completion(
1415 &mut self,
1416 request: LanguageModelRequest,
1417 model: Arc<dyn LanguageModel>,
1418 window: Option<AnyWindowHandle>,
1419 cx: &mut Context<Self>,
1420 ) {
1421 self.tool_use_limit_reached = false;
1422
1423 let pending_completion_id = post_inc(&mut self.completion_count);
1424 let mut request_callback_parameters = if self.request_callback.is_some() {
1425 Some((request.clone(), Vec::new()))
1426 } else {
1427 None
1428 };
1429 let prompt_id = self.last_prompt_id.clone();
1430 let tool_use_metadata = ToolUseMetadata {
1431 model: model.clone(),
1432 thread_id: self.id.clone(),
1433 prompt_id: prompt_id.clone(),
1434 };
1435
1436 self.last_received_chunk_at = Some(Instant::now());
1437
1438 let task = cx.spawn(async move |thread, cx| {
1439 let stream_completion_future = model.stream_completion(request, &cx);
1440 let initial_token_usage =
1441 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1442 let stream_completion = async {
1443 let mut events = stream_completion_future.await?;
1444
1445 let mut stop_reason = StopReason::EndTurn;
1446 let mut current_token_usage = TokenUsage::default();
1447
1448 thread
1449 .update(cx, |_thread, cx| {
1450 cx.emit(ThreadEvent::NewRequest);
1451 })
1452 .ok();
1453
1454 let mut request_assistant_message_id = None;
1455
1456 while let Some(event) = events.next().await {
1457 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1458 response_events
1459 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1460 }
1461
1462 thread.update(cx, |thread, cx| {
1463 let event = match event {
1464 Ok(event) => event,
1465 Err(LanguageModelCompletionError::BadInputJson {
1466 id,
1467 tool_name,
1468 raw_input: invalid_input_json,
1469 json_parse_error,
1470 }) => {
1471 thread.receive_invalid_tool_json(
1472 id,
1473 tool_name,
1474 invalid_input_json,
1475 json_parse_error,
1476 window,
1477 cx,
1478 );
1479 return Ok(());
1480 }
1481 Err(LanguageModelCompletionError::Other(error)) => {
1482 return Err(error);
1483 }
1484 };
1485
1486 match event {
1487 LanguageModelCompletionEvent::StartMessage { .. } => {
1488 request_assistant_message_id =
1489 Some(thread.insert_assistant_message(
1490 vec![MessageSegment::Text(String::new())],
1491 cx,
1492 ));
1493 }
1494 LanguageModelCompletionEvent::Stop(reason) => {
1495 stop_reason = reason;
1496 }
1497 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1498 thread.update_token_usage_at_last_message(token_usage);
1499 thread.cumulative_token_usage = thread.cumulative_token_usage
1500 + token_usage
1501 - current_token_usage;
1502 current_token_usage = token_usage;
1503 }
1504 LanguageModelCompletionEvent::Text(chunk) => {
1505 thread.received_chunk();
1506
1507 cx.emit(ThreadEvent::ReceivedTextChunk);
1508 if let Some(last_message) = thread.messages.last_mut() {
1509 if last_message.role == Role::Assistant
1510 && !thread.tool_use.has_tool_results(last_message.id)
1511 {
1512 last_message.push_text(&chunk);
1513 cx.emit(ThreadEvent::StreamedAssistantText(
1514 last_message.id,
1515 chunk,
1516 ));
1517 } else {
1518 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1519 // of a new Assistant response.
1520 //
1521 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1522 // will result in duplicating the text of the chunk in the rendered Markdown.
1523 request_assistant_message_id =
1524 Some(thread.insert_assistant_message(
1525 vec![MessageSegment::Text(chunk.to_string())],
1526 cx,
1527 ));
1528 };
1529 }
1530 }
1531 LanguageModelCompletionEvent::Thinking {
1532 text: chunk,
1533 signature,
1534 } => {
1535 thread.received_chunk();
1536
1537 if let Some(last_message) = thread.messages.last_mut() {
1538 if last_message.role == Role::Assistant
1539 && !thread.tool_use.has_tool_results(last_message.id)
1540 {
1541 last_message.push_thinking(&chunk, signature);
1542 cx.emit(ThreadEvent::StreamedAssistantThinking(
1543 last_message.id,
1544 chunk,
1545 ));
1546 } else {
1547 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1548 // of a new Assistant response.
1549 //
1550 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1551 // will result in duplicating the text of the chunk in the rendered Markdown.
1552 request_assistant_message_id =
1553 Some(thread.insert_assistant_message(
1554 vec![MessageSegment::Thinking {
1555 text: chunk.to_string(),
1556 signature,
1557 }],
1558 cx,
1559 ));
1560 };
1561 }
1562 }
1563 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1564 let last_assistant_message_id = request_assistant_message_id
1565 .unwrap_or_else(|| {
1566 let new_assistant_message_id =
1567 thread.insert_assistant_message(vec![], cx);
1568 request_assistant_message_id =
1569 Some(new_assistant_message_id);
1570 new_assistant_message_id
1571 });
1572
1573 let tool_use_id = tool_use.id.clone();
1574 let streamed_input = if tool_use.is_input_complete {
1575 None
1576 } else {
1577 Some((&tool_use.input).clone())
1578 };
1579
1580 let ui_text = thread.tool_use.request_tool_use(
1581 last_assistant_message_id,
1582 tool_use,
1583 tool_use_metadata.clone(),
1584 cx,
1585 );
1586
1587 if let Some(input) = streamed_input {
1588 cx.emit(ThreadEvent::StreamedToolUse {
1589 tool_use_id,
1590 ui_text,
1591 input,
1592 });
1593 }
1594 }
1595 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1596 if let Some(completion) = thread
1597 .pending_completions
1598 .iter_mut()
1599 .find(|completion| completion.id == pending_completion_id)
1600 {
1601 match status_update {
1602 CompletionRequestStatus::Queued {
1603 position,
1604 } => {
1605 completion.queue_state = QueueState::Queued { position };
1606 }
1607 CompletionRequestStatus::Started => {
1608 completion.queue_state = QueueState::Started;
1609 }
1610 CompletionRequestStatus::Failed {
1611 code, message, request_id
1612 } => {
1613 return Err(anyhow!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"));
1614 }
1615 CompletionRequestStatus::UsageUpdated {
1616 amount, limit
1617 } => {
1618 let usage = RequestUsage { limit, amount: amount as i32 };
1619
1620 thread.last_usage = Some(usage);
1621 }
1622 CompletionRequestStatus::ToolUseLimitReached => {
1623 thread.tool_use_limit_reached = true;
1624 }
1625 }
1626 }
1627 }
1628 }
1629
1630 thread.touch_updated_at();
1631 cx.emit(ThreadEvent::StreamedCompletion);
1632 cx.notify();
1633
1634 thread.auto_capture_telemetry(cx);
1635 Ok(())
1636 })??;
1637
1638 smol::future::yield_now().await;
1639 }
1640
1641 thread.update(cx, |thread, cx| {
1642 thread.last_received_chunk_at = None;
1643 thread
1644 .pending_completions
1645 .retain(|completion| completion.id != pending_completion_id);
1646
1647 // If there is a response without tool use, summarize the message. Otherwise,
1648 // allow two tool uses before summarizing.
1649 if thread.summary.is_none()
1650 && thread.messages.len() >= 2
1651 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1652 {
1653 thread.summarize(cx);
1654 }
1655 })?;
1656
1657 anyhow::Ok(stop_reason)
1658 };
1659
1660 let result = stream_completion.await;
1661
1662 thread
1663 .update(cx, |thread, cx| {
1664 thread.finalize_pending_checkpoint(cx);
1665 match result.as_ref() {
1666 Ok(stop_reason) => match stop_reason {
1667 StopReason::ToolUse => {
1668 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1669 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1670 }
1671 StopReason::EndTurn | StopReason::MaxTokens => {
1672 thread.project.update(cx, |project, cx| {
1673 project.set_agent_location(None, cx);
1674 });
1675 }
1676 },
1677 Err(error) => {
1678 thread.project.update(cx, |project, cx| {
1679 project.set_agent_location(None, cx);
1680 });
1681
1682 if error.is::<PaymentRequiredError>() {
1683 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1684 } else if error.is::<MaxMonthlySpendReachedError>() {
1685 cx.emit(ThreadEvent::ShowError(
1686 ThreadError::MaxMonthlySpendReached,
1687 ));
1688 } else if let Some(error) =
1689 error.downcast_ref::<ModelRequestLimitReachedError>()
1690 {
1691 cx.emit(ThreadEvent::ShowError(
1692 ThreadError::ModelRequestLimitReached { plan: error.plan },
1693 ));
1694 } else if let Some(known_error) =
1695 error.downcast_ref::<LanguageModelKnownError>()
1696 {
1697 match known_error {
1698 LanguageModelKnownError::ContextWindowLimitExceeded {
1699 tokens,
1700 } => {
1701 thread.exceeded_window_error = Some(ExceededWindowError {
1702 model_id: model.id(),
1703 token_count: *tokens,
1704 });
1705 cx.notify();
1706 }
1707 }
1708 } else {
1709 let error_message = error
1710 .chain()
1711 .map(|err| err.to_string())
1712 .collect::<Vec<_>>()
1713 .join("\n");
1714 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1715 header: "Error interacting with language model".into(),
1716 message: SharedString::from(error_message.clone()),
1717 }));
1718 }
1719
1720 thread.cancel_last_completion(window, cx);
1721 }
1722 }
1723 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1724
1725 if let Some((request_callback, (request, response_events))) = thread
1726 .request_callback
1727 .as_mut()
1728 .zip(request_callback_parameters.as_ref())
1729 {
1730 request_callback(request, response_events);
1731 }
1732
1733 thread.auto_capture_telemetry(cx);
1734
1735 if let Ok(initial_usage) = initial_token_usage {
1736 let usage = thread.cumulative_token_usage - initial_usage;
1737
1738 telemetry::event!(
1739 "Assistant Thread Completion",
1740 thread_id = thread.id().to_string(),
1741 prompt_id = prompt_id,
1742 model = model.telemetry_id(),
1743 model_provider = model.provider_id().to_string(),
1744 input_tokens = usage.input_tokens,
1745 output_tokens = usage.output_tokens,
1746 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1747 cache_read_input_tokens = usage.cache_read_input_tokens,
1748 );
1749 }
1750 })
1751 .ok();
1752 });
1753
1754 self.pending_completions.push(PendingCompletion {
1755 id: pending_completion_id,
1756 queue_state: QueueState::Sending,
1757 _task: task,
1758 });
1759 }
1760
1761 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1762 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1763 return;
1764 };
1765
1766 if !model.provider.is_authenticated(cx) {
1767 return;
1768 }
1769
1770 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1771 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1772 If the conversation is about a specific subject, include it in the title. \
1773 Be descriptive. DO NOT speak in the first person.";
1774
1775 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1776
1777 self.pending_summary = cx.spawn(async move |this, cx| {
1778 async move {
1779 let mut messages = model.model.stream_completion(request, &cx).await?;
1780
1781 let mut new_summary = String::new();
1782 while let Some(event) = messages.next().await {
1783 let event = event?;
1784 let text = match event {
1785 LanguageModelCompletionEvent::Text(text) => text,
1786 LanguageModelCompletionEvent::StatusUpdate(
1787 CompletionRequestStatus::UsageUpdated { amount, limit },
1788 ) => {
1789 this.update(cx, |thread, _cx| {
1790 thread.last_usage = Some(RequestUsage {
1791 limit,
1792 amount: amount as i32,
1793 });
1794 })?;
1795 continue;
1796 }
1797 _ => continue,
1798 };
1799
1800 let mut lines = text.lines();
1801 new_summary.extend(lines.next());
1802
1803 // Stop if the LLM generated multiple lines.
1804 if lines.next().is_some() {
1805 break;
1806 }
1807 }
1808
1809 this.update(cx, |this, cx| {
1810 if !new_summary.is_empty() {
1811 this.summary = Some(new_summary.into());
1812 }
1813
1814 cx.emit(ThreadEvent::SummaryGenerated);
1815 })?;
1816
1817 anyhow::Ok(())
1818 }
1819 .log_err()
1820 .await
1821 });
1822 }
1823
1824 pub fn start_generating_detailed_summary_if_needed(
1825 &mut self,
1826 thread_store: WeakEntity<ThreadStore>,
1827 cx: &mut Context<Self>,
1828 ) {
1829 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1830 return;
1831 };
1832
1833 match &*self.detailed_summary_rx.borrow() {
1834 DetailedSummaryState::Generating { message_id, .. }
1835 | DetailedSummaryState::Generated { message_id, .. }
1836 if *message_id == last_message_id =>
1837 {
1838 // Already up-to-date
1839 return;
1840 }
1841 _ => {}
1842 }
1843
1844 let Some(ConfiguredModel { model, provider }) =
1845 LanguageModelRegistry::read_global(cx).thread_summary_model()
1846 else {
1847 return;
1848 };
1849
1850 if !provider.is_authenticated(cx) {
1851 return;
1852 }
1853
1854 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1855 1. A brief overview of what was discussed\n\
1856 2. Key facts or information discovered\n\
1857 3. Outcomes or conclusions reached\n\
1858 4. Any action items or next steps if any\n\
1859 Format it in Markdown with headings and bullet points.";
1860
1861 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1862
1863 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1864 message_id: last_message_id,
1865 };
1866
1867 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1868 // be better to allow the old task to complete, but this would require logic for choosing
1869 // which result to prefer (the old task could complete after the new one, resulting in a
1870 // stale summary).
1871 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1872 let stream = model.stream_completion_text(request, &cx);
1873 let Some(mut messages) = stream.await.log_err() else {
1874 thread
1875 .update(cx, |thread, _cx| {
1876 *thread.detailed_summary_tx.borrow_mut() =
1877 DetailedSummaryState::NotGenerated;
1878 })
1879 .ok()?;
1880 return None;
1881 };
1882
1883 let mut new_detailed_summary = String::new();
1884
1885 while let Some(chunk) = messages.stream.next().await {
1886 if let Some(chunk) = chunk.log_err() {
1887 new_detailed_summary.push_str(&chunk);
1888 }
1889 }
1890
1891 thread
1892 .update(cx, |thread, _cx| {
1893 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1894 text: new_detailed_summary.into(),
1895 message_id: last_message_id,
1896 };
1897 })
1898 .ok()?;
1899
1900 // Save thread so its summary can be reused later
1901 if let Some(thread) = thread.upgrade() {
1902 if let Ok(Ok(save_task)) = cx.update(|cx| {
1903 thread_store
1904 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1905 }) {
1906 save_task.await.log_err();
1907 }
1908 }
1909
1910 Some(())
1911 });
1912 }
1913
1914 pub async fn wait_for_detailed_summary_or_text(
1915 this: &Entity<Self>,
1916 cx: &mut AsyncApp,
1917 ) -> Option<SharedString> {
1918 let mut detailed_summary_rx = this
1919 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
1920 .ok()?;
1921 loop {
1922 match detailed_summary_rx.recv().await? {
1923 DetailedSummaryState::Generating { .. } => {}
1924 DetailedSummaryState::NotGenerated => {
1925 return this.read_with(cx, |this, _cx| this.text().into()).ok();
1926 }
1927 DetailedSummaryState::Generated { text, .. } => return Some(text),
1928 }
1929 }
1930 }
1931
1932 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
1933 self.detailed_summary_rx
1934 .borrow()
1935 .text()
1936 .unwrap_or_else(|| self.text().into())
1937 }
1938
1939 pub fn is_generating_detailed_summary(&self) -> bool {
1940 matches!(
1941 &*self.detailed_summary_rx.borrow(),
1942 DetailedSummaryState::Generating { .. }
1943 )
1944 }
1945
1946 pub fn use_pending_tools(
1947 &mut self,
1948 window: Option<AnyWindowHandle>,
1949 cx: &mut Context<Self>,
1950 model: Arc<dyn LanguageModel>,
1951 ) -> Vec<PendingToolUse> {
1952 self.auto_capture_telemetry(cx);
1953 let request = Arc::new(self.to_completion_request(model.clone(), cx));
1954 let pending_tool_uses = self
1955 .tool_use
1956 .pending_tool_uses()
1957 .into_iter()
1958 .filter(|tool_use| tool_use.status.is_idle())
1959 .cloned()
1960 .collect::<Vec<_>>();
1961
1962 for tool_use in pending_tool_uses.iter() {
1963 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1964 if tool.needs_confirmation(&tool_use.input, cx)
1965 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1966 {
1967 self.tool_use.confirm_tool_use(
1968 tool_use.id.clone(),
1969 tool_use.ui_text.clone(),
1970 tool_use.input.clone(),
1971 request.clone(),
1972 tool,
1973 );
1974 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1975 } else {
1976 self.run_tool(
1977 tool_use.id.clone(),
1978 tool_use.ui_text.clone(),
1979 tool_use.input.clone(),
1980 request.clone(),
1981 tool,
1982 model.clone(),
1983 window,
1984 cx,
1985 );
1986 }
1987 } else {
1988 self.handle_hallucinated_tool_use(
1989 tool_use.id.clone(),
1990 tool_use.name.clone(),
1991 window,
1992 cx,
1993 );
1994 }
1995 }
1996
1997 pending_tool_uses
1998 }
1999
2000 pub fn handle_hallucinated_tool_use(
2001 &mut self,
2002 tool_use_id: LanguageModelToolUseId,
2003 hallucinated_tool_name: Arc<str>,
2004 window: Option<AnyWindowHandle>,
2005 cx: &mut Context<Thread>,
2006 ) {
2007 let available_tools = self.tools.read(cx).enabled_tools(cx);
2008
2009 let tool_list = available_tools
2010 .iter()
2011 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2012 .collect::<Vec<_>>()
2013 .join("\n");
2014
2015 let error_message = format!(
2016 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2017 hallucinated_tool_name, tool_list
2018 );
2019
2020 let pending_tool_use = self.tool_use.insert_tool_output(
2021 tool_use_id.clone(),
2022 hallucinated_tool_name,
2023 Err(anyhow!("Missing tool call: {error_message}")),
2024 self.configured_model.as_ref(),
2025 );
2026
2027 cx.emit(ThreadEvent::MissingToolUse {
2028 tool_use_id: tool_use_id.clone(),
2029 ui_text: error_message.into(),
2030 });
2031
2032 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2033 }
2034
2035 pub fn receive_invalid_tool_json(
2036 &mut self,
2037 tool_use_id: LanguageModelToolUseId,
2038 tool_name: Arc<str>,
2039 invalid_json: Arc<str>,
2040 error: String,
2041 window: Option<AnyWindowHandle>,
2042 cx: &mut Context<Thread>,
2043 ) {
2044 log::error!("The model returned invalid input JSON: {invalid_json}");
2045
2046 let pending_tool_use = self.tool_use.insert_tool_output(
2047 tool_use_id.clone(),
2048 tool_name,
2049 Err(anyhow!("Error parsing input JSON: {error}")),
2050 self.configured_model.as_ref(),
2051 );
2052 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2053 pending_tool_use.ui_text.clone()
2054 } else {
2055 log::error!(
2056 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2057 );
2058 format!("Unknown tool {}", tool_use_id).into()
2059 };
2060
2061 cx.emit(ThreadEvent::InvalidToolInput {
2062 tool_use_id: tool_use_id.clone(),
2063 ui_text,
2064 invalid_input_json: invalid_json,
2065 });
2066
2067 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2068 }
2069
2070 pub fn run_tool(
2071 &mut self,
2072 tool_use_id: LanguageModelToolUseId,
2073 ui_text: impl Into<SharedString>,
2074 input: serde_json::Value,
2075 request: Arc<LanguageModelRequest>,
2076 tool: Arc<dyn Tool>,
2077 model: Arc<dyn LanguageModel>,
2078 window: Option<AnyWindowHandle>,
2079 cx: &mut Context<Thread>,
2080 ) {
2081 let task =
2082 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2083 self.tool_use
2084 .run_pending_tool(tool_use_id, ui_text.into(), task);
2085 }
2086
2087 fn spawn_tool_use(
2088 &mut self,
2089 tool_use_id: LanguageModelToolUseId,
2090 request: Arc<LanguageModelRequest>,
2091 input: serde_json::Value,
2092 tool: Arc<dyn Tool>,
2093 model: Arc<dyn LanguageModel>,
2094 window: Option<AnyWindowHandle>,
2095 cx: &mut Context<Thread>,
2096 ) -> Task<()> {
2097 let tool_name: Arc<str> = tool.name().into();
2098
2099 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2100 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2101 } else {
2102 tool.run(
2103 input,
2104 request,
2105 self.project.clone(),
2106 self.action_log.clone(),
2107 model,
2108 window,
2109 cx,
2110 )
2111 };
2112
2113 // Store the card separately if it exists
2114 if let Some(card) = tool_result.card.clone() {
2115 self.tool_use
2116 .insert_tool_result_card(tool_use_id.clone(), card);
2117 }
2118
2119 cx.spawn({
2120 async move |thread: WeakEntity<Thread>, cx| {
2121 let output = tool_result.output.await;
2122
2123 thread
2124 .update(cx, |thread, cx| {
2125 let pending_tool_use = thread.tool_use.insert_tool_output(
2126 tool_use_id.clone(),
2127 tool_name,
2128 output,
2129 thread.configured_model.as_ref(),
2130 );
2131 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2132 })
2133 .ok();
2134 }
2135 })
2136 }
2137
2138 fn tool_finished(
2139 &mut self,
2140 tool_use_id: LanguageModelToolUseId,
2141 pending_tool_use: Option<PendingToolUse>,
2142 canceled: bool,
2143 window: Option<AnyWindowHandle>,
2144 cx: &mut Context<Self>,
2145 ) {
2146 if self.all_tools_finished() {
2147 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2148 if !canceled {
2149 self.send_to_model(model.clone(), window, cx);
2150 }
2151 self.auto_capture_telemetry(cx);
2152 }
2153 }
2154
2155 cx.emit(ThreadEvent::ToolFinished {
2156 tool_use_id,
2157 pending_tool_use,
2158 });
2159 }
2160
2161 /// Cancels the last pending completion, if there are any pending.
2162 ///
2163 /// Returns whether a completion was canceled.
2164 pub fn cancel_last_completion(
2165 &mut self,
2166 window: Option<AnyWindowHandle>,
2167 cx: &mut Context<Self>,
2168 ) -> bool {
2169 let mut canceled = self.pending_completions.pop().is_some();
2170
2171 for pending_tool_use in self.tool_use.cancel_pending() {
2172 canceled = true;
2173 self.tool_finished(
2174 pending_tool_use.id.clone(),
2175 Some(pending_tool_use),
2176 true,
2177 window,
2178 cx,
2179 );
2180 }
2181
2182 self.finalize_pending_checkpoint(cx);
2183
2184 if canceled {
2185 cx.emit(ThreadEvent::CompletionCanceled);
2186 }
2187
2188 canceled
2189 }
2190
2191 /// Signals that any in-progress editing should be canceled.
2192 ///
2193 /// This method is used to notify listeners (like ActiveThread) that
2194 /// they should cancel any editing operations.
2195 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2196 cx.emit(ThreadEvent::CancelEditing);
2197 }
2198
2199 pub fn feedback(&self) -> Option<ThreadFeedback> {
2200 self.feedback
2201 }
2202
2203 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2204 self.message_feedback.get(&message_id).copied()
2205 }
2206
2207 pub fn report_message_feedback(
2208 &mut self,
2209 message_id: MessageId,
2210 feedback: ThreadFeedback,
2211 cx: &mut Context<Self>,
2212 ) -> Task<Result<()>> {
2213 if self.message_feedback.get(&message_id) == Some(&feedback) {
2214 return Task::ready(Ok(()));
2215 }
2216
2217 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2218 let serialized_thread = self.serialize(cx);
2219 let thread_id = self.id().clone();
2220 let client = self.project.read(cx).client();
2221
2222 let enabled_tool_names: Vec<String> = self
2223 .tools()
2224 .read(cx)
2225 .enabled_tools(cx)
2226 .iter()
2227 .map(|tool| tool.name().to_string())
2228 .collect();
2229
2230 self.message_feedback.insert(message_id, feedback);
2231
2232 cx.notify();
2233
2234 let message_content = self
2235 .message(message_id)
2236 .map(|msg| msg.to_string())
2237 .unwrap_or_default();
2238
2239 cx.background_spawn(async move {
2240 let final_project_snapshot = final_project_snapshot.await;
2241 let serialized_thread = serialized_thread.await?;
2242 let thread_data =
2243 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2244
2245 let rating = match feedback {
2246 ThreadFeedback::Positive => "positive",
2247 ThreadFeedback::Negative => "negative",
2248 };
2249 telemetry::event!(
2250 "Assistant Thread Rated",
2251 rating,
2252 thread_id,
2253 enabled_tool_names,
2254 message_id = message_id.0,
2255 message_content,
2256 thread_data,
2257 final_project_snapshot
2258 );
2259 client.telemetry().flush_events().await;
2260
2261 Ok(())
2262 })
2263 }
2264
2265 pub fn report_feedback(
2266 &mut self,
2267 feedback: ThreadFeedback,
2268 cx: &mut Context<Self>,
2269 ) -> Task<Result<()>> {
2270 let last_assistant_message_id = self
2271 .messages
2272 .iter()
2273 .rev()
2274 .find(|msg| msg.role == Role::Assistant)
2275 .map(|msg| msg.id);
2276
2277 if let Some(message_id) = last_assistant_message_id {
2278 self.report_message_feedback(message_id, feedback, cx)
2279 } else {
2280 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2281 let serialized_thread = self.serialize(cx);
2282 let thread_id = self.id().clone();
2283 let client = self.project.read(cx).client();
2284 self.feedback = Some(feedback);
2285 cx.notify();
2286
2287 cx.background_spawn(async move {
2288 let final_project_snapshot = final_project_snapshot.await;
2289 let serialized_thread = serialized_thread.await?;
2290 let thread_data = serde_json::to_value(serialized_thread)
2291 .unwrap_or_else(|_| serde_json::Value::Null);
2292
2293 let rating = match feedback {
2294 ThreadFeedback::Positive => "positive",
2295 ThreadFeedback::Negative => "negative",
2296 };
2297 telemetry::event!(
2298 "Assistant Thread Rated",
2299 rating,
2300 thread_id,
2301 thread_data,
2302 final_project_snapshot
2303 );
2304 client.telemetry().flush_events().await;
2305
2306 Ok(())
2307 })
2308 }
2309 }
2310
2311 /// Create a snapshot of the current project state including git information and unsaved buffers.
2312 fn project_snapshot(
2313 project: Entity<Project>,
2314 cx: &mut Context<Self>,
2315 ) -> Task<Arc<ProjectSnapshot>> {
2316 let git_store = project.read(cx).git_store().clone();
2317 let worktree_snapshots: Vec<_> = project
2318 .read(cx)
2319 .visible_worktrees(cx)
2320 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2321 .collect();
2322
2323 cx.spawn(async move |_, cx| {
2324 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2325
2326 let mut unsaved_buffers = Vec::new();
2327 cx.update(|app_cx| {
2328 let buffer_store = project.read(app_cx).buffer_store();
2329 for buffer_handle in buffer_store.read(app_cx).buffers() {
2330 let buffer = buffer_handle.read(app_cx);
2331 if buffer.is_dirty() {
2332 if let Some(file) = buffer.file() {
2333 let path = file.path().to_string_lossy().to_string();
2334 unsaved_buffers.push(path);
2335 }
2336 }
2337 }
2338 })
2339 .ok();
2340
2341 Arc::new(ProjectSnapshot {
2342 worktree_snapshots,
2343 unsaved_buffer_paths: unsaved_buffers,
2344 timestamp: Utc::now(),
2345 })
2346 })
2347 }
2348
2349 fn worktree_snapshot(
2350 worktree: Entity<project::Worktree>,
2351 git_store: Entity<GitStore>,
2352 cx: &App,
2353 ) -> Task<WorktreeSnapshot> {
2354 cx.spawn(async move |cx| {
2355 // Get worktree path and snapshot
2356 let worktree_info = cx.update(|app_cx| {
2357 let worktree = worktree.read(app_cx);
2358 let path = worktree.abs_path().to_string_lossy().to_string();
2359 let snapshot = worktree.snapshot();
2360 (path, snapshot)
2361 });
2362
2363 let Ok((worktree_path, _snapshot)) = worktree_info else {
2364 return WorktreeSnapshot {
2365 worktree_path: String::new(),
2366 git_state: None,
2367 };
2368 };
2369
2370 let git_state = git_store
2371 .update(cx, |git_store, cx| {
2372 git_store
2373 .repositories()
2374 .values()
2375 .find(|repo| {
2376 repo.read(cx)
2377 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2378 .is_some()
2379 })
2380 .cloned()
2381 })
2382 .ok()
2383 .flatten()
2384 .map(|repo| {
2385 repo.update(cx, |repo, _| {
2386 let current_branch =
2387 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2388 repo.send_job(None, |state, _| async move {
2389 let RepositoryState::Local { backend, .. } = state else {
2390 return GitState {
2391 remote_url: None,
2392 head_sha: None,
2393 current_branch,
2394 diff: None,
2395 };
2396 };
2397
2398 let remote_url = backend.remote_url("origin");
2399 let head_sha = backend.head_sha().await;
2400 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2401
2402 GitState {
2403 remote_url,
2404 head_sha,
2405 current_branch,
2406 diff,
2407 }
2408 })
2409 })
2410 });
2411
2412 let git_state = match git_state {
2413 Some(git_state) => match git_state.ok() {
2414 Some(git_state) => git_state.await.ok(),
2415 None => None,
2416 },
2417 None => None,
2418 };
2419
2420 WorktreeSnapshot {
2421 worktree_path,
2422 git_state,
2423 }
2424 })
2425 }
2426
2427 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2428 let mut markdown = Vec::new();
2429
2430 if let Some(summary) = self.summary() {
2431 writeln!(markdown, "# {summary}\n")?;
2432 };
2433
2434 for message in self.messages() {
2435 writeln!(
2436 markdown,
2437 "## {role}\n",
2438 role = match message.role {
2439 Role::User => "User",
2440 Role::Assistant => "Agent",
2441 Role::System => "System",
2442 }
2443 )?;
2444
2445 if !message.loaded_context.text.is_empty() {
2446 writeln!(markdown, "{}", message.loaded_context.text)?;
2447 }
2448
2449 if !message.loaded_context.images.is_empty() {
2450 writeln!(
2451 markdown,
2452 "\n{} images attached as context.\n",
2453 message.loaded_context.images.len()
2454 )?;
2455 }
2456
2457 for segment in &message.segments {
2458 match segment {
2459 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2460 MessageSegment::Thinking { text, .. } => {
2461 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2462 }
2463 MessageSegment::RedactedThinking(_) => {}
2464 }
2465 }
2466
2467 for tool_use in self.tool_uses_for_message(message.id, cx) {
2468 writeln!(
2469 markdown,
2470 "**Use Tool: {} ({})**",
2471 tool_use.name, tool_use.id
2472 )?;
2473 writeln!(markdown, "```json")?;
2474 writeln!(
2475 markdown,
2476 "{}",
2477 serde_json::to_string_pretty(&tool_use.input)?
2478 )?;
2479 writeln!(markdown, "```")?;
2480 }
2481
2482 for tool_result in self.tool_results_for_message(message.id) {
2483 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2484 if tool_result.is_error {
2485 write!(markdown, " (Error)")?;
2486 }
2487
2488 writeln!(markdown, "**\n")?;
2489 writeln!(markdown, "{}", tool_result.content)?;
2490 if let Some(output) = tool_result.output.as_ref() {
2491 writeln!(
2492 markdown,
2493 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2494 serde_json::to_string_pretty(output)?
2495 )?;
2496 }
2497 }
2498 }
2499
2500 Ok(String::from_utf8_lossy(&markdown).to_string())
2501 }
2502
2503 pub fn keep_edits_in_range(
2504 &mut self,
2505 buffer: Entity<language::Buffer>,
2506 buffer_range: Range<language::Anchor>,
2507 cx: &mut Context<Self>,
2508 ) {
2509 self.action_log.update(cx, |action_log, cx| {
2510 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2511 });
2512 }
2513
2514 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2515 self.action_log
2516 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2517 }
2518
2519 pub fn reject_edits_in_ranges(
2520 &mut self,
2521 buffer: Entity<language::Buffer>,
2522 buffer_ranges: Vec<Range<language::Anchor>>,
2523 cx: &mut Context<Self>,
2524 ) -> Task<Result<()>> {
2525 self.action_log.update(cx, |action_log, cx| {
2526 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2527 })
2528 }
2529
2530 pub fn action_log(&self) -> &Entity<ActionLog> {
2531 &self.action_log
2532 }
2533
2534 pub fn project(&self) -> &Entity<Project> {
2535 &self.project
2536 }
2537
2538 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2539 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2540 return;
2541 }
2542
2543 let now = Instant::now();
2544 if let Some(last) = self.last_auto_capture_at {
2545 if now.duration_since(last).as_secs() < 10 {
2546 return;
2547 }
2548 }
2549
2550 self.last_auto_capture_at = Some(now);
2551
2552 let thread_id = self.id().clone();
2553 let github_login = self
2554 .project
2555 .read(cx)
2556 .user_store()
2557 .read(cx)
2558 .current_user()
2559 .map(|user| user.github_login.clone());
2560 let client = self.project.read(cx).client().clone();
2561 let serialize_task = self.serialize(cx);
2562
2563 cx.background_executor()
2564 .spawn(async move {
2565 if let Ok(serialized_thread) = serialize_task.await {
2566 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2567 telemetry::event!(
2568 "Agent Thread Auto-Captured",
2569 thread_id = thread_id.to_string(),
2570 thread_data = thread_data,
2571 auto_capture_reason = "tracked_user",
2572 github_login = github_login
2573 );
2574
2575 client.telemetry().flush_events().await;
2576 }
2577 }
2578 })
2579 .detach();
2580 }
2581
2582 pub fn cumulative_token_usage(&self) -> TokenUsage {
2583 self.cumulative_token_usage
2584 }
2585
2586 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2587 let Some(model) = self.configured_model.as_ref() else {
2588 return TotalTokenUsage::default();
2589 };
2590
2591 let max = model.model.max_token_count();
2592
2593 let index = self
2594 .messages
2595 .iter()
2596 .position(|msg| msg.id == message_id)
2597 .unwrap_or(0);
2598
2599 if index == 0 {
2600 return TotalTokenUsage { total: 0, max };
2601 }
2602
2603 let token_usage = &self
2604 .request_token_usage
2605 .get(index - 1)
2606 .cloned()
2607 .unwrap_or_default();
2608
2609 TotalTokenUsage {
2610 total: token_usage.total_tokens() as usize,
2611 max,
2612 }
2613 }
2614
2615 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2616 let model = self.configured_model.as_ref()?;
2617
2618 let max = model.model.max_token_count();
2619
2620 if let Some(exceeded_error) = &self.exceeded_window_error {
2621 if model.model.id() == exceeded_error.model_id {
2622 return Some(TotalTokenUsage {
2623 total: exceeded_error.token_count,
2624 max,
2625 });
2626 }
2627 }
2628
2629 let total = self
2630 .token_usage_at_last_message()
2631 .unwrap_or_default()
2632 .total_tokens() as usize;
2633
2634 Some(TotalTokenUsage { total, max })
2635 }
2636
2637 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2638 self.request_token_usage
2639 .get(self.messages.len().saturating_sub(1))
2640 .or_else(|| self.request_token_usage.last())
2641 .cloned()
2642 }
2643
2644 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2645 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2646 self.request_token_usage
2647 .resize(self.messages.len(), placeholder);
2648
2649 if let Some(last) = self.request_token_usage.last_mut() {
2650 *last = token_usage;
2651 }
2652 }
2653
2654 pub fn deny_tool_use(
2655 &mut self,
2656 tool_use_id: LanguageModelToolUseId,
2657 tool_name: Arc<str>,
2658 window: Option<AnyWindowHandle>,
2659 cx: &mut Context<Self>,
2660 ) {
2661 let err = Err(anyhow::anyhow!(
2662 "Permission to run tool action denied by user"
2663 ));
2664
2665 self.tool_use.insert_tool_output(
2666 tool_use_id.clone(),
2667 tool_name,
2668 err,
2669 self.configured_model.as_ref(),
2670 );
2671 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2672 }
2673}
2674
2675#[derive(Debug, Clone, Error)]
2676pub enum ThreadError {
2677 #[error("Payment required")]
2678 PaymentRequired,
2679 #[error("Max monthly spend reached")]
2680 MaxMonthlySpendReached,
2681 #[error("Model request limit reached")]
2682 ModelRequestLimitReached { plan: Plan },
2683 #[error("Message {header}: {message}")]
2684 Message {
2685 header: SharedString,
2686 message: SharedString,
2687 },
2688}
2689
2690#[derive(Debug, Clone)]
2691pub enum ThreadEvent {
2692 ShowError(ThreadError),
2693 StreamedCompletion,
2694 ReceivedTextChunk,
2695 NewRequest,
2696 StreamedAssistantText(MessageId, String),
2697 StreamedAssistantThinking(MessageId, String),
2698 StreamedToolUse {
2699 tool_use_id: LanguageModelToolUseId,
2700 ui_text: Arc<str>,
2701 input: serde_json::Value,
2702 },
2703 MissingToolUse {
2704 tool_use_id: LanguageModelToolUseId,
2705 ui_text: Arc<str>,
2706 },
2707 InvalidToolInput {
2708 tool_use_id: LanguageModelToolUseId,
2709 ui_text: Arc<str>,
2710 invalid_input_json: Arc<str>,
2711 },
2712 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2713 MessageAdded(MessageId),
2714 MessageEdited(MessageId),
2715 MessageDeleted(MessageId),
2716 SummaryGenerated,
2717 SummaryChanged,
2718 UsePendingTools {
2719 tool_uses: Vec<PendingToolUse>,
2720 },
2721 ToolFinished {
2722 #[allow(unused)]
2723 tool_use_id: LanguageModelToolUseId,
2724 /// The pending tool use that corresponds to this tool.
2725 pending_tool_use: Option<PendingToolUse>,
2726 },
2727 CheckpointChanged,
2728 ToolConfirmationNeeded,
2729 CancelEditing,
2730 CompletionCanceled,
2731}
2732
2733impl EventEmitter<ThreadEvent> for Thread {}
2734
2735struct PendingCompletion {
2736 id: usize,
2737 queue_state: QueueState,
2738 _task: Task<()>,
2739}
2740
2741#[cfg(test)]
2742mod tests {
2743 use super::*;
2744 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2745 use assistant_settings::{AssistantSettings, LanguageModelParameters};
2746 use assistant_tool::ToolRegistry;
2747 use editor::EditorSettings;
2748 use gpui::TestAppContext;
2749 use language_model::fake_provider::FakeLanguageModel;
2750 use project::{FakeFs, Project};
2751 use prompt_store::PromptBuilder;
2752 use serde_json::json;
2753 use settings::{Settings, SettingsStore};
2754 use std::sync::Arc;
2755 use theme::ThemeSettings;
2756 use util::path;
2757 use workspace::Workspace;
2758
2759 #[gpui::test]
2760 async fn test_message_with_context(cx: &mut TestAppContext) {
2761 init_test_settings(cx);
2762
2763 let project = create_test_project(
2764 cx,
2765 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2766 )
2767 .await;
2768
2769 let (_workspace, _thread_store, thread, context_store, model) =
2770 setup_test_environment(cx, project.clone()).await;
2771
2772 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2773 .await
2774 .unwrap();
2775
2776 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
2777 let loaded_context = cx
2778 .update(|cx| load_context(vec![context], &project, &None, cx))
2779 .await;
2780
2781 // Insert user message with context
2782 let message_id = thread.update(cx, |thread, cx| {
2783 thread.insert_user_message(
2784 "Please explain this code",
2785 loaded_context,
2786 None,
2787 Vec::new(),
2788 cx,
2789 )
2790 });
2791
2792 // Check content and context in message object
2793 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2794
2795 // Use different path format strings based on platform for the test
2796 #[cfg(windows)]
2797 let path_part = r"test\code.rs";
2798 #[cfg(not(windows))]
2799 let path_part = "test/code.rs";
2800
2801 let expected_context = format!(
2802 r#"
2803<context>
2804The following items were attached by the user. They are up-to-date and don't need to be re-read.
2805
2806<files>
2807```rs {path_part}
2808fn main() {{
2809 println!("Hello, world!");
2810}}
2811```
2812</files>
2813</context>
2814"#
2815 );
2816
2817 assert_eq!(message.role, Role::User);
2818 assert_eq!(message.segments.len(), 1);
2819 assert_eq!(
2820 message.segments[0],
2821 MessageSegment::Text("Please explain this code".to_string())
2822 );
2823 assert_eq!(message.loaded_context.text, expected_context);
2824
2825 // Check message in request
2826 let request = thread.update(cx, |thread, cx| {
2827 thread.to_completion_request(model.clone(), cx)
2828 });
2829
2830 assert_eq!(request.messages.len(), 2);
2831 let expected_full_message = format!("{}Please explain this code", expected_context);
2832 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2833 }
2834
2835 #[gpui::test]
2836 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2837 init_test_settings(cx);
2838
2839 let project = create_test_project(
2840 cx,
2841 json!({
2842 "file1.rs": "fn function1() {}\n",
2843 "file2.rs": "fn function2() {}\n",
2844 "file3.rs": "fn function3() {}\n",
2845 "file4.rs": "fn function4() {}\n",
2846 }),
2847 )
2848 .await;
2849
2850 let (_, _thread_store, thread, context_store, model) =
2851 setup_test_environment(cx, project.clone()).await;
2852
2853 // First message with context 1
2854 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2855 .await
2856 .unwrap();
2857 let new_contexts = context_store.update(cx, |store, cx| {
2858 store.new_context_for_thread(thread.read(cx), None)
2859 });
2860 assert_eq!(new_contexts.len(), 1);
2861 let loaded_context = cx
2862 .update(|cx| load_context(new_contexts, &project, &None, cx))
2863 .await;
2864 let message1_id = thread.update(cx, |thread, cx| {
2865 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2866 });
2867
2868 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2869 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2870 .await
2871 .unwrap();
2872 let new_contexts = context_store.update(cx, |store, cx| {
2873 store.new_context_for_thread(thread.read(cx), None)
2874 });
2875 assert_eq!(new_contexts.len(), 1);
2876 let loaded_context = cx
2877 .update(|cx| load_context(new_contexts, &project, &None, cx))
2878 .await;
2879 let message2_id = thread.update(cx, |thread, cx| {
2880 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2881 });
2882
2883 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2884 //
2885 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2886 .await
2887 .unwrap();
2888 let new_contexts = context_store.update(cx, |store, cx| {
2889 store.new_context_for_thread(thread.read(cx), None)
2890 });
2891 assert_eq!(new_contexts.len(), 1);
2892 let loaded_context = cx
2893 .update(|cx| load_context(new_contexts, &project, &None, cx))
2894 .await;
2895 let message3_id = thread.update(cx, |thread, cx| {
2896 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
2897 });
2898
2899 // Check what contexts are included in each message
2900 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2901 (
2902 thread.message(message1_id).unwrap().clone(),
2903 thread.message(message2_id).unwrap().clone(),
2904 thread.message(message3_id).unwrap().clone(),
2905 )
2906 });
2907
2908 // First message should include context 1
2909 assert!(message1.loaded_context.text.contains("file1.rs"));
2910
2911 // Second message should include only context 2 (not 1)
2912 assert!(!message2.loaded_context.text.contains("file1.rs"));
2913 assert!(message2.loaded_context.text.contains("file2.rs"));
2914
2915 // Third message should include only context 3 (not 1 or 2)
2916 assert!(!message3.loaded_context.text.contains("file1.rs"));
2917 assert!(!message3.loaded_context.text.contains("file2.rs"));
2918 assert!(message3.loaded_context.text.contains("file3.rs"));
2919
2920 // Check entire request to make sure all contexts are properly included
2921 let request = thread.update(cx, |thread, cx| {
2922 thread.to_completion_request(model.clone(), cx)
2923 });
2924
2925 // The request should contain all 3 messages
2926 assert_eq!(request.messages.len(), 4);
2927
2928 // Check that the contexts are properly formatted in each message
2929 assert!(request.messages[1].string_contents().contains("file1.rs"));
2930 assert!(!request.messages[1].string_contents().contains("file2.rs"));
2931 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2932
2933 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2934 assert!(request.messages[2].string_contents().contains("file2.rs"));
2935 assert!(!request.messages[2].string_contents().contains("file3.rs"));
2936
2937 assert!(!request.messages[3].string_contents().contains("file1.rs"));
2938 assert!(!request.messages[3].string_contents().contains("file2.rs"));
2939 assert!(request.messages[3].string_contents().contains("file3.rs"));
2940
2941 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
2942 .await
2943 .unwrap();
2944 let new_contexts = context_store.update(cx, |store, cx| {
2945 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2946 });
2947 assert_eq!(new_contexts.len(), 3);
2948 let loaded_context = cx
2949 .update(|cx| load_context(new_contexts, &project, &None, cx))
2950 .await
2951 .loaded_context;
2952
2953 assert!(!loaded_context.text.contains("file1.rs"));
2954 assert!(loaded_context.text.contains("file2.rs"));
2955 assert!(loaded_context.text.contains("file3.rs"));
2956 assert!(loaded_context.text.contains("file4.rs"));
2957
2958 let new_contexts = context_store.update(cx, |store, cx| {
2959 // Remove file4.rs
2960 store.remove_context(&loaded_context.contexts[2].handle(), cx);
2961 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2962 });
2963 assert_eq!(new_contexts.len(), 2);
2964 let loaded_context = cx
2965 .update(|cx| load_context(new_contexts, &project, &None, cx))
2966 .await
2967 .loaded_context;
2968
2969 assert!(!loaded_context.text.contains("file1.rs"));
2970 assert!(loaded_context.text.contains("file2.rs"));
2971 assert!(loaded_context.text.contains("file3.rs"));
2972 assert!(!loaded_context.text.contains("file4.rs"));
2973
2974 let new_contexts = context_store.update(cx, |store, cx| {
2975 // Remove file3.rs
2976 store.remove_context(&loaded_context.contexts[1].handle(), cx);
2977 store.new_context_for_thread(thread.read(cx), Some(message2_id))
2978 });
2979 assert_eq!(new_contexts.len(), 1);
2980 let loaded_context = cx
2981 .update(|cx| load_context(new_contexts, &project, &None, cx))
2982 .await
2983 .loaded_context;
2984
2985 assert!(!loaded_context.text.contains("file1.rs"));
2986 assert!(loaded_context.text.contains("file2.rs"));
2987 assert!(!loaded_context.text.contains("file3.rs"));
2988 assert!(!loaded_context.text.contains("file4.rs"));
2989 }
2990
2991 #[gpui::test]
2992 async fn test_message_without_files(cx: &mut TestAppContext) {
2993 init_test_settings(cx);
2994
2995 let project = create_test_project(
2996 cx,
2997 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2998 )
2999 .await;
3000
3001 let (_, _thread_store, thread, _context_store, model) =
3002 setup_test_environment(cx, project.clone()).await;
3003
3004 // Insert user message without any context (empty context vector)
3005 let message_id = thread.update(cx, |thread, cx| {
3006 thread.insert_user_message(
3007 "What is the best way to learn Rust?",
3008 ContextLoadResult::default(),
3009 None,
3010 Vec::new(),
3011 cx,
3012 )
3013 });
3014
3015 // Check content and context in message object
3016 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3017
3018 // Context should be empty when no files are included
3019 assert_eq!(message.role, Role::User);
3020 assert_eq!(message.segments.len(), 1);
3021 assert_eq!(
3022 message.segments[0],
3023 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3024 );
3025 assert_eq!(message.loaded_context.text, "");
3026
3027 // Check message in request
3028 let request = thread.update(cx, |thread, cx| {
3029 thread.to_completion_request(model.clone(), cx)
3030 });
3031
3032 assert_eq!(request.messages.len(), 2);
3033 assert_eq!(
3034 request.messages[1].string_contents(),
3035 "What is the best way to learn Rust?"
3036 );
3037
3038 // Add second message, also without context
3039 let message2_id = thread.update(cx, |thread, cx| {
3040 thread.insert_user_message(
3041 "Are there any good books?",
3042 ContextLoadResult::default(),
3043 None,
3044 Vec::new(),
3045 cx,
3046 )
3047 });
3048
3049 let message2 =
3050 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3051 assert_eq!(message2.loaded_context.text, "");
3052
3053 // Check that both messages appear in the request
3054 let request = thread.update(cx, |thread, cx| {
3055 thread.to_completion_request(model.clone(), cx)
3056 });
3057
3058 assert_eq!(request.messages.len(), 3);
3059 assert_eq!(
3060 request.messages[1].string_contents(),
3061 "What is the best way to learn Rust?"
3062 );
3063 assert_eq!(
3064 request.messages[2].string_contents(),
3065 "Are there any good books?"
3066 );
3067 }
3068
3069 #[gpui::test]
3070 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3071 init_test_settings(cx);
3072
3073 let project = create_test_project(
3074 cx,
3075 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3076 )
3077 .await;
3078
3079 let (_workspace, _thread_store, thread, context_store, model) =
3080 setup_test_environment(cx, project.clone()).await;
3081
3082 // Open buffer and add it to context
3083 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3084 .await
3085 .unwrap();
3086
3087 let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
3088 let loaded_context = cx
3089 .update(|cx| load_context(vec![context], &project, &None, cx))
3090 .await;
3091
3092 // Insert user message with the buffer as context
3093 thread.update(cx, |thread, cx| {
3094 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3095 });
3096
3097 // Create a request and check that it doesn't have a stale buffer warning yet
3098 let initial_request = thread.update(cx, |thread, cx| {
3099 thread.to_completion_request(model.clone(), cx)
3100 });
3101
3102 // Make sure we don't have a stale file warning yet
3103 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3104 msg.string_contents()
3105 .contains("These files changed since last read:")
3106 });
3107 assert!(
3108 !has_stale_warning,
3109 "Should not have stale buffer warning before buffer is modified"
3110 );
3111
3112 // Modify the buffer
3113 buffer.update(cx, |buffer, cx| {
3114 // Find a position at the end of line 1
3115 buffer.edit(
3116 [(1..1, "\n println!(\"Added a new line\");\n")],
3117 None,
3118 cx,
3119 );
3120 });
3121
3122 // Insert another user message without context
3123 thread.update(cx, |thread, cx| {
3124 thread.insert_user_message(
3125 "What does the code do now?",
3126 ContextLoadResult::default(),
3127 None,
3128 Vec::new(),
3129 cx,
3130 )
3131 });
3132
3133 // Create a new request and check for the stale buffer warning
3134 let new_request = thread.update(cx, |thread, cx| {
3135 thread.to_completion_request(model.clone(), cx)
3136 });
3137
3138 // We should have a stale file warning as the last message
3139 let last_message = new_request
3140 .messages
3141 .last()
3142 .expect("Request should have messages");
3143
3144 // The last message should be the stale buffer notification
3145 assert_eq!(last_message.role, Role::User);
3146
3147 // Check the exact content of the message
3148 let expected_content = "These files changed since last read:\n- code.rs\n";
3149 assert_eq!(
3150 last_message.string_contents(),
3151 expected_content,
3152 "Last message should be exactly the stale buffer notification"
3153 );
3154 }
3155
3156 #[gpui::test]
3157 async fn test_temperature_setting(cx: &mut TestAppContext) {
3158 init_test_settings(cx);
3159
3160 let project = create_test_project(
3161 cx,
3162 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3163 )
3164 .await;
3165
3166 let (_workspace, _thread_store, thread, _context_store, model) =
3167 setup_test_environment(cx, project.clone()).await;
3168
3169 // Both model and provider
3170 cx.update(|cx| {
3171 AssistantSettings::override_global(
3172 AssistantSettings {
3173 model_parameters: vec![LanguageModelParameters {
3174 provider: Some(model.provider_id().0.to_string().into()),
3175 model: Some(model.id().0.clone()),
3176 temperature: Some(0.66),
3177 }],
3178 ..AssistantSettings::get_global(cx).clone()
3179 },
3180 cx,
3181 );
3182 });
3183
3184 let request = thread.update(cx, |thread, cx| {
3185 thread.to_completion_request(model.clone(), cx)
3186 });
3187 assert_eq!(request.temperature, Some(0.66));
3188
3189 // Only model
3190 cx.update(|cx| {
3191 AssistantSettings::override_global(
3192 AssistantSettings {
3193 model_parameters: vec![LanguageModelParameters {
3194 provider: None,
3195 model: Some(model.id().0.clone()),
3196 temperature: Some(0.66),
3197 }],
3198 ..AssistantSettings::get_global(cx).clone()
3199 },
3200 cx,
3201 );
3202 });
3203
3204 let request = thread.update(cx, |thread, cx| {
3205 thread.to_completion_request(model.clone(), cx)
3206 });
3207 assert_eq!(request.temperature, Some(0.66));
3208
3209 // Only provider
3210 cx.update(|cx| {
3211 AssistantSettings::override_global(
3212 AssistantSettings {
3213 model_parameters: vec![LanguageModelParameters {
3214 provider: Some(model.provider_id().0.to_string().into()),
3215 model: None,
3216 temperature: Some(0.66),
3217 }],
3218 ..AssistantSettings::get_global(cx).clone()
3219 },
3220 cx,
3221 );
3222 });
3223
3224 let request = thread.update(cx, |thread, cx| {
3225 thread.to_completion_request(model.clone(), cx)
3226 });
3227 assert_eq!(request.temperature, Some(0.66));
3228
3229 // Same model name, different provider
3230 cx.update(|cx| {
3231 AssistantSettings::override_global(
3232 AssistantSettings {
3233 model_parameters: vec![LanguageModelParameters {
3234 provider: Some("anthropic".into()),
3235 model: Some(model.id().0.clone()),
3236 temperature: Some(0.66),
3237 }],
3238 ..AssistantSettings::get_global(cx).clone()
3239 },
3240 cx,
3241 );
3242 });
3243
3244 let request = thread.update(cx, |thread, cx| {
3245 thread.to_completion_request(model.clone(), cx)
3246 });
3247 assert_eq!(request.temperature, None);
3248 }
3249
3250 fn init_test_settings(cx: &mut TestAppContext) {
3251 cx.update(|cx| {
3252 let settings_store = SettingsStore::test(cx);
3253 cx.set_global(settings_store);
3254 language::init(cx);
3255 Project::init_settings(cx);
3256 AssistantSettings::register(cx);
3257 prompt_store::init(cx);
3258 thread_store::init(cx);
3259 workspace::init_settings(cx);
3260 language_model::init_settings(cx);
3261 ThemeSettings::register(cx);
3262 EditorSettings::register(cx);
3263 ToolRegistry::default_global(cx);
3264 });
3265 }
3266
3267 // Helper to create a test project with test files
3268 async fn create_test_project(
3269 cx: &mut TestAppContext,
3270 files: serde_json::Value,
3271 ) -> Entity<Project> {
3272 let fs = FakeFs::new(cx.executor());
3273 fs.insert_tree(path!("/test"), files).await;
3274 Project::test(fs, [path!("/test").as_ref()], cx).await
3275 }
3276
3277 async fn setup_test_environment(
3278 cx: &mut TestAppContext,
3279 project: Entity<Project>,
3280 ) -> (
3281 Entity<Workspace>,
3282 Entity<ThreadStore>,
3283 Entity<Thread>,
3284 Entity<ContextStore>,
3285 Arc<dyn LanguageModel>,
3286 ) {
3287 let (workspace, cx) =
3288 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3289
3290 let thread_store = cx
3291 .update(|_, cx| {
3292 ThreadStore::load(
3293 project.clone(),
3294 cx.new(|_| ToolWorkingSet::default()),
3295 None,
3296 Arc::new(PromptBuilder::new(None).unwrap()),
3297 cx,
3298 )
3299 })
3300 .await
3301 .unwrap();
3302
3303 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3304 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3305
3306 let model = FakeLanguageModel::default();
3307 let model: Arc<dyn LanguageModel> = Arc::new(model);
3308
3309 (workspace, thread_store, thread, context_store, model)
3310 }
3311
3312 async fn add_file_to_context(
3313 project: &Entity<Project>,
3314 context_store: &Entity<ContextStore>,
3315 path: &str,
3316 cx: &mut TestAppContext,
3317 ) -> Result<Entity<language::Buffer>> {
3318 let buffer_path = project
3319 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3320 .unwrap();
3321
3322 let buffer = project
3323 .update(cx, |project, cx| {
3324 project.open_buffer(buffer_path.clone(), cx)
3325 })
3326 .await
3327 .unwrap();
3328
3329 context_store.update(cx, |context_store, cx| {
3330 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3331 });
3332
3333 Ok(buffer)
3334 }
3335}