1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{HashMap, HashSet};
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::agent_profile::AgentProfile;
45use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
46use crate::thread_store::{
47 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
48 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
49};
50use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
51
52#[derive(
53 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
54)]
55pub struct ThreadId(Arc<str>);
56
57impl ThreadId {
58 pub fn new() -> Self {
59 Self(Uuid::new_v4().to_string().into())
60 }
61}
62
63impl std::fmt::Display for ThreadId {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 write!(f, "{}", self.0)
66 }
67}
68
69impl From<&str> for ThreadId {
70 fn from(value: &str) -> Self {
71 Self(value.into())
72 }
73}
74
75/// The ID of the user prompt that initiated a request.
76///
77/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
78#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
79pub struct PromptId(Arc<str>);
80
81impl PromptId {
82 pub fn new() -> Self {
83 Self(Uuid::new_v4().to_string().into())
84 }
85}
86
87impl std::fmt::Display for PromptId {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(f, "{}", self.0)
90 }
91}
92
93#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
94pub struct MessageId(pub(crate) usize);
95
96impl MessageId {
97 fn post_inc(&mut self) -> Self {
98 Self(post_inc(&mut self.0))
99 }
100}
101
102/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
103#[derive(Clone, Debug)]
104pub struct MessageCrease {
105 pub range: Range<usize>,
106 pub metadata: CreaseMetadata,
107 /// None for a deserialized message, Some otherwise.
108 pub context: Option<AgentContextHandle>,
109}
110
111/// A message in a [`Thread`].
112#[derive(Debug, Clone)]
113pub struct Message {
114 pub id: MessageId,
115 pub role: Role,
116 pub segments: Vec<MessageSegment>,
117 pub loaded_context: LoadedContext,
118 pub creases: Vec<MessageCrease>,
119 pub is_hidden: bool,
120}
121
122impl Message {
123 /// Returns whether the message contains any meaningful text that should be displayed
124 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
125 pub fn should_display_content(&self) -> bool {
126 self.segments.iter().all(|segment| segment.should_display())
127 }
128
129 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
130 if let Some(MessageSegment::Thinking {
131 text: segment,
132 signature: current_signature,
133 }) = self.segments.last_mut()
134 {
135 if let Some(signature) = signature {
136 *current_signature = Some(signature);
137 }
138 segment.push_str(text);
139 } else {
140 self.segments.push(MessageSegment::Thinking {
141 text: text.to_string(),
142 signature,
143 });
144 }
145 }
146
147 pub fn push_redacted_thinking(&mut self, data: String) {
148 self.segments.push(MessageSegment::RedactedThinking(data));
149 }
150
151 pub fn push_text(&mut self, text: &str) {
152 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
153 segment.push_str(text);
154 } else {
155 self.segments.push(MessageSegment::Text(text.to_string()));
156 }
157 }
158
159 pub fn to_string(&self) -> String {
160 let mut result = String::new();
161
162 if !self.loaded_context.text.is_empty() {
163 result.push_str(&self.loaded_context.text);
164 }
165
166 for segment in &self.segments {
167 match segment {
168 MessageSegment::Text(text) => result.push_str(text),
169 MessageSegment::Thinking { text, .. } => {
170 result.push_str("<think>\n");
171 result.push_str(text);
172 result.push_str("\n</think>");
173 }
174 MessageSegment::RedactedThinking(_) => {}
175 }
176 }
177
178 result
179 }
180}
181
182#[derive(Debug, Clone, PartialEq, Eq)]
183pub enum MessageSegment {
184 Text(String),
185 Thinking {
186 text: String,
187 signature: Option<String>,
188 },
189 RedactedThinking(String),
190}
191
192impl MessageSegment {
193 pub fn should_display(&self) -> bool {
194 match self {
195 Self::Text(text) => text.is_empty(),
196 Self::Thinking { text, .. } => text.is_empty(),
197 Self::RedactedThinking(_) => false,
198 }
199 }
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
203pub struct ProjectSnapshot {
204 pub worktree_snapshots: Vec<WorktreeSnapshot>,
205 pub unsaved_buffer_paths: Vec<String>,
206 pub timestamp: DateTime<Utc>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
210pub struct WorktreeSnapshot {
211 pub worktree_path: String,
212 pub git_state: Option<GitState>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
216pub struct GitState {
217 pub remote_url: Option<String>,
218 pub head_sha: Option<String>,
219 pub current_branch: Option<String>,
220 pub diff: Option<String>,
221}
222
223#[derive(Clone, Debug)]
224pub struct ThreadCheckpoint {
225 message_id: MessageId,
226 git_checkpoint: GitStoreCheckpoint,
227}
228
229#[derive(Copy, Clone, Debug, PartialEq, Eq)]
230pub enum ThreadFeedback {
231 Positive,
232 Negative,
233}
234
235pub enum LastRestoreCheckpoint {
236 Pending {
237 message_id: MessageId,
238 },
239 Error {
240 message_id: MessageId,
241 error: String,
242 },
243}
244
245impl LastRestoreCheckpoint {
246 pub fn message_id(&self) -> MessageId {
247 match self {
248 LastRestoreCheckpoint::Pending { message_id } => *message_id,
249 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
250 }
251 }
252}
253
254#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
255pub enum DetailedSummaryState {
256 #[default]
257 NotGenerated,
258 Generating {
259 message_id: MessageId,
260 },
261 Generated {
262 text: SharedString,
263 message_id: MessageId,
264 },
265}
266
267impl DetailedSummaryState {
268 fn text(&self) -> Option<SharedString> {
269 if let Self::Generated { text, .. } = self {
270 Some(text.clone())
271 } else {
272 None
273 }
274 }
275}
276
277#[derive(Default, Debug)]
278pub struct TotalTokenUsage {
279 pub total: u64,
280 pub max: u64,
281}
282
283impl TotalTokenUsage {
284 pub fn ratio(&self) -> TokenUsageRatio {
285 #[cfg(debug_assertions)]
286 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
287 .unwrap_or("0.8".to_string())
288 .parse()
289 .unwrap();
290 #[cfg(not(debug_assertions))]
291 let warning_threshold: f32 = 0.8;
292
293 // When the maximum is unknown because there is no selected model,
294 // avoid showing the token limit warning.
295 if self.max == 0 {
296 TokenUsageRatio::Normal
297 } else if self.total >= self.max {
298 TokenUsageRatio::Exceeded
299 } else if self.total as f32 / self.max as f32 >= warning_threshold {
300 TokenUsageRatio::Warning
301 } else {
302 TokenUsageRatio::Normal
303 }
304 }
305
306 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
307 TotalTokenUsage {
308 total: self.total + tokens,
309 max: self.max,
310 }
311 }
312}
313
314#[derive(Debug, Default, PartialEq, Eq)]
315pub enum TokenUsageRatio {
316 #[default]
317 Normal,
318 Warning,
319 Exceeded,
320}
321
322#[derive(Debug, Clone, Copy)]
323pub enum QueueState {
324 Sending,
325 Queued { position: usize },
326 Started,
327}
328
329/// A thread of conversation with the LLM.
330pub struct Thread {
331 id: ThreadId,
332 updated_at: DateTime<Utc>,
333 summary: ThreadSummary,
334 pending_summary: Task<Option<()>>,
335 detailed_summary_task: Task<Option<()>>,
336 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
337 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
338 completion_mode: agent_settings::CompletionMode,
339 messages: Vec<Message>,
340 next_message_id: MessageId,
341 last_prompt_id: PromptId,
342 project_context: SharedProjectContext,
343 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
344 completion_count: usize,
345 pending_completions: Vec<PendingCompletion>,
346 project: Entity<Project>,
347 prompt_builder: Arc<PromptBuilder>,
348 tools: Entity<ToolWorkingSet>,
349 tool_use: ToolUseState,
350 action_log: Entity<ActionLog>,
351 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
352 pending_checkpoint: Option<ThreadCheckpoint>,
353 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
354 request_token_usage: Vec<TokenUsage>,
355 cumulative_token_usage: TokenUsage,
356 exceeded_window_error: Option<ExceededWindowError>,
357 last_usage: Option<RequestUsage>,
358 tool_use_limit_reached: bool,
359 feedback: Option<ThreadFeedback>,
360 message_feedback: HashMap<MessageId, ThreadFeedback>,
361 last_auto_capture_at: Option<Instant>,
362 last_received_chunk_at: Option<Instant>,
363 request_callback: Option<
364 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
365 >,
366 remaining_turns: u32,
367 configured_model: Option<ConfiguredModel>,
368 profile: AgentProfile,
369}
370
371#[derive(Clone, Debug, PartialEq, Eq)]
372pub enum ThreadSummary {
373 Pending,
374 Generating,
375 Ready(SharedString),
376 Error,
377}
378
379impl ThreadSummary {
380 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
381
382 pub fn or_default(&self) -> SharedString {
383 self.unwrap_or(Self::DEFAULT)
384 }
385
386 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
387 self.ready().unwrap_or_else(|| message.into())
388 }
389
390 pub fn ready(&self) -> Option<SharedString> {
391 match self {
392 ThreadSummary::Ready(summary) => Some(summary.clone()),
393 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
394 }
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
399pub struct ExceededWindowError {
400 /// Model used when last message exceeded context window
401 model_id: LanguageModelId,
402 /// Token count including last message
403 token_count: u64,
404}
405
406impl Thread {
407 pub fn new(
408 project: Entity<Project>,
409 tools: Entity<ToolWorkingSet>,
410 prompt_builder: Arc<PromptBuilder>,
411 system_prompt: SharedProjectContext,
412 cx: &mut Context<Self>,
413 ) -> Self {
414 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
415 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
416 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
417
418 Self {
419 id: ThreadId::new(),
420 updated_at: Utc::now(),
421 summary: ThreadSummary::Pending,
422 pending_summary: Task::ready(None),
423 detailed_summary_task: Task::ready(None),
424 detailed_summary_tx,
425 detailed_summary_rx,
426 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
427 messages: Vec::new(),
428 next_message_id: MessageId(0),
429 last_prompt_id: PromptId::new(),
430 project_context: system_prompt,
431 checkpoints_by_message: HashMap::default(),
432 completion_count: 0,
433 pending_completions: Vec::new(),
434 project: project.clone(),
435 prompt_builder,
436 tools: tools.clone(),
437 last_restore_checkpoint: None,
438 pending_checkpoint: None,
439 tool_use: ToolUseState::new(tools.clone()),
440 action_log: cx.new(|_| ActionLog::new(project.clone())),
441 initial_project_snapshot: {
442 let project_snapshot = Self::project_snapshot(project, cx);
443 cx.foreground_executor()
444 .spawn(async move { Some(project_snapshot.await) })
445 .shared()
446 },
447 request_token_usage: Vec::new(),
448 cumulative_token_usage: TokenUsage::default(),
449 exceeded_window_error: None,
450 last_usage: None,
451 tool_use_limit_reached: false,
452 feedback: None,
453 message_feedback: HashMap::default(),
454 last_auto_capture_at: None,
455 last_received_chunk_at: None,
456 request_callback: None,
457 remaining_turns: u32::MAX,
458 configured_model,
459 profile: AgentProfile::new(profile_id, tools),
460 }
461 }
462
463 pub fn deserialize(
464 id: ThreadId,
465 serialized: SerializedThread,
466 project: Entity<Project>,
467 tools: Entity<ToolWorkingSet>,
468 prompt_builder: Arc<PromptBuilder>,
469 project_context: SharedProjectContext,
470 window: Option<&mut Window>, // None in headless mode
471 cx: &mut Context<Self>,
472 ) -> Self {
473 let next_message_id = MessageId(
474 serialized
475 .messages
476 .last()
477 .map(|message| message.id.0 + 1)
478 .unwrap_or(0),
479 );
480 let tool_use = ToolUseState::from_serialized_messages(
481 tools.clone(),
482 &serialized.messages,
483 project.clone(),
484 window,
485 cx,
486 );
487 let (detailed_summary_tx, detailed_summary_rx) =
488 postage::watch::channel_with(serialized.detailed_summary_state);
489
490 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
491 serialized
492 .model
493 .and_then(|model| {
494 let model = SelectedModel {
495 provider: model.provider.clone().into(),
496 model: model.model.clone().into(),
497 };
498 registry.select_model(&model, cx)
499 })
500 .or_else(|| registry.default_model())
501 });
502
503 let completion_mode = serialized
504 .completion_mode
505 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
506 let profile_id = serialized
507 .profile
508 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
509
510 Self {
511 id,
512 updated_at: serialized.updated_at,
513 summary: ThreadSummary::Ready(serialized.summary),
514 pending_summary: Task::ready(None),
515 detailed_summary_task: Task::ready(None),
516 detailed_summary_tx,
517 detailed_summary_rx,
518 completion_mode,
519 messages: serialized
520 .messages
521 .into_iter()
522 .map(|message| Message {
523 id: message.id,
524 role: message.role,
525 segments: message
526 .segments
527 .into_iter()
528 .map(|segment| match segment {
529 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
530 SerializedMessageSegment::Thinking { text, signature } => {
531 MessageSegment::Thinking { text, signature }
532 }
533 SerializedMessageSegment::RedactedThinking { data } => {
534 MessageSegment::RedactedThinking(data)
535 }
536 })
537 .collect(),
538 loaded_context: LoadedContext {
539 contexts: Vec::new(),
540 text: message.context,
541 images: Vec::new(),
542 },
543 creases: message
544 .creases
545 .into_iter()
546 .map(|crease| MessageCrease {
547 range: crease.start..crease.end,
548 metadata: CreaseMetadata {
549 icon_path: crease.icon_path,
550 label: crease.label,
551 },
552 context: None,
553 })
554 .collect(),
555 is_hidden: message.is_hidden,
556 })
557 .collect(),
558 next_message_id,
559 last_prompt_id: PromptId::new(),
560 project_context,
561 checkpoints_by_message: HashMap::default(),
562 completion_count: 0,
563 pending_completions: Vec::new(),
564 last_restore_checkpoint: None,
565 pending_checkpoint: None,
566 project: project.clone(),
567 prompt_builder,
568 tools: tools.clone(),
569 tool_use,
570 action_log: cx.new(|_| ActionLog::new(project)),
571 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
572 request_token_usage: serialized.request_token_usage,
573 cumulative_token_usage: serialized.cumulative_token_usage,
574 exceeded_window_error: None,
575 last_usage: None,
576 tool_use_limit_reached: serialized.tool_use_limit_reached,
577 feedback: None,
578 message_feedback: HashMap::default(),
579 last_auto_capture_at: None,
580 last_received_chunk_at: None,
581 request_callback: None,
582 remaining_turns: u32::MAX,
583 configured_model,
584 profile: AgentProfile::new(profile_id, tools),
585 }
586 }
587
588 pub fn set_request_callback(
589 &mut self,
590 callback: impl 'static
591 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
592 ) {
593 self.request_callback = Some(Box::new(callback));
594 }
595
596 pub fn id(&self) -> &ThreadId {
597 &self.id
598 }
599
600 pub fn profile(&self) -> &AgentProfile {
601 &self.profile
602 }
603
604 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
605 if &id != self.profile.id() {
606 self.profile = AgentProfile::new(id, self.tools.clone());
607 cx.emit(ThreadEvent::ProfileChanged);
608 }
609 }
610
611 pub fn is_empty(&self) -> bool {
612 self.messages.is_empty()
613 }
614
615 pub fn updated_at(&self) -> DateTime<Utc> {
616 self.updated_at
617 }
618
619 pub fn touch_updated_at(&mut self) {
620 self.updated_at = Utc::now();
621 }
622
623 pub fn advance_prompt_id(&mut self) {
624 self.last_prompt_id = PromptId::new();
625 }
626
627 pub fn project_context(&self) -> SharedProjectContext {
628 self.project_context.clone()
629 }
630
631 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
632 if self.configured_model.is_none() {
633 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
634 }
635 self.configured_model.clone()
636 }
637
638 pub fn configured_model(&self) -> Option<ConfiguredModel> {
639 self.configured_model.clone()
640 }
641
642 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
643 self.configured_model = model;
644 cx.notify();
645 }
646
647 pub fn summary(&self) -> &ThreadSummary {
648 &self.summary
649 }
650
651 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
652 let current_summary = match &self.summary {
653 ThreadSummary::Pending | ThreadSummary::Generating => return,
654 ThreadSummary::Ready(summary) => summary,
655 ThreadSummary::Error => &ThreadSummary::DEFAULT,
656 };
657
658 let mut new_summary = new_summary.into();
659
660 if new_summary.is_empty() {
661 new_summary = ThreadSummary::DEFAULT;
662 }
663
664 if current_summary != &new_summary {
665 self.summary = ThreadSummary::Ready(new_summary);
666 cx.emit(ThreadEvent::SummaryChanged);
667 }
668 }
669
670 pub fn completion_mode(&self) -> CompletionMode {
671 self.completion_mode
672 }
673
674 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
675 self.completion_mode = mode;
676 }
677
678 pub fn message(&self, id: MessageId) -> Option<&Message> {
679 let index = self
680 .messages
681 .binary_search_by(|message| message.id.cmp(&id))
682 .ok()?;
683
684 self.messages.get(index)
685 }
686
687 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
688 self.messages.iter()
689 }
690
691 pub fn is_generating(&self) -> bool {
692 !self.pending_completions.is_empty() || !self.all_tools_finished()
693 }
694
695 /// Indicates whether streaming of language model events is stale.
696 /// When `is_generating()` is false, this method returns `None`.
697 pub fn is_generation_stale(&self) -> Option<bool> {
698 const STALE_THRESHOLD: u128 = 250;
699
700 self.last_received_chunk_at
701 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
702 }
703
704 fn received_chunk(&mut self) {
705 self.last_received_chunk_at = Some(Instant::now());
706 }
707
708 pub fn queue_state(&self) -> Option<QueueState> {
709 self.pending_completions
710 .first()
711 .map(|pending_completion| pending_completion.queue_state)
712 }
713
714 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
715 &self.tools
716 }
717
718 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
719 self.tool_use
720 .pending_tool_uses()
721 .into_iter()
722 .find(|tool_use| &tool_use.id == id)
723 }
724
725 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
726 self.tool_use
727 .pending_tool_uses()
728 .into_iter()
729 .filter(|tool_use| tool_use.status.needs_confirmation())
730 }
731
732 pub fn has_pending_tool_uses(&self) -> bool {
733 !self.tool_use.pending_tool_uses().is_empty()
734 }
735
736 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
737 self.checkpoints_by_message.get(&id).cloned()
738 }
739
740 pub fn restore_checkpoint(
741 &mut self,
742 checkpoint: ThreadCheckpoint,
743 cx: &mut Context<Self>,
744 ) -> Task<Result<()>> {
745 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
746 message_id: checkpoint.message_id,
747 });
748 cx.emit(ThreadEvent::CheckpointChanged);
749 cx.notify();
750
751 let git_store = self.project().read(cx).git_store().clone();
752 let restore = git_store.update(cx, |git_store, cx| {
753 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
754 });
755
756 cx.spawn(async move |this, cx| {
757 let result = restore.await;
758 this.update(cx, |this, cx| {
759 if let Err(err) = result.as_ref() {
760 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
761 message_id: checkpoint.message_id,
762 error: err.to_string(),
763 });
764 } else {
765 this.truncate(checkpoint.message_id, cx);
766 this.last_restore_checkpoint = None;
767 }
768 this.pending_checkpoint = None;
769 cx.emit(ThreadEvent::CheckpointChanged);
770 cx.notify();
771 })?;
772 result
773 })
774 }
775
776 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
777 let pending_checkpoint = if self.is_generating() {
778 return;
779 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
780 checkpoint
781 } else {
782 return;
783 };
784
785 self.finalize_checkpoint(pending_checkpoint, cx);
786 }
787
788 fn finalize_checkpoint(
789 &mut self,
790 pending_checkpoint: ThreadCheckpoint,
791 cx: &mut Context<Self>,
792 ) {
793 let git_store = self.project.read(cx).git_store().clone();
794 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
795 cx.spawn(async move |this, cx| match final_checkpoint.await {
796 Ok(final_checkpoint) => {
797 let equal = git_store
798 .update(cx, |store, cx| {
799 store.compare_checkpoints(
800 pending_checkpoint.git_checkpoint.clone(),
801 final_checkpoint.clone(),
802 cx,
803 )
804 })?
805 .await
806 .unwrap_or(false);
807
808 if !equal {
809 this.update(cx, |this, cx| {
810 this.insert_checkpoint(pending_checkpoint, cx)
811 })?;
812 }
813
814 Ok(())
815 }
816 Err(_) => this.update(cx, |this, cx| {
817 this.insert_checkpoint(pending_checkpoint, cx)
818 }),
819 })
820 .detach();
821 }
822
823 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
824 self.checkpoints_by_message
825 .insert(checkpoint.message_id, checkpoint);
826 cx.emit(ThreadEvent::CheckpointChanged);
827 cx.notify();
828 }
829
830 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
831 self.last_restore_checkpoint.as_ref()
832 }
833
834 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
835 let Some(message_ix) = self
836 .messages
837 .iter()
838 .rposition(|message| message.id == message_id)
839 else {
840 return;
841 };
842 for deleted_message in self.messages.drain(message_ix..) {
843 self.checkpoints_by_message.remove(&deleted_message.id);
844 }
845 cx.notify();
846 }
847
848 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
849 self.messages
850 .iter()
851 .find(|message| message.id == id)
852 .into_iter()
853 .flat_map(|message| message.loaded_context.contexts.iter())
854 }
855
856 pub fn is_turn_end(&self, ix: usize) -> bool {
857 if self.messages.is_empty() {
858 return false;
859 }
860
861 if !self.is_generating() && ix == self.messages.len() - 1 {
862 return true;
863 }
864
865 let Some(message) = self.messages.get(ix) else {
866 return false;
867 };
868
869 if message.role != Role::Assistant {
870 return false;
871 }
872
873 self.messages
874 .get(ix + 1)
875 .and_then(|message| {
876 self.message(message.id)
877 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
878 })
879 .unwrap_or(false)
880 }
881
882 pub fn last_usage(&self) -> Option<RequestUsage> {
883 self.last_usage
884 }
885
886 pub fn tool_use_limit_reached(&self) -> bool {
887 self.tool_use_limit_reached
888 }
889
890 /// Returns whether all of the tool uses have finished running.
891 pub fn all_tools_finished(&self) -> bool {
892 // If the only pending tool uses left are the ones with errors, then
893 // that means that we've finished running all of the pending tools.
894 self.tool_use
895 .pending_tool_uses()
896 .iter()
897 .all(|pending_tool_use| pending_tool_use.status.is_error())
898 }
899
900 /// Returns whether any pending tool uses may perform edits
901 pub fn has_pending_edit_tool_uses(&self) -> bool {
902 self.tool_use
903 .pending_tool_uses()
904 .iter()
905 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
906 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
907 }
908
909 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
910 self.tool_use.tool_uses_for_message(id, cx)
911 }
912
913 pub fn tool_results_for_message(
914 &self,
915 assistant_message_id: MessageId,
916 ) -> Vec<&LanguageModelToolResult> {
917 self.tool_use.tool_results_for_message(assistant_message_id)
918 }
919
920 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
921 self.tool_use.tool_result(id)
922 }
923
924 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
925 match &self.tool_use.tool_result(id)?.content {
926 LanguageModelToolResultContent::Text(text) => Some(text),
927 LanguageModelToolResultContent::Image(_) => {
928 // TODO: We should display image
929 None
930 }
931 }
932 }
933
934 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
935 self.tool_use.tool_result_card(id).cloned()
936 }
937
938 /// Return tools that are both enabled and supported by the model
939 pub fn available_tools(
940 &self,
941 cx: &App,
942 model: Arc<dyn LanguageModel>,
943 ) -> Vec<LanguageModelRequestTool> {
944 if model.supports_tools() {
945 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
946 .into_iter()
947 .filter_map(|(name, tool)| {
948 // Skip tools that cannot be supported
949 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
950 Some(LanguageModelRequestTool {
951 name,
952 description: tool.description(),
953 input_schema,
954 })
955 })
956 .collect()
957 } else {
958 Vec::default()
959 }
960 }
961
962 pub fn insert_user_message(
963 &mut self,
964 text: impl Into<String>,
965 loaded_context: ContextLoadResult,
966 git_checkpoint: Option<GitStoreCheckpoint>,
967 creases: Vec<MessageCrease>,
968 cx: &mut Context<Self>,
969 ) -> MessageId {
970 if !loaded_context.referenced_buffers.is_empty() {
971 self.action_log.update(cx, |log, cx| {
972 for buffer in loaded_context.referenced_buffers {
973 log.buffer_read(buffer, cx);
974 }
975 });
976 }
977
978 let message_id = self.insert_message(
979 Role::User,
980 vec![MessageSegment::Text(text.into())],
981 loaded_context.loaded_context,
982 creases,
983 false,
984 cx,
985 );
986
987 if let Some(git_checkpoint) = git_checkpoint {
988 self.pending_checkpoint = Some(ThreadCheckpoint {
989 message_id,
990 git_checkpoint,
991 });
992 }
993
994 self.auto_capture_telemetry(cx);
995
996 message_id
997 }
998
999 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1000 let id = self.insert_message(
1001 Role::User,
1002 vec![MessageSegment::Text("Continue where you left off".into())],
1003 LoadedContext::default(),
1004 vec![],
1005 true,
1006 cx,
1007 );
1008 self.pending_checkpoint = None;
1009
1010 id
1011 }
1012
1013 pub fn insert_assistant_message(
1014 &mut self,
1015 segments: Vec<MessageSegment>,
1016 cx: &mut Context<Self>,
1017 ) -> MessageId {
1018 self.insert_message(
1019 Role::Assistant,
1020 segments,
1021 LoadedContext::default(),
1022 Vec::new(),
1023 false,
1024 cx,
1025 )
1026 }
1027
1028 pub fn insert_message(
1029 &mut self,
1030 role: Role,
1031 segments: Vec<MessageSegment>,
1032 loaded_context: LoadedContext,
1033 creases: Vec<MessageCrease>,
1034 is_hidden: bool,
1035 cx: &mut Context<Self>,
1036 ) -> MessageId {
1037 let id = self.next_message_id.post_inc();
1038 self.messages.push(Message {
1039 id,
1040 role,
1041 segments,
1042 loaded_context,
1043 creases,
1044 is_hidden,
1045 });
1046 self.touch_updated_at();
1047 cx.emit(ThreadEvent::MessageAdded(id));
1048 id
1049 }
1050
1051 pub fn edit_message(
1052 &mut self,
1053 id: MessageId,
1054 new_role: Role,
1055 new_segments: Vec<MessageSegment>,
1056 creases: Vec<MessageCrease>,
1057 loaded_context: Option<LoadedContext>,
1058 checkpoint: Option<GitStoreCheckpoint>,
1059 cx: &mut Context<Self>,
1060 ) -> bool {
1061 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1062 return false;
1063 };
1064 message.role = new_role;
1065 message.segments = new_segments;
1066 message.creases = creases;
1067 if let Some(context) = loaded_context {
1068 message.loaded_context = context;
1069 }
1070 if let Some(git_checkpoint) = checkpoint {
1071 self.checkpoints_by_message.insert(
1072 id,
1073 ThreadCheckpoint {
1074 message_id: id,
1075 git_checkpoint,
1076 },
1077 );
1078 }
1079 self.touch_updated_at();
1080 cx.emit(ThreadEvent::MessageEdited(id));
1081 true
1082 }
1083
1084 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1085 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1086 return false;
1087 };
1088 self.messages.remove(index);
1089 self.touch_updated_at();
1090 cx.emit(ThreadEvent::MessageDeleted(id));
1091 true
1092 }
1093
1094 /// Returns the representation of this [`Thread`] in a textual form.
1095 ///
1096 /// This is the representation we use when attaching a thread as context to another thread.
1097 pub fn text(&self) -> String {
1098 let mut text = String::new();
1099
1100 for message in &self.messages {
1101 text.push_str(match message.role {
1102 language_model::Role::User => "User:",
1103 language_model::Role::Assistant => "Agent:",
1104 language_model::Role::System => "System:",
1105 });
1106 text.push('\n');
1107
1108 for segment in &message.segments {
1109 match segment {
1110 MessageSegment::Text(content) => text.push_str(content),
1111 MessageSegment::Thinking { text: content, .. } => {
1112 text.push_str(&format!("<think>{}</think>", content))
1113 }
1114 MessageSegment::RedactedThinking(_) => {}
1115 }
1116 }
1117 text.push('\n');
1118 }
1119
1120 text
1121 }
1122
1123 /// Serializes this thread into a format for storage or telemetry.
1124 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1125 let initial_project_snapshot = self.initial_project_snapshot.clone();
1126 cx.spawn(async move |this, cx| {
1127 let initial_project_snapshot = initial_project_snapshot.await;
1128 this.read_with(cx, |this, cx| SerializedThread {
1129 version: SerializedThread::VERSION.to_string(),
1130 summary: this.summary().or_default(),
1131 updated_at: this.updated_at(),
1132 messages: this
1133 .messages()
1134 .map(|message| SerializedMessage {
1135 id: message.id,
1136 role: message.role,
1137 segments: message
1138 .segments
1139 .iter()
1140 .map(|segment| match segment {
1141 MessageSegment::Text(text) => {
1142 SerializedMessageSegment::Text { text: text.clone() }
1143 }
1144 MessageSegment::Thinking { text, signature } => {
1145 SerializedMessageSegment::Thinking {
1146 text: text.clone(),
1147 signature: signature.clone(),
1148 }
1149 }
1150 MessageSegment::RedactedThinking(data) => {
1151 SerializedMessageSegment::RedactedThinking {
1152 data: data.clone(),
1153 }
1154 }
1155 })
1156 .collect(),
1157 tool_uses: this
1158 .tool_uses_for_message(message.id, cx)
1159 .into_iter()
1160 .map(|tool_use| SerializedToolUse {
1161 id: tool_use.id,
1162 name: tool_use.name,
1163 input: tool_use.input,
1164 })
1165 .collect(),
1166 tool_results: this
1167 .tool_results_for_message(message.id)
1168 .into_iter()
1169 .map(|tool_result| SerializedToolResult {
1170 tool_use_id: tool_result.tool_use_id.clone(),
1171 is_error: tool_result.is_error,
1172 content: tool_result.content.clone(),
1173 output: tool_result.output.clone(),
1174 })
1175 .collect(),
1176 context: message.loaded_context.text.clone(),
1177 creases: message
1178 .creases
1179 .iter()
1180 .map(|crease| SerializedCrease {
1181 start: crease.range.start,
1182 end: crease.range.end,
1183 icon_path: crease.metadata.icon_path.clone(),
1184 label: crease.metadata.label.clone(),
1185 })
1186 .collect(),
1187 is_hidden: message.is_hidden,
1188 })
1189 .collect(),
1190 initial_project_snapshot,
1191 cumulative_token_usage: this.cumulative_token_usage,
1192 request_token_usage: this.request_token_usage.clone(),
1193 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1194 exceeded_window_error: this.exceeded_window_error.clone(),
1195 model: this
1196 .configured_model
1197 .as_ref()
1198 .map(|model| SerializedLanguageModel {
1199 provider: model.provider.id().0.to_string(),
1200 model: model.model.id().0.to_string(),
1201 }),
1202 completion_mode: Some(this.completion_mode),
1203 tool_use_limit_reached: this.tool_use_limit_reached,
1204 profile: Some(this.profile.id().clone()),
1205 })
1206 })
1207 }
1208
1209 pub fn remaining_turns(&self) -> u32 {
1210 self.remaining_turns
1211 }
1212
1213 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1214 self.remaining_turns = remaining_turns;
1215 }
1216
1217 pub fn send_to_model(
1218 &mut self,
1219 model: Arc<dyn LanguageModel>,
1220 intent: CompletionIntent,
1221 window: Option<AnyWindowHandle>,
1222 cx: &mut Context<Self>,
1223 ) {
1224 if self.remaining_turns == 0 {
1225 return;
1226 }
1227
1228 self.remaining_turns -= 1;
1229
1230 let request = self.to_completion_request(model.clone(), intent, cx);
1231
1232 self.stream_completion(request, model, window, cx);
1233 }
1234
1235 pub fn used_tools_since_last_user_message(&self) -> bool {
1236 for message in self.messages.iter().rev() {
1237 if self.tool_use.message_has_tool_results(message.id) {
1238 return true;
1239 } else if message.role == Role::User {
1240 return false;
1241 }
1242 }
1243
1244 false
1245 }
1246
1247 pub fn to_completion_request(
1248 &self,
1249 model: Arc<dyn LanguageModel>,
1250 intent: CompletionIntent,
1251 cx: &mut Context<Self>,
1252 ) -> LanguageModelRequest {
1253 let mut request = LanguageModelRequest {
1254 thread_id: Some(self.id.to_string()),
1255 prompt_id: Some(self.last_prompt_id.to_string()),
1256 intent: Some(intent),
1257 mode: None,
1258 messages: vec![],
1259 tools: Vec::new(),
1260 tool_choice: None,
1261 stop: Vec::new(),
1262 temperature: AgentSettings::temperature_for_model(&model, cx),
1263 };
1264
1265 let available_tools = self.available_tools(cx, model.clone());
1266 let available_tool_names = available_tools
1267 .iter()
1268 .map(|tool| tool.name.clone())
1269 .collect();
1270
1271 let model_context = &ModelContext {
1272 available_tools: available_tool_names,
1273 };
1274
1275 if let Some(project_context) = self.project_context.borrow().as_ref() {
1276 match self
1277 .prompt_builder
1278 .generate_assistant_system_prompt(project_context, model_context)
1279 {
1280 Err(err) => {
1281 let message = format!("{err:?}").into();
1282 log::error!("{message}");
1283 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1284 header: "Error generating system prompt".into(),
1285 message,
1286 }));
1287 }
1288 Ok(system_prompt) => {
1289 request.messages.push(LanguageModelRequestMessage {
1290 role: Role::System,
1291 content: vec![MessageContent::Text(system_prompt)],
1292 cache: true,
1293 });
1294 }
1295 }
1296 } else {
1297 let message = "Context for system prompt unexpectedly not ready.".into();
1298 log::error!("{message}");
1299 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1300 header: "Error generating system prompt".into(),
1301 message,
1302 }));
1303 }
1304
1305 let mut message_ix_to_cache = None;
1306 for message in &self.messages {
1307 let mut request_message = LanguageModelRequestMessage {
1308 role: message.role,
1309 content: Vec::new(),
1310 cache: false,
1311 };
1312
1313 message
1314 .loaded_context
1315 .add_to_request_message(&mut request_message);
1316
1317 for segment in &message.segments {
1318 match segment {
1319 MessageSegment::Text(text) => {
1320 if !text.is_empty() {
1321 request_message
1322 .content
1323 .push(MessageContent::Text(text.into()));
1324 }
1325 }
1326 MessageSegment::Thinking { text, signature } => {
1327 if !text.is_empty() {
1328 request_message.content.push(MessageContent::Thinking {
1329 text: text.into(),
1330 signature: signature.clone(),
1331 });
1332 }
1333 }
1334 MessageSegment::RedactedThinking(data) => {
1335 request_message
1336 .content
1337 .push(MessageContent::RedactedThinking(data.clone()));
1338 }
1339 };
1340 }
1341
1342 let mut cache_message = true;
1343 let mut tool_results_message = LanguageModelRequestMessage {
1344 role: Role::User,
1345 content: Vec::new(),
1346 cache: false,
1347 };
1348 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1349 if let Some(tool_result) = tool_result {
1350 request_message
1351 .content
1352 .push(MessageContent::ToolUse(tool_use.clone()));
1353 tool_results_message
1354 .content
1355 .push(MessageContent::ToolResult(LanguageModelToolResult {
1356 tool_use_id: tool_use.id.clone(),
1357 tool_name: tool_result.tool_name.clone(),
1358 is_error: tool_result.is_error,
1359 content: if tool_result.content.is_empty() {
1360 // Surprisingly, the API fails if we return an empty string here.
1361 // It thinks we are sending a tool use without a tool result.
1362 "<Tool returned an empty string>".into()
1363 } else {
1364 tool_result.content.clone()
1365 },
1366 output: None,
1367 }));
1368 } else {
1369 cache_message = false;
1370 log::debug!(
1371 "skipped tool use {:?} because it is still pending",
1372 tool_use
1373 );
1374 }
1375 }
1376
1377 if cache_message {
1378 message_ix_to_cache = Some(request.messages.len());
1379 }
1380 request.messages.push(request_message);
1381
1382 if !tool_results_message.content.is_empty() {
1383 if cache_message {
1384 message_ix_to_cache = Some(request.messages.len());
1385 }
1386 request.messages.push(tool_results_message);
1387 }
1388 }
1389
1390 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1391 if let Some(message_ix_to_cache) = message_ix_to_cache {
1392 request.messages[message_ix_to_cache].cache = true;
1393 }
1394
1395 self.attach_tracked_files_state(&mut request.messages, cx);
1396
1397 request.tools = available_tools;
1398 request.mode = if model.supports_max_mode() {
1399 Some(self.completion_mode.into())
1400 } else {
1401 Some(CompletionMode::Normal.into())
1402 };
1403
1404 request
1405 }
1406
1407 fn to_summarize_request(
1408 &self,
1409 model: &Arc<dyn LanguageModel>,
1410 intent: CompletionIntent,
1411 added_user_message: String,
1412 cx: &App,
1413 ) -> LanguageModelRequest {
1414 let mut request = LanguageModelRequest {
1415 thread_id: None,
1416 prompt_id: None,
1417 intent: Some(intent),
1418 mode: None,
1419 messages: vec![],
1420 tools: Vec::new(),
1421 tool_choice: None,
1422 stop: Vec::new(),
1423 temperature: AgentSettings::temperature_for_model(model, cx),
1424 };
1425
1426 for message in &self.messages {
1427 let mut request_message = LanguageModelRequestMessage {
1428 role: message.role,
1429 content: Vec::new(),
1430 cache: false,
1431 };
1432
1433 for segment in &message.segments {
1434 match segment {
1435 MessageSegment::Text(text) => request_message
1436 .content
1437 .push(MessageContent::Text(text.clone())),
1438 MessageSegment::Thinking { .. } => {}
1439 MessageSegment::RedactedThinking(_) => {}
1440 }
1441 }
1442
1443 if request_message.content.is_empty() {
1444 continue;
1445 }
1446
1447 request.messages.push(request_message);
1448 }
1449
1450 request.messages.push(LanguageModelRequestMessage {
1451 role: Role::User,
1452 content: vec![MessageContent::Text(added_user_message)],
1453 cache: false,
1454 });
1455
1456 request
1457 }
1458
1459 fn attach_tracked_files_state(
1460 &self,
1461 messages: &mut Vec<LanguageModelRequestMessage>,
1462 cx: &App,
1463 ) {
1464 let mut stale_files = String::new();
1465
1466 let action_log = self.action_log.read(cx);
1467
1468 for stale_file in action_log.stale_buffers(cx) {
1469 if let Some(file) = stale_file.read(cx).file() {
1470 writeln!(&mut stale_files, "- {}", file.path().display()).ok();
1471 }
1472 }
1473
1474 if stale_files.is_empty() {
1475 return;
1476 }
1477
1478 // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
1479 const STALE_FILES_HEADER: &str = include_str!("./prompts/stale_files_prompt_header.txt");
1480 let content = MessageContent::Text(
1481 format!("{STALE_FILES_HEADER}{stale_files}").replace("\r\n", "\n"),
1482 );
1483
1484 // Insert our message before the last Assistant message.
1485 // Inserting it to the tail distracts the agent too much
1486 let insert_position = messages
1487 .iter()
1488 .enumerate()
1489 .rfind(|(_, message)| message.role == Role::Assistant)
1490 .map_or(messages.len(), |(i, _)| i);
1491
1492 let request_message = LanguageModelRequestMessage {
1493 role: Role::User,
1494 content: vec![content],
1495 cache: false,
1496 };
1497
1498 messages.insert(insert_position, request_message);
1499
1500 // It makes no sense to cache messages after this one because
1501 // the cache is invalidated when this message is gone.
1502 // Move the cache marker before this message.
1503 let has_cached_messages_after = messages
1504 .iter()
1505 .skip(insert_position + 1)
1506 .any(|message| message.cache);
1507
1508 if has_cached_messages_after {
1509 messages[insert_position - 1].cache = true;
1510 }
1511 }
1512
1513 pub fn stream_completion(
1514 &mut self,
1515 request: LanguageModelRequest,
1516 model: Arc<dyn LanguageModel>,
1517 window: Option<AnyWindowHandle>,
1518 cx: &mut Context<Self>,
1519 ) {
1520 self.tool_use_limit_reached = false;
1521
1522 let pending_completion_id = post_inc(&mut self.completion_count);
1523 let mut request_callback_parameters = if self.request_callback.is_some() {
1524 Some((request.clone(), Vec::new()))
1525 } else {
1526 None
1527 };
1528 let prompt_id = self.last_prompt_id.clone();
1529 let tool_use_metadata = ToolUseMetadata {
1530 model: model.clone(),
1531 thread_id: self.id.clone(),
1532 prompt_id: prompt_id.clone(),
1533 };
1534
1535 self.last_received_chunk_at = Some(Instant::now());
1536
1537 let task = cx.spawn(async move |thread, cx| {
1538 let stream_completion_future = model.stream_completion(request, &cx);
1539 let initial_token_usage =
1540 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1541 let stream_completion = async {
1542 let mut events = stream_completion_future.await?;
1543
1544 let mut stop_reason = StopReason::EndTurn;
1545 let mut current_token_usage = TokenUsage::default();
1546
1547 thread
1548 .update(cx, |_thread, cx| {
1549 cx.emit(ThreadEvent::NewRequest);
1550 })
1551 .ok();
1552
1553 let mut request_assistant_message_id = None;
1554
1555 while let Some(event) = events.next().await {
1556 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1557 response_events
1558 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1559 }
1560
1561 thread.update(cx, |thread, cx| {
1562 let event = match event {
1563 Ok(event) => event,
1564 Err(LanguageModelCompletionError::BadInputJson {
1565 id,
1566 tool_name,
1567 raw_input: invalid_input_json,
1568 json_parse_error,
1569 }) => {
1570 thread.receive_invalid_tool_json(
1571 id,
1572 tool_name,
1573 invalid_input_json,
1574 json_parse_error,
1575 window,
1576 cx,
1577 );
1578 return Ok(());
1579 }
1580 Err(LanguageModelCompletionError::Other(error)) => {
1581 return Err(error);
1582 }
1583 Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
1584 return Err(err.into());
1585 }
1586 };
1587
1588 match event {
1589 LanguageModelCompletionEvent::StartMessage { .. } => {
1590 request_assistant_message_id =
1591 Some(thread.insert_assistant_message(
1592 vec![MessageSegment::Text(String::new())],
1593 cx,
1594 ));
1595 }
1596 LanguageModelCompletionEvent::Stop(reason) => {
1597 stop_reason = reason;
1598 }
1599 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1600 thread.update_token_usage_at_last_message(token_usage);
1601 thread.cumulative_token_usage = thread.cumulative_token_usage
1602 + token_usage
1603 - current_token_usage;
1604 current_token_usage = token_usage;
1605 }
1606 LanguageModelCompletionEvent::Text(chunk) => {
1607 thread.received_chunk();
1608
1609 cx.emit(ThreadEvent::ReceivedTextChunk);
1610 if let Some(last_message) = thread.messages.last_mut() {
1611 if last_message.role == Role::Assistant
1612 && !thread.tool_use.has_tool_results(last_message.id)
1613 {
1614 last_message.push_text(&chunk);
1615 cx.emit(ThreadEvent::StreamedAssistantText(
1616 last_message.id,
1617 chunk,
1618 ));
1619 } else {
1620 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1621 // of a new Assistant response.
1622 //
1623 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1624 // will result in duplicating the text of the chunk in the rendered Markdown.
1625 request_assistant_message_id =
1626 Some(thread.insert_assistant_message(
1627 vec![MessageSegment::Text(chunk.to_string())],
1628 cx,
1629 ));
1630 };
1631 }
1632 }
1633 LanguageModelCompletionEvent::Thinking {
1634 text: chunk,
1635 signature,
1636 } => {
1637 thread.received_chunk();
1638
1639 if let Some(last_message) = thread.messages.last_mut() {
1640 if last_message.role == Role::Assistant
1641 && !thread.tool_use.has_tool_results(last_message.id)
1642 {
1643 last_message.push_thinking(&chunk, signature);
1644 cx.emit(ThreadEvent::StreamedAssistantThinking(
1645 last_message.id,
1646 chunk,
1647 ));
1648 } else {
1649 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1650 // of a new Assistant response.
1651 //
1652 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1653 // will result in duplicating the text of the chunk in the rendered Markdown.
1654 request_assistant_message_id =
1655 Some(thread.insert_assistant_message(
1656 vec![MessageSegment::Thinking {
1657 text: chunk.to_string(),
1658 signature,
1659 }],
1660 cx,
1661 ));
1662 };
1663 }
1664 }
1665 LanguageModelCompletionEvent::RedactedThinking {
1666 data
1667 } => {
1668 thread.received_chunk();
1669
1670 if let Some(last_message) = thread.messages.last_mut() {
1671 if last_message.role == Role::Assistant
1672 && !thread.tool_use.has_tool_results(last_message.id)
1673 {
1674 last_message.push_redacted_thinking(data);
1675 } else {
1676 request_assistant_message_id =
1677 Some(thread.insert_assistant_message(
1678 vec![MessageSegment::RedactedThinking(data)],
1679 cx,
1680 ));
1681 };
1682 }
1683 }
1684 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1685 let last_assistant_message_id = request_assistant_message_id
1686 .unwrap_or_else(|| {
1687 let new_assistant_message_id =
1688 thread.insert_assistant_message(vec![], cx);
1689 request_assistant_message_id =
1690 Some(new_assistant_message_id);
1691 new_assistant_message_id
1692 });
1693
1694 let tool_use_id = tool_use.id.clone();
1695 let streamed_input = if tool_use.is_input_complete {
1696 None
1697 } else {
1698 Some((&tool_use.input).clone())
1699 };
1700
1701 let ui_text = thread.tool_use.request_tool_use(
1702 last_assistant_message_id,
1703 tool_use,
1704 tool_use_metadata.clone(),
1705 cx,
1706 );
1707
1708 if let Some(input) = streamed_input {
1709 cx.emit(ThreadEvent::StreamedToolUse {
1710 tool_use_id,
1711 ui_text,
1712 input,
1713 });
1714 }
1715 }
1716 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1717 if let Some(completion) = thread
1718 .pending_completions
1719 .iter_mut()
1720 .find(|completion| completion.id == pending_completion_id)
1721 {
1722 match status_update {
1723 CompletionRequestStatus::Queued {
1724 position,
1725 } => {
1726 completion.queue_state = QueueState::Queued { position };
1727 }
1728 CompletionRequestStatus::Started => {
1729 completion.queue_state = QueueState::Started;
1730 }
1731 CompletionRequestStatus::Failed {
1732 code, message, request_id
1733 } => {
1734 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1735 }
1736 CompletionRequestStatus::UsageUpdated {
1737 amount, limit
1738 } => {
1739 let usage = RequestUsage { limit, amount: amount as i32 };
1740
1741 thread.last_usage = Some(usage);
1742 }
1743 CompletionRequestStatus::ToolUseLimitReached => {
1744 thread.tool_use_limit_reached = true;
1745 cx.emit(ThreadEvent::ToolUseLimitReached);
1746 }
1747 }
1748 }
1749 }
1750 }
1751
1752 thread.touch_updated_at();
1753 cx.emit(ThreadEvent::StreamedCompletion);
1754 cx.notify();
1755
1756 thread.auto_capture_telemetry(cx);
1757 Ok(())
1758 })??;
1759
1760 smol::future::yield_now().await;
1761 }
1762
1763 thread.update(cx, |thread, cx| {
1764 thread.last_received_chunk_at = None;
1765 thread
1766 .pending_completions
1767 .retain(|completion| completion.id != pending_completion_id);
1768
1769 // If there is a response without tool use, summarize the message. Otherwise,
1770 // allow two tool uses before summarizing.
1771 if matches!(thread.summary, ThreadSummary::Pending)
1772 && thread.messages.len() >= 2
1773 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1774 {
1775 thread.summarize(cx);
1776 }
1777 })?;
1778
1779 anyhow::Ok(stop_reason)
1780 };
1781
1782 let result = stream_completion.await;
1783
1784 thread
1785 .update(cx, |thread, cx| {
1786 thread.finalize_pending_checkpoint(cx);
1787 match result.as_ref() {
1788 Ok(stop_reason) => match stop_reason {
1789 StopReason::ToolUse => {
1790 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1791 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1792 }
1793 StopReason::EndTurn | StopReason::MaxTokens => {
1794 thread.project.update(cx, |project, cx| {
1795 project.set_agent_location(None, cx);
1796 });
1797 }
1798 StopReason::Refusal => {
1799 thread.project.update(cx, |project, cx| {
1800 project.set_agent_location(None, cx);
1801 });
1802
1803 // Remove the turn that was refused.
1804 //
1805 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1806 {
1807 let mut messages_to_remove = Vec::new();
1808
1809 for (ix, message) in thread.messages.iter().enumerate().rev() {
1810 messages_to_remove.push(message.id);
1811
1812 if message.role == Role::User {
1813 if ix == 0 {
1814 break;
1815 }
1816
1817 if let Some(prev_message) = thread.messages.get(ix - 1) {
1818 if prev_message.role == Role::Assistant {
1819 break;
1820 }
1821 }
1822 }
1823 }
1824
1825 for message_id in messages_to_remove {
1826 thread.delete_message(message_id, cx);
1827 }
1828 }
1829
1830 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1831 header: "Language model refusal".into(),
1832 message: "Model refused to generate content for safety reasons.".into(),
1833 }));
1834 }
1835 },
1836 Err(error) => {
1837 thread.project.update(cx, |project, cx| {
1838 project.set_agent_location(None, cx);
1839 });
1840
1841 if error.is::<PaymentRequiredError>() {
1842 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1843 } else if let Some(error) =
1844 error.downcast_ref::<ModelRequestLimitReachedError>()
1845 {
1846 cx.emit(ThreadEvent::ShowError(
1847 ThreadError::ModelRequestLimitReached { plan: error.plan },
1848 ));
1849 } else if let Some(known_error) =
1850 error.downcast_ref::<LanguageModelKnownError>()
1851 {
1852 match known_error {
1853 LanguageModelKnownError::ContextWindowLimitExceeded {
1854 tokens,
1855 } => {
1856 thread.exceeded_window_error = Some(ExceededWindowError {
1857 model_id: model.id(),
1858 token_count: *tokens,
1859 });
1860 cx.notify();
1861 }
1862 }
1863 } else {
1864 let error_message = error
1865 .chain()
1866 .map(|err| err.to_string())
1867 .collect::<Vec<_>>()
1868 .join("\n");
1869 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1870 header: "Error interacting with language model".into(),
1871 message: SharedString::from(error_message.clone()),
1872 }));
1873 }
1874
1875 thread.cancel_last_completion(window, cx);
1876 }
1877 }
1878
1879 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1880
1881 if let Some((request_callback, (request, response_events))) = thread
1882 .request_callback
1883 .as_mut()
1884 .zip(request_callback_parameters.as_ref())
1885 {
1886 request_callback(request, response_events);
1887 }
1888
1889 thread.auto_capture_telemetry(cx);
1890
1891 if let Ok(initial_usage) = initial_token_usage {
1892 let usage = thread.cumulative_token_usage - initial_usage;
1893
1894 telemetry::event!(
1895 "Assistant Thread Completion",
1896 thread_id = thread.id().to_string(),
1897 prompt_id = prompt_id,
1898 model = model.telemetry_id(),
1899 model_provider = model.provider_id().to_string(),
1900 input_tokens = usage.input_tokens,
1901 output_tokens = usage.output_tokens,
1902 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1903 cache_read_input_tokens = usage.cache_read_input_tokens,
1904 );
1905 }
1906 })
1907 .ok();
1908 });
1909
1910 self.pending_completions.push(PendingCompletion {
1911 id: pending_completion_id,
1912 queue_state: QueueState::Sending,
1913 _task: task,
1914 });
1915 }
1916
1917 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1918 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1919 println!("No thread summary model");
1920 return;
1921 };
1922
1923 if !model.provider.is_authenticated(cx) {
1924 return;
1925 }
1926
1927 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1928
1929 let request = self.to_summarize_request(
1930 &model.model,
1931 CompletionIntent::ThreadSummarization,
1932 added_user_message.into(),
1933 cx,
1934 );
1935
1936 self.summary = ThreadSummary::Generating;
1937
1938 self.pending_summary = cx.spawn(async move |this, cx| {
1939 let result = async {
1940 let mut messages = model.model.stream_completion(request, &cx).await?;
1941
1942 let mut new_summary = String::new();
1943 while let Some(event) = messages.next().await {
1944 let Ok(event) = event else {
1945 continue;
1946 };
1947 let text = match event {
1948 LanguageModelCompletionEvent::Text(text) => text,
1949 LanguageModelCompletionEvent::StatusUpdate(
1950 CompletionRequestStatus::UsageUpdated { amount, limit },
1951 ) => {
1952 this.update(cx, |thread, _cx| {
1953 thread.last_usage = Some(RequestUsage {
1954 limit,
1955 amount: amount as i32,
1956 });
1957 })?;
1958 continue;
1959 }
1960 _ => continue,
1961 };
1962
1963 let mut lines = text.lines();
1964 new_summary.extend(lines.next());
1965
1966 // Stop if the LLM generated multiple lines.
1967 if lines.next().is_some() {
1968 break;
1969 }
1970 }
1971
1972 anyhow::Ok(new_summary)
1973 }
1974 .await;
1975
1976 this.update(cx, |this, cx| {
1977 match result {
1978 Ok(new_summary) => {
1979 if new_summary.is_empty() {
1980 this.summary = ThreadSummary::Error;
1981 } else {
1982 this.summary = ThreadSummary::Ready(new_summary.into());
1983 }
1984 }
1985 Err(err) => {
1986 this.summary = ThreadSummary::Error;
1987 log::error!("Failed to generate thread summary: {}", err);
1988 }
1989 }
1990 cx.emit(ThreadEvent::SummaryGenerated);
1991 })
1992 .log_err()?;
1993
1994 Some(())
1995 });
1996 }
1997
1998 pub fn start_generating_detailed_summary_if_needed(
1999 &mut self,
2000 thread_store: WeakEntity<ThreadStore>,
2001 cx: &mut Context<Self>,
2002 ) {
2003 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2004 return;
2005 };
2006
2007 match &*self.detailed_summary_rx.borrow() {
2008 DetailedSummaryState::Generating { message_id, .. }
2009 | DetailedSummaryState::Generated { message_id, .. }
2010 if *message_id == last_message_id =>
2011 {
2012 // Already up-to-date
2013 return;
2014 }
2015 _ => {}
2016 }
2017
2018 let Some(ConfiguredModel { model, provider }) =
2019 LanguageModelRegistry::read_global(cx).thread_summary_model()
2020 else {
2021 return;
2022 };
2023
2024 if !provider.is_authenticated(cx) {
2025 return;
2026 }
2027
2028 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2029
2030 let request = self.to_summarize_request(
2031 &model,
2032 CompletionIntent::ThreadContextSummarization,
2033 added_user_message.into(),
2034 cx,
2035 );
2036
2037 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2038 message_id: last_message_id,
2039 };
2040
2041 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2042 // be better to allow the old task to complete, but this would require logic for choosing
2043 // which result to prefer (the old task could complete after the new one, resulting in a
2044 // stale summary).
2045 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2046 let stream = model.stream_completion_text(request, &cx);
2047 let Some(mut messages) = stream.await.log_err() else {
2048 thread
2049 .update(cx, |thread, _cx| {
2050 *thread.detailed_summary_tx.borrow_mut() =
2051 DetailedSummaryState::NotGenerated;
2052 })
2053 .ok()?;
2054 return None;
2055 };
2056
2057 let mut new_detailed_summary = String::new();
2058
2059 while let Some(chunk) = messages.stream.next().await {
2060 if let Some(chunk) = chunk.log_err() {
2061 new_detailed_summary.push_str(&chunk);
2062 }
2063 }
2064
2065 thread
2066 .update(cx, |thread, _cx| {
2067 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2068 text: new_detailed_summary.into(),
2069 message_id: last_message_id,
2070 };
2071 })
2072 .ok()?;
2073
2074 // Save thread so its summary can be reused later
2075 if let Some(thread) = thread.upgrade() {
2076 if let Ok(Ok(save_task)) = cx.update(|cx| {
2077 thread_store
2078 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2079 }) {
2080 save_task.await.log_err();
2081 }
2082 }
2083
2084 Some(())
2085 });
2086 }
2087
2088 pub async fn wait_for_detailed_summary_or_text(
2089 this: &Entity<Self>,
2090 cx: &mut AsyncApp,
2091 ) -> Option<SharedString> {
2092 let mut detailed_summary_rx = this
2093 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2094 .ok()?;
2095 loop {
2096 match detailed_summary_rx.recv().await? {
2097 DetailedSummaryState::Generating { .. } => {}
2098 DetailedSummaryState::NotGenerated => {
2099 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2100 }
2101 DetailedSummaryState::Generated { text, .. } => return Some(text),
2102 }
2103 }
2104 }
2105
2106 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2107 self.detailed_summary_rx
2108 .borrow()
2109 .text()
2110 .unwrap_or_else(|| self.text().into())
2111 }
2112
2113 pub fn is_generating_detailed_summary(&self) -> bool {
2114 matches!(
2115 &*self.detailed_summary_rx.borrow(),
2116 DetailedSummaryState::Generating { .. }
2117 )
2118 }
2119
2120 pub fn use_pending_tools(
2121 &mut self,
2122 window: Option<AnyWindowHandle>,
2123 cx: &mut Context<Self>,
2124 model: Arc<dyn LanguageModel>,
2125 ) -> Vec<PendingToolUse> {
2126 self.auto_capture_telemetry(cx);
2127 let request =
2128 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2129 let pending_tool_uses = self
2130 .tool_use
2131 .pending_tool_uses()
2132 .into_iter()
2133 .filter(|tool_use| tool_use.status.is_idle())
2134 .cloned()
2135 .collect::<Vec<_>>();
2136
2137 for tool_use in pending_tool_uses.iter() {
2138 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2139 if tool.needs_confirmation(&tool_use.input, cx)
2140 && !AgentSettings::get_global(cx).always_allow_tool_actions
2141 {
2142 self.tool_use.confirm_tool_use(
2143 tool_use.id.clone(),
2144 tool_use.ui_text.clone(),
2145 tool_use.input.clone(),
2146 request.clone(),
2147 tool,
2148 );
2149 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2150 } else {
2151 self.run_tool(
2152 tool_use.id.clone(),
2153 tool_use.ui_text.clone(),
2154 tool_use.input.clone(),
2155 request.clone(),
2156 tool,
2157 model.clone(),
2158 window,
2159 cx,
2160 );
2161 }
2162 } else {
2163 self.handle_hallucinated_tool_use(
2164 tool_use.id.clone(),
2165 tool_use.name.clone(),
2166 window,
2167 cx,
2168 );
2169 }
2170 }
2171
2172 pending_tool_uses
2173 }
2174
2175 pub fn handle_hallucinated_tool_use(
2176 &mut self,
2177 tool_use_id: LanguageModelToolUseId,
2178 hallucinated_tool_name: Arc<str>,
2179 window: Option<AnyWindowHandle>,
2180 cx: &mut Context<Thread>,
2181 ) {
2182 let available_tools = self.profile.enabled_tools(cx);
2183
2184 let tool_list = available_tools
2185 .iter()
2186 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2187 .collect::<Vec<_>>()
2188 .join("\n");
2189
2190 let error_message = format!(
2191 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2192 hallucinated_tool_name, tool_list
2193 );
2194
2195 let pending_tool_use = self.tool_use.insert_tool_output(
2196 tool_use_id.clone(),
2197 hallucinated_tool_name,
2198 Err(anyhow!("Missing tool call: {error_message}")),
2199 self.configured_model.as_ref(),
2200 );
2201
2202 cx.emit(ThreadEvent::MissingToolUse {
2203 tool_use_id: tool_use_id.clone(),
2204 ui_text: error_message.into(),
2205 });
2206
2207 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2208 }
2209
2210 pub fn receive_invalid_tool_json(
2211 &mut self,
2212 tool_use_id: LanguageModelToolUseId,
2213 tool_name: Arc<str>,
2214 invalid_json: Arc<str>,
2215 error: String,
2216 window: Option<AnyWindowHandle>,
2217 cx: &mut Context<Thread>,
2218 ) {
2219 log::error!("The model returned invalid input JSON: {invalid_json}");
2220
2221 let pending_tool_use = self.tool_use.insert_tool_output(
2222 tool_use_id.clone(),
2223 tool_name,
2224 Err(anyhow!("Error parsing input JSON: {error}")),
2225 self.configured_model.as_ref(),
2226 );
2227 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2228 pending_tool_use.ui_text.clone()
2229 } else {
2230 log::error!(
2231 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2232 );
2233 format!("Unknown tool {}", tool_use_id).into()
2234 };
2235
2236 cx.emit(ThreadEvent::InvalidToolInput {
2237 tool_use_id: tool_use_id.clone(),
2238 ui_text,
2239 invalid_input_json: invalid_json,
2240 });
2241
2242 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2243 }
2244
2245 pub fn run_tool(
2246 &mut self,
2247 tool_use_id: LanguageModelToolUseId,
2248 ui_text: impl Into<SharedString>,
2249 input: serde_json::Value,
2250 request: Arc<LanguageModelRequest>,
2251 tool: Arc<dyn Tool>,
2252 model: Arc<dyn LanguageModel>,
2253 window: Option<AnyWindowHandle>,
2254 cx: &mut Context<Thread>,
2255 ) {
2256 let task =
2257 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2258 self.tool_use
2259 .run_pending_tool(tool_use_id, ui_text.into(), task);
2260 }
2261
2262 fn spawn_tool_use(
2263 &mut self,
2264 tool_use_id: LanguageModelToolUseId,
2265 request: Arc<LanguageModelRequest>,
2266 input: serde_json::Value,
2267 tool: Arc<dyn Tool>,
2268 model: Arc<dyn LanguageModel>,
2269 window: Option<AnyWindowHandle>,
2270 cx: &mut Context<Thread>,
2271 ) -> Task<()> {
2272 let tool_name: Arc<str> = tool.name().into();
2273
2274 let tool_result = tool.run(
2275 input,
2276 request,
2277 self.project.clone(),
2278 self.action_log.clone(),
2279 model,
2280 window,
2281 cx,
2282 );
2283
2284 // Store the card separately if it exists
2285 if let Some(card) = tool_result.card.clone() {
2286 self.tool_use
2287 .insert_tool_result_card(tool_use_id.clone(), card);
2288 }
2289
2290 cx.spawn({
2291 async move |thread: WeakEntity<Thread>, cx| {
2292 let output = tool_result.output.await;
2293
2294 thread
2295 .update(cx, |thread, cx| {
2296 let pending_tool_use = thread.tool_use.insert_tool_output(
2297 tool_use_id.clone(),
2298 tool_name,
2299 output,
2300 thread.configured_model.as_ref(),
2301 );
2302 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2303 })
2304 .ok();
2305 }
2306 })
2307 }
2308
2309 fn tool_finished(
2310 &mut self,
2311 tool_use_id: LanguageModelToolUseId,
2312 pending_tool_use: Option<PendingToolUse>,
2313 canceled: bool,
2314 window: Option<AnyWindowHandle>,
2315 cx: &mut Context<Self>,
2316 ) {
2317 if self.all_tools_finished() {
2318 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2319 if !canceled {
2320 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2321 }
2322 self.auto_capture_telemetry(cx);
2323 }
2324 }
2325
2326 cx.emit(ThreadEvent::ToolFinished {
2327 tool_use_id,
2328 pending_tool_use,
2329 });
2330 }
2331
2332 /// Cancels the last pending completion, if there are any pending.
2333 ///
2334 /// Returns whether a completion was canceled.
2335 pub fn cancel_last_completion(
2336 &mut self,
2337 window: Option<AnyWindowHandle>,
2338 cx: &mut Context<Self>,
2339 ) -> bool {
2340 let mut canceled = self.pending_completions.pop().is_some();
2341
2342 for pending_tool_use in self.tool_use.cancel_pending() {
2343 canceled = true;
2344 self.tool_finished(
2345 pending_tool_use.id.clone(),
2346 Some(pending_tool_use),
2347 true,
2348 window,
2349 cx,
2350 );
2351 }
2352
2353 if canceled {
2354 cx.emit(ThreadEvent::CompletionCanceled);
2355
2356 // When canceled, we always want to insert the checkpoint.
2357 // (We skip over finalize_pending_checkpoint, because it
2358 // would conclude we didn't have anything to insert here.)
2359 if let Some(checkpoint) = self.pending_checkpoint.take() {
2360 self.insert_checkpoint(checkpoint, cx);
2361 }
2362 } else {
2363 self.finalize_pending_checkpoint(cx);
2364 }
2365
2366 canceled
2367 }
2368
2369 /// Signals that any in-progress editing should be canceled.
2370 ///
2371 /// This method is used to notify listeners (like ActiveThread) that
2372 /// they should cancel any editing operations.
2373 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2374 cx.emit(ThreadEvent::CancelEditing);
2375 }
2376
2377 pub fn feedback(&self) -> Option<ThreadFeedback> {
2378 self.feedback
2379 }
2380
2381 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2382 self.message_feedback.get(&message_id).copied()
2383 }
2384
2385 pub fn report_message_feedback(
2386 &mut self,
2387 message_id: MessageId,
2388 feedback: ThreadFeedback,
2389 cx: &mut Context<Self>,
2390 ) -> Task<Result<()>> {
2391 if self.message_feedback.get(&message_id) == Some(&feedback) {
2392 return Task::ready(Ok(()));
2393 }
2394
2395 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2396 let serialized_thread = self.serialize(cx);
2397 let thread_id = self.id().clone();
2398 let client = self.project.read(cx).client();
2399
2400 let enabled_tool_names: Vec<String> = self
2401 .profile
2402 .enabled_tools(cx)
2403 .iter()
2404 .map(|tool| tool.name())
2405 .collect();
2406
2407 self.message_feedback.insert(message_id, feedback);
2408
2409 cx.notify();
2410
2411 let message_content = self
2412 .message(message_id)
2413 .map(|msg| msg.to_string())
2414 .unwrap_or_default();
2415
2416 cx.background_spawn(async move {
2417 let final_project_snapshot = final_project_snapshot.await;
2418 let serialized_thread = serialized_thread.await?;
2419 let thread_data =
2420 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2421
2422 let rating = match feedback {
2423 ThreadFeedback::Positive => "positive",
2424 ThreadFeedback::Negative => "negative",
2425 };
2426 telemetry::event!(
2427 "Assistant Thread Rated",
2428 rating,
2429 thread_id,
2430 enabled_tool_names,
2431 message_id = message_id.0,
2432 message_content,
2433 thread_data,
2434 final_project_snapshot
2435 );
2436 client.telemetry().flush_events().await;
2437
2438 Ok(())
2439 })
2440 }
2441
2442 pub fn report_feedback(
2443 &mut self,
2444 feedback: ThreadFeedback,
2445 cx: &mut Context<Self>,
2446 ) -> Task<Result<()>> {
2447 let last_assistant_message_id = self
2448 .messages
2449 .iter()
2450 .rev()
2451 .find(|msg| msg.role == Role::Assistant)
2452 .map(|msg| msg.id);
2453
2454 if let Some(message_id) = last_assistant_message_id {
2455 self.report_message_feedback(message_id, feedback, cx)
2456 } else {
2457 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2458 let serialized_thread = self.serialize(cx);
2459 let thread_id = self.id().clone();
2460 let client = self.project.read(cx).client();
2461 self.feedback = Some(feedback);
2462 cx.notify();
2463
2464 cx.background_spawn(async move {
2465 let final_project_snapshot = final_project_snapshot.await;
2466 let serialized_thread = serialized_thread.await?;
2467 let thread_data = serde_json::to_value(serialized_thread)
2468 .unwrap_or_else(|_| serde_json::Value::Null);
2469
2470 let rating = match feedback {
2471 ThreadFeedback::Positive => "positive",
2472 ThreadFeedback::Negative => "negative",
2473 };
2474 telemetry::event!(
2475 "Assistant Thread Rated",
2476 rating,
2477 thread_id,
2478 thread_data,
2479 final_project_snapshot
2480 );
2481 client.telemetry().flush_events().await;
2482
2483 Ok(())
2484 })
2485 }
2486 }
2487
2488 /// Create a snapshot of the current project state including git information and unsaved buffers.
2489 fn project_snapshot(
2490 project: Entity<Project>,
2491 cx: &mut Context<Self>,
2492 ) -> Task<Arc<ProjectSnapshot>> {
2493 let git_store = project.read(cx).git_store().clone();
2494 let worktree_snapshots: Vec<_> = project
2495 .read(cx)
2496 .visible_worktrees(cx)
2497 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2498 .collect();
2499
2500 cx.spawn(async move |_, cx| {
2501 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2502
2503 let mut unsaved_buffers = Vec::new();
2504 cx.update(|app_cx| {
2505 let buffer_store = project.read(app_cx).buffer_store();
2506 for buffer_handle in buffer_store.read(app_cx).buffers() {
2507 let buffer = buffer_handle.read(app_cx);
2508 if buffer.is_dirty() {
2509 if let Some(file) = buffer.file() {
2510 let path = file.path().to_string_lossy().to_string();
2511 unsaved_buffers.push(path);
2512 }
2513 }
2514 }
2515 })
2516 .ok();
2517
2518 Arc::new(ProjectSnapshot {
2519 worktree_snapshots,
2520 unsaved_buffer_paths: unsaved_buffers,
2521 timestamp: Utc::now(),
2522 })
2523 })
2524 }
2525
2526 fn worktree_snapshot(
2527 worktree: Entity<project::Worktree>,
2528 git_store: Entity<GitStore>,
2529 cx: &App,
2530 ) -> Task<WorktreeSnapshot> {
2531 cx.spawn(async move |cx| {
2532 // Get worktree path and snapshot
2533 let worktree_info = cx.update(|app_cx| {
2534 let worktree = worktree.read(app_cx);
2535 let path = worktree.abs_path().to_string_lossy().to_string();
2536 let snapshot = worktree.snapshot();
2537 (path, snapshot)
2538 });
2539
2540 let Ok((worktree_path, _snapshot)) = worktree_info else {
2541 return WorktreeSnapshot {
2542 worktree_path: String::new(),
2543 git_state: None,
2544 };
2545 };
2546
2547 let git_state = git_store
2548 .update(cx, |git_store, cx| {
2549 git_store
2550 .repositories()
2551 .values()
2552 .find(|repo| {
2553 repo.read(cx)
2554 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2555 .is_some()
2556 })
2557 .cloned()
2558 })
2559 .ok()
2560 .flatten()
2561 .map(|repo| {
2562 repo.update(cx, |repo, _| {
2563 let current_branch =
2564 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2565 repo.send_job(None, |state, _| async move {
2566 let RepositoryState::Local { backend, .. } = state else {
2567 return GitState {
2568 remote_url: None,
2569 head_sha: None,
2570 current_branch,
2571 diff: None,
2572 };
2573 };
2574
2575 let remote_url = backend.remote_url("origin");
2576 let head_sha = backend.head_sha().await;
2577 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2578
2579 GitState {
2580 remote_url,
2581 head_sha,
2582 current_branch,
2583 diff,
2584 }
2585 })
2586 })
2587 });
2588
2589 let git_state = match git_state {
2590 Some(git_state) => match git_state.ok() {
2591 Some(git_state) => git_state.await.ok(),
2592 None => None,
2593 },
2594 None => None,
2595 };
2596
2597 WorktreeSnapshot {
2598 worktree_path,
2599 git_state,
2600 }
2601 })
2602 }
2603
2604 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2605 let mut markdown = Vec::new();
2606
2607 let summary = self.summary().or_default();
2608 writeln!(markdown, "# {summary}\n")?;
2609
2610 for message in self.messages() {
2611 writeln!(
2612 markdown,
2613 "## {role}\n",
2614 role = match message.role {
2615 Role::User => "User",
2616 Role::Assistant => "Agent",
2617 Role::System => "System",
2618 }
2619 )?;
2620
2621 if !message.loaded_context.text.is_empty() {
2622 writeln!(markdown, "{}", message.loaded_context.text)?;
2623 }
2624
2625 if !message.loaded_context.images.is_empty() {
2626 writeln!(
2627 markdown,
2628 "\n{} images attached as context.\n",
2629 message.loaded_context.images.len()
2630 )?;
2631 }
2632
2633 for segment in &message.segments {
2634 match segment {
2635 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2636 MessageSegment::Thinking { text, .. } => {
2637 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2638 }
2639 MessageSegment::RedactedThinking(_) => {}
2640 }
2641 }
2642
2643 for tool_use in self.tool_uses_for_message(message.id, cx) {
2644 writeln!(
2645 markdown,
2646 "**Use Tool: {} ({})**",
2647 tool_use.name, tool_use.id
2648 )?;
2649 writeln!(markdown, "```json")?;
2650 writeln!(
2651 markdown,
2652 "{}",
2653 serde_json::to_string_pretty(&tool_use.input)?
2654 )?;
2655 writeln!(markdown, "```")?;
2656 }
2657
2658 for tool_result in self.tool_results_for_message(message.id) {
2659 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2660 if tool_result.is_error {
2661 write!(markdown, " (Error)")?;
2662 }
2663
2664 writeln!(markdown, "**\n")?;
2665 match &tool_result.content {
2666 LanguageModelToolResultContent::Text(text) => {
2667 writeln!(markdown, "{text}")?;
2668 }
2669 LanguageModelToolResultContent::Image(image) => {
2670 writeln!(markdown, "", image.source)?;
2671 }
2672 }
2673
2674 if let Some(output) = tool_result.output.as_ref() {
2675 writeln!(
2676 markdown,
2677 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2678 serde_json::to_string_pretty(output)?
2679 )?;
2680 }
2681 }
2682 }
2683
2684 Ok(String::from_utf8_lossy(&markdown).to_string())
2685 }
2686
2687 pub fn keep_edits_in_range(
2688 &mut self,
2689 buffer: Entity<language::Buffer>,
2690 buffer_range: Range<language::Anchor>,
2691 cx: &mut Context<Self>,
2692 ) {
2693 self.action_log.update(cx, |action_log, cx| {
2694 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2695 });
2696 }
2697
2698 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2699 self.action_log
2700 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2701 }
2702
2703 pub fn reject_edits_in_ranges(
2704 &mut self,
2705 buffer: Entity<language::Buffer>,
2706 buffer_ranges: Vec<Range<language::Anchor>>,
2707 cx: &mut Context<Self>,
2708 ) -> Task<Result<()>> {
2709 self.action_log.update(cx, |action_log, cx| {
2710 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2711 })
2712 }
2713
2714 pub fn action_log(&self) -> &Entity<ActionLog> {
2715 &self.action_log
2716 }
2717
2718 pub fn project(&self) -> &Entity<Project> {
2719 &self.project
2720 }
2721
2722 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2723 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2724 return;
2725 }
2726
2727 let now = Instant::now();
2728 if let Some(last) = self.last_auto_capture_at {
2729 if now.duration_since(last).as_secs() < 10 {
2730 return;
2731 }
2732 }
2733
2734 self.last_auto_capture_at = Some(now);
2735
2736 let thread_id = self.id().clone();
2737 let github_login = self
2738 .project
2739 .read(cx)
2740 .user_store()
2741 .read(cx)
2742 .current_user()
2743 .map(|user| user.github_login.clone());
2744 let client = self.project.read(cx).client();
2745 let serialize_task = self.serialize(cx);
2746
2747 cx.background_executor()
2748 .spawn(async move {
2749 if let Ok(serialized_thread) = serialize_task.await {
2750 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2751 telemetry::event!(
2752 "Agent Thread Auto-Captured",
2753 thread_id = thread_id.to_string(),
2754 thread_data = thread_data,
2755 auto_capture_reason = "tracked_user",
2756 github_login = github_login
2757 );
2758
2759 client.telemetry().flush_events().await;
2760 }
2761 }
2762 })
2763 .detach();
2764 }
2765
2766 pub fn cumulative_token_usage(&self) -> TokenUsage {
2767 self.cumulative_token_usage
2768 }
2769
2770 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2771 let Some(model) = self.configured_model.as_ref() else {
2772 return TotalTokenUsage::default();
2773 };
2774
2775 let max = model.model.max_token_count();
2776
2777 let index = self
2778 .messages
2779 .iter()
2780 .position(|msg| msg.id == message_id)
2781 .unwrap_or(0);
2782
2783 if index == 0 {
2784 return TotalTokenUsage { total: 0, max };
2785 }
2786
2787 let token_usage = &self
2788 .request_token_usage
2789 .get(index - 1)
2790 .cloned()
2791 .unwrap_or_default();
2792
2793 TotalTokenUsage {
2794 total: token_usage.total_tokens(),
2795 max,
2796 }
2797 }
2798
2799 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2800 let model = self.configured_model.as_ref()?;
2801
2802 let max = model.model.max_token_count();
2803
2804 if let Some(exceeded_error) = &self.exceeded_window_error {
2805 if model.model.id() == exceeded_error.model_id {
2806 return Some(TotalTokenUsage {
2807 total: exceeded_error.token_count,
2808 max,
2809 });
2810 }
2811 }
2812
2813 let total = self
2814 .token_usage_at_last_message()
2815 .unwrap_or_default()
2816 .total_tokens();
2817
2818 Some(TotalTokenUsage { total, max })
2819 }
2820
2821 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2822 self.request_token_usage
2823 .get(self.messages.len().saturating_sub(1))
2824 .or_else(|| self.request_token_usage.last())
2825 .cloned()
2826 }
2827
2828 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2829 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2830 self.request_token_usage
2831 .resize(self.messages.len(), placeholder);
2832
2833 if let Some(last) = self.request_token_usage.last_mut() {
2834 *last = token_usage;
2835 }
2836 }
2837
2838 pub fn deny_tool_use(
2839 &mut self,
2840 tool_use_id: LanguageModelToolUseId,
2841 tool_name: Arc<str>,
2842 window: Option<AnyWindowHandle>,
2843 cx: &mut Context<Self>,
2844 ) {
2845 let err = Err(anyhow::anyhow!(
2846 "Permission to run tool action denied by user"
2847 ));
2848
2849 self.tool_use.insert_tool_output(
2850 tool_use_id.clone(),
2851 tool_name,
2852 err,
2853 self.configured_model.as_ref(),
2854 );
2855 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2856 }
2857}
2858
2859#[derive(Debug, Clone, Error)]
2860pub enum ThreadError {
2861 #[error("Payment required")]
2862 PaymentRequired,
2863 #[error("Model request limit reached")]
2864 ModelRequestLimitReached { plan: Plan },
2865 #[error("Message {header}: {message}")]
2866 Message {
2867 header: SharedString,
2868 message: SharedString,
2869 },
2870}
2871
2872#[derive(Debug, Clone)]
2873pub enum ThreadEvent {
2874 ShowError(ThreadError),
2875 StreamedCompletion,
2876 ReceivedTextChunk,
2877 NewRequest,
2878 StreamedAssistantText(MessageId, String),
2879 StreamedAssistantThinking(MessageId, String),
2880 StreamedToolUse {
2881 tool_use_id: LanguageModelToolUseId,
2882 ui_text: Arc<str>,
2883 input: serde_json::Value,
2884 },
2885 MissingToolUse {
2886 tool_use_id: LanguageModelToolUseId,
2887 ui_text: Arc<str>,
2888 },
2889 InvalidToolInput {
2890 tool_use_id: LanguageModelToolUseId,
2891 ui_text: Arc<str>,
2892 invalid_input_json: Arc<str>,
2893 },
2894 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2895 MessageAdded(MessageId),
2896 MessageEdited(MessageId),
2897 MessageDeleted(MessageId),
2898 SummaryGenerated,
2899 SummaryChanged,
2900 UsePendingTools {
2901 tool_uses: Vec<PendingToolUse>,
2902 },
2903 ToolFinished {
2904 #[allow(unused)]
2905 tool_use_id: LanguageModelToolUseId,
2906 /// The pending tool use that corresponds to this tool.
2907 pending_tool_use: Option<PendingToolUse>,
2908 },
2909 CheckpointChanged,
2910 ToolConfirmationNeeded,
2911 ToolUseLimitReached,
2912 CancelEditing,
2913 CompletionCanceled,
2914 ProfileChanged,
2915}
2916
2917impl EventEmitter<ThreadEvent> for Thread {}
2918
2919struct PendingCompletion {
2920 id: usize,
2921 queue_state: QueueState,
2922 _task: Task<()>,
2923}
2924
2925/// Resolves tool name conflicts by ensuring all tool names are unique.
2926///
2927/// When multiple tools have the same name, this function applies the following rules:
2928/// 1. Native tools always keep their original name
2929/// 2. Context server tools get prefixed with their server ID and an underscore
2930/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
2931/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
2932///
2933/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
2934fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
2935 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
2936 let mut tool_name = tool.name();
2937 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2938 tool_name
2939 }
2940
2941 const MAX_TOOL_NAME_LENGTH: usize = 64;
2942
2943 let mut duplicated_tool_names = HashSet::default();
2944 let mut seen_tool_names = HashSet::default();
2945 for tool in tools {
2946 let tool_name = resolve_tool_name(tool);
2947 if seen_tool_names.contains(&tool_name) {
2948 debug_assert!(
2949 tool.source() != assistant_tool::ToolSource::Native,
2950 "There are two built-in tools with the same name: {}",
2951 tool_name
2952 );
2953 duplicated_tool_names.insert(tool_name);
2954 } else {
2955 seen_tool_names.insert(tool_name);
2956 }
2957 }
2958
2959 if duplicated_tool_names.is_empty() {
2960 return tools
2961 .into_iter()
2962 .map(|tool| (resolve_tool_name(tool), tool.clone()))
2963 .collect();
2964 }
2965
2966 tools
2967 .into_iter()
2968 .filter_map(|tool| {
2969 let mut tool_name = resolve_tool_name(tool);
2970 if !duplicated_tool_names.contains(&tool_name) {
2971 return Some((tool_name, tool.clone()));
2972 }
2973 match tool.source() {
2974 assistant_tool::ToolSource::Native => {
2975 // Built-in tools always keep their original name
2976 Some((tool_name, tool.clone()))
2977 }
2978 assistant_tool::ToolSource::ContextServer { id } => {
2979 // Context server tools are prefixed with the context server ID, and truncated if necessary
2980 tool_name.insert(0, '_');
2981 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
2982 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
2983 let mut id = id.to_string();
2984 id.truncate(len);
2985 tool_name.insert_str(0, &id);
2986 } else {
2987 tool_name.insert_str(0, &id);
2988 }
2989
2990 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2991
2992 if seen_tool_names.contains(&tool_name) {
2993 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
2994 None
2995 } else {
2996 Some((tool_name, tool.clone()))
2997 }
2998 }
2999 }
3000 })
3001 .collect()
3002}
3003
3004#[cfg(test)]
3005mod tests {
3006 use super::*;
3007 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
3008 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3009 use assistant_tool::ToolRegistry;
3010 use editor::EditorSettings;
3011 use gpui::TestAppContext;
3012 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3013 use project::{FakeFs, Project};
3014 use prompt_store::PromptBuilder;
3015 use serde_json::json;
3016 use settings::{Settings, SettingsStore};
3017 use std::sync::Arc;
3018 use theme::ThemeSettings;
3019 use ui::IconName;
3020 use util::path;
3021 use workspace::Workspace;
3022
3023 #[gpui::test]
3024 async fn test_message_with_context(cx: &mut TestAppContext) {
3025 init_test_settings(cx);
3026
3027 let project = create_test_project(
3028 cx,
3029 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3030 )
3031 .await;
3032
3033 let (_workspace, _thread_store, thread, context_store, model) =
3034 setup_test_environment(cx, project.clone()).await;
3035
3036 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3037 .await
3038 .unwrap();
3039
3040 let context =
3041 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3042 let loaded_context = cx
3043 .update(|cx| load_context(vec![context], &project, &None, cx))
3044 .await;
3045
3046 // Insert user message with context
3047 let message_id = thread.update(cx, |thread, cx| {
3048 thread.insert_user_message(
3049 "Please explain this code",
3050 loaded_context,
3051 None,
3052 Vec::new(),
3053 cx,
3054 )
3055 });
3056
3057 // Check content and context in message object
3058 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3059
3060 // Use different path format strings based on platform for the test
3061 #[cfg(windows)]
3062 let path_part = r"test\code.rs";
3063 #[cfg(not(windows))]
3064 let path_part = "test/code.rs";
3065
3066 let expected_context = format!(
3067 r#"
3068<context>
3069The following items were attached by the user. They are up-to-date and don't need to be re-read.
3070
3071<files>
3072```rs {path_part}
3073fn main() {{
3074 println!("Hello, world!");
3075}}
3076```
3077</files>
3078</context>
3079"#
3080 );
3081
3082 assert_eq!(message.role, Role::User);
3083 assert_eq!(message.segments.len(), 1);
3084 assert_eq!(
3085 message.segments[0],
3086 MessageSegment::Text("Please explain this code".to_string())
3087 );
3088 assert_eq!(message.loaded_context.text, expected_context);
3089
3090 // Check message in request
3091 let request = thread.update(cx, |thread, cx| {
3092 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3093 });
3094
3095 assert_eq!(request.messages.len(), 2);
3096 let expected_full_message = format!("{}Please explain this code", expected_context);
3097 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3098 }
3099
3100 #[gpui::test]
3101 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3102 init_test_settings(cx);
3103
3104 let project = create_test_project(
3105 cx,
3106 json!({
3107 "file1.rs": "fn function1() {}\n",
3108 "file2.rs": "fn function2() {}\n",
3109 "file3.rs": "fn function3() {}\n",
3110 "file4.rs": "fn function4() {}\n",
3111 }),
3112 )
3113 .await;
3114
3115 let (_, _thread_store, thread, context_store, model) =
3116 setup_test_environment(cx, project.clone()).await;
3117
3118 // First message with context 1
3119 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3120 .await
3121 .unwrap();
3122 let new_contexts = context_store.update(cx, |store, cx| {
3123 store.new_context_for_thread(thread.read(cx), None)
3124 });
3125 assert_eq!(new_contexts.len(), 1);
3126 let loaded_context = cx
3127 .update(|cx| load_context(new_contexts, &project, &None, cx))
3128 .await;
3129 let message1_id = thread.update(cx, |thread, cx| {
3130 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3131 });
3132
3133 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3134 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3135 .await
3136 .unwrap();
3137 let new_contexts = context_store.update(cx, |store, cx| {
3138 store.new_context_for_thread(thread.read(cx), None)
3139 });
3140 assert_eq!(new_contexts.len(), 1);
3141 let loaded_context = cx
3142 .update(|cx| load_context(new_contexts, &project, &None, cx))
3143 .await;
3144 let message2_id = thread.update(cx, |thread, cx| {
3145 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3146 });
3147
3148 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3149 //
3150 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3151 .await
3152 .unwrap();
3153 let new_contexts = context_store.update(cx, |store, cx| {
3154 store.new_context_for_thread(thread.read(cx), None)
3155 });
3156 assert_eq!(new_contexts.len(), 1);
3157 let loaded_context = cx
3158 .update(|cx| load_context(new_contexts, &project, &None, cx))
3159 .await;
3160 let message3_id = thread.update(cx, |thread, cx| {
3161 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3162 });
3163
3164 // Check what contexts are included in each message
3165 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3166 (
3167 thread.message(message1_id).unwrap().clone(),
3168 thread.message(message2_id).unwrap().clone(),
3169 thread.message(message3_id).unwrap().clone(),
3170 )
3171 });
3172
3173 // First message should include context 1
3174 assert!(message1.loaded_context.text.contains("file1.rs"));
3175
3176 // Second message should include only context 2 (not 1)
3177 assert!(!message2.loaded_context.text.contains("file1.rs"));
3178 assert!(message2.loaded_context.text.contains("file2.rs"));
3179
3180 // Third message should include only context 3 (not 1 or 2)
3181 assert!(!message3.loaded_context.text.contains("file1.rs"));
3182 assert!(!message3.loaded_context.text.contains("file2.rs"));
3183 assert!(message3.loaded_context.text.contains("file3.rs"));
3184
3185 // Check entire request to make sure all contexts are properly included
3186 let request = thread.update(cx, |thread, cx| {
3187 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3188 });
3189
3190 // The request should contain all 3 messages
3191 assert_eq!(request.messages.len(), 4);
3192
3193 // Check that the contexts are properly formatted in each message
3194 assert!(request.messages[1].string_contents().contains("file1.rs"));
3195 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3196 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3197
3198 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3199 assert!(request.messages[2].string_contents().contains("file2.rs"));
3200 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3201
3202 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3203 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3204 assert!(request.messages[3].string_contents().contains("file3.rs"));
3205
3206 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3207 .await
3208 .unwrap();
3209 let new_contexts = context_store.update(cx, |store, cx| {
3210 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3211 });
3212 assert_eq!(new_contexts.len(), 3);
3213 let loaded_context = cx
3214 .update(|cx| load_context(new_contexts, &project, &None, cx))
3215 .await
3216 .loaded_context;
3217
3218 assert!(!loaded_context.text.contains("file1.rs"));
3219 assert!(loaded_context.text.contains("file2.rs"));
3220 assert!(loaded_context.text.contains("file3.rs"));
3221 assert!(loaded_context.text.contains("file4.rs"));
3222
3223 let new_contexts = context_store.update(cx, |store, cx| {
3224 // Remove file4.rs
3225 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3226 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3227 });
3228 assert_eq!(new_contexts.len(), 2);
3229 let loaded_context = cx
3230 .update(|cx| load_context(new_contexts, &project, &None, cx))
3231 .await
3232 .loaded_context;
3233
3234 assert!(!loaded_context.text.contains("file1.rs"));
3235 assert!(loaded_context.text.contains("file2.rs"));
3236 assert!(loaded_context.text.contains("file3.rs"));
3237 assert!(!loaded_context.text.contains("file4.rs"));
3238
3239 let new_contexts = context_store.update(cx, |store, cx| {
3240 // Remove file3.rs
3241 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3242 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3243 });
3244 assert_eq!(new_contexts.len(), 1);
3245 let loaded_context = cx
3246 .update(|cx| load_context(new_contexts, &project, &None, cx))
3247 .await
3248 .loaded_context;
3249
3250 assert!(!loaded_context.text.contains("file1.rs"));
3251 assert!(loaded_context.text.contains("file2.rs"));
3252 assert!(!loaded_context.text.contains("file3.rs"));
3253 assert!(!loaded_context.text.contains("file4.rs"));
3254 }
3255
3256 #[gpui::test]
3257 async fn test_message_without_files(cx: &mut TestAppContext) {
3258 init_test_settings(cx);
3259
3260 let project = create_test_project(
3261 cx,
3262 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3263 )
3264 .await;
3265
3266 let (_, _thread_store, thread, _context_store, model) =
3267 setup_test_environment(cx, project.clone()).await;
3268
3269 // Insert user message without any context (empty context vector)
3270 let message_id = thread.update(cx, |thread, cx| {
3271 thread.insert_user_message(
3272 "What is the best way to learn Rust?",
3273 ContextLoadResult::default(),
3274 None,
3275 Vec::new(),
3276 cx,
3277 )
3278 });
3279
3280 // Check content and context in message object
3281 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3282
3283 // Context should be empty when no files are included
3284 assert_eq!(message.role, Role::User);
3285 assert_eq!(message.segments.len(), 1);
3286 assert_eq!(
3287 message.segments[0],
3288 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3289 );
3290 assert_eq!(message.loaded_context.text, "");
3291
3292 // Check message in request
3293 let request = thread.update(cx, |thread, cx| {
3294 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3295 });
3296
3297 assert_eq!(request.messages.len(), 2);
3298 assert_eq!(
3299 request.messages[1].string_contents(),
3300 "What is the best way to learn Rust?"
3301 );
3302
3303 // Add second message, also without context
3304 let message2_id = thread.update(cx, |thread, cx| {
3305 thread.insert_user_message(
3306 "Are there any good books?",
3307 ContextLoadResult::default(),
3308 None,
3309 Vec::new(),
3310 cx,
3311 )
3312 });
3313
3314 let message2 =
3315 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3316 assert_eq!(message2.loaded_context.text, "");
3317
3318 // Check that both messages appear in the request
3319 let request = thread.update(cx, |thread, cx| {
3320 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3321 });
3322
3323 assert_eq!(request.messages.len(), 3);
3324 assert_eq!(
3325 request.messages[1].string_contents(),
3326 "What is the best way to learn Rust?"
3327 );
3328 assert_eq!(
3329 request.messages[2].string_contents(),
3330 "Are there any good books?"
3331 );
3332 }
3333
3334 #[gpui::test]
3335 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3336 init_test_settings(cx);
3337
3338 let project = create_test_project(
3339 cx,
3340 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3341 )
3342 .await;
3343
3344 let (_workspace, _thread_store, thread, context_store, model) =
3345 setup_test_environment(cx, project.clone()).await;
3346
3347 // Open buffer and add it to context
3348 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3349 .await
3350 .unwrap();
3351
3352 let context =
3353 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3354 let loaded_context = cx
3355 .update(|cx| load_context(vec![context], &project, &None, cx))
3356 .await;
3357
3358 // Insert user message with the buffer as context
3359 thread.update(cx, |thread, cx| {
3360 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3361 });
3362
3363 // Create a request and check that it doesn't have a stale buffer warning yet
3364 let initial_request = thread.update(cx, |thread, cx| {
3365 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3366 });
3367
3368 // Make sure we don't have a stale file warning yet
3369 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3370 msg.string_contents()
3371 .contains("These files changed since last read:")
3372 });
3373 assert!(
3374 !has_stale_warning,
3375 "Should not have stale buffer warning before buffer is modified"
3376 );
3377
3378 // Modify the buffer
3379 buffer.update(cx, |buffer, cx| {
3380 // Find a position at the end of line 1
3381 buffer.edit(
3382 [(1..1, "\n println!(\"Added a new line\");\n")],
3383 None,
3384 cx,
3385 );
3386 });
3387
3388 // Insert another user message without context
3389 thread.update(cx, |thread, cx| {
3390 thread.insert_user_message(
3391 "What does the code do now?",
3392 ContextLoadResult::default(),
3393 None,
3394 Vec::new(),
3395 cx,
3396 )
3397 });
3398
3399 // Create a new request and check for the stale buffer warning
3400 let new_request = thread.update(cx, |thread, cx| {
3401 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3402 });
3403
3404 // We should have a stale file warning as the last message
3405 let last_message = new_request
3406 .messages
3407 .last()
3408 .expect("Request should have messages");
3409
3410 // The last message should be the stale buffer notification
3411 assert_eq!(last_message.role, Role::User);
3412
3413 // Check the exact content of the message
3414 let expected_content = "[The following is an auto-generated notification; do not reply]
3415
3416These files have changed since the last read:
3417- code.rs
3418";
3419 assert_eq!(
3420 last_message.string_contents(),
3421 expected_content,
3422 "Last message should be exactly the stale buffer notification"
3423 );
3424
3425 // The message before the notification should be cached
3426 let index = new_request.messages.len() - 2;
3427 let previous_message = new_request.messages.get(index).unwrap();
3428 assert!(
3429 previous_message.cache,
3430 "Message before the stale buffer notification should be cached"
3431 );
3432 }
3433
3434 #[gpui::test]
3435 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3436 init_test_settings(cx);
3437
3438 let project = create_test_project(
3439 cx,
3440 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3441 )
3442 .await;
3443
3444 let (_workspace, thread_store, thread, _context_store, _model) =
3445 setup_test_environment(cx, project.clone()).await;
3446
3447 // Check that we are starting with the default profile
3448 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3449 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3450 assert_eq!(
3451 profile,
3452 AgentProfile::new(AgentProfileId::default(), tool_set)
3453 );
3454 }
3455
3456 #[gpui::test]
3457 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3458 init_test_settings(cx);
3459
3460 let project = create_test_project(
3461 cx,
3462 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3463 )
3464 .await;
3465
3466 let (_workspace, thread_store, thread, _context_store, _model) =
3467 setup_test_environment(cx, project.clone()).await;
3468
3469 // Profile gets serialized with default values
3470 let serialized = thread
3471 .update(cx, |thread, cx| thread.serialize(cx))
3472 .await
3473 .unwrap();
3474
3475 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3476
3477 let deserialized = cx.update(|cx| {
3478 thread.update(cx, |thread, cx| {
3479 Thread::deserialize(
3480 thread.id.clone(),
3481 serialized,
3482 thread.project.clone(),
3483 thread.tools.clone(),
3484 thread.prompt_builder.clone(),
3485 thread.project_context.clone(),
3486 None,
3487 cx,
3488 )
3489 })
3490 });
3491 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3492
3493 assert_eq!(
3494 deserialized.profile,
3495 AgentProfile::new(AgentProfileId::default(), tool_set)
3496 );
3497 }
3498
3499 #[gpui::test]
3500 async fn test_temperature_setting(cx: &mut TestAppContext) {
3501 init_test_settings(cx);
3502
3503 let project = create_test_project(
3504 cx,
3505 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3506 )
3507 .await;
3508
3509 let (_workspace, _thread_store, thread, _context_store, model) =
3510 setup_test_environment(cx, project.clone()).await;
3511
3512 // Both model and provider
3513 cx.update(|cx| {
3514 AgentSettings::override_global(
3515 AgentSettings {
3516 model_parameters: vec![LanguageModelParameters {
3517 provider: Some(model.provider_id().0.to_string().into()),
3518 model: Some(model.id().0.clone()),
3519 temperature: Some(0.66),
3520 }],
3521 ..AgentSettings::get_global(cx).clone()
3522 },
3523 cx,
3524 );
3525 });
3526
3527 let request = thread.update(cx, |thread, cx| {
3528 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3529 });
3530 assert_eq!(request.temperature, Some(0.66));
3531
3532 // Only model
3533 cx.update(|cx| {
3534 AgentSettings::override_global(
3535 AgentSettings {
3536 model_parameters: vec![LanguageModelParameters {
3537 provider: None,
3538 model: Some(model.id().0.clone()),
3539 temperature: Some(0.66),
3540 }],
3541 ..AgentSettings::get_global(cx).clone()
3542 },
3543 cx,
3544 );
3545 });
3546
3547 let request = thread.update(cx, |thread, cx| {
3548 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3549 });
3550 assert_eq!(request.temperature, Some(0.66));
3551
3552 // Only provider
3553 cx.update(|cx| {
3554 AgentSettings::override_global(
3555 AgentSettings {
3556 model_parameters: vec![LanguageModelParameters {
3557 provider: Some(model.provider_id().0.to_string().into()),
3558 model: None,
3559 temperature: Some(0.66),
3560 }],
3561 ..AgentSettings::get_global(cx).clone()
3562 },
3563 cx,
3564 );
3565 });
3566
3567 let request = thread.update(cx, |thread, cx| {
3568 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3569 });
3570 assert_eq!(request.temperature, Some(0.66));
3571
3572 // Same model name, different provider
3573 cx.update(|cx| {
3574 AgentSettings::override_global(
3575 AgentSettings {
3576 model_parameters: vec![LanguageModelParameters {
3577 provider: Some("anthropic".into()),
3578 model: Some(model.id().0.clone()),
3579 temperature: Some(0.66),
3580 }],
3581 ..AgentSettings::get_global(cx).clone()
3582 },
3583 cx,
3584 );
3585 });
3586
3587 let request = thread.update(cx, |thread, cx| {
3588 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3589 });
3590 assert_eq!(request.temperature, None);
3591 }
3592
3593 #[gpui::test]
3594 async fn test_thread_summary(cx: &mut TestAppContext) {
3595 init_test_settings(cx);
3596
3597 let project = create_test_project(cx, json!({})).await;
3598
3599 let (_, _thread_store, thread, _context_store, model) =
3600 setup_test_environment(cx, project.clone()).await;
3601
3602 // Initial state should be pending
3603 thread.read_with(cx, |thread, _| {
3604 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3605 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3606 });
3607
3608 // Manually setting the summary should not be allowed in this state
3609 thread.update(cx, |thread, cx| {
3610 thread.set_summary("This should not work", cx);
3611 });
3612
3613 thread.read_with(cx, |thread, _| {
3614 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3615 });
3616
3617 // Send a message
3618 thread.update(cx, |thread, cx| {
3619 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3620 thread.send_to_model(
3621 model.clone(),
3622 CompletionIntent::ThreadSummarization,
3623 None,
3624 cx,
3625 );
3626 });
3627
3628 let fake_model = model.as_fake();
3629 simulate_successful_response(&fake_model, cx);
3630
3631 // Should start generating summary when there are >= 2 messages
3632 thread.read_with(cx, |thread, _| {
3633 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3634 });
3635
3636 // Should not be able to set the summary while generating
3637 thread.update(cx, |thread, cx| {
3638 thread.set_summary("This should not work either", cx);
3639 });
3640
3641 thread.read_with(cx, |thread, _| {
3642 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3643 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3644 });
3645
3646 cx.run_until_parked();
3647 fake_model.stream_last_completion_response("Brief");
3648 fake_model.stream_last_completion_response(" Introduction");
3649 fake_model.end_last_completion_stream();
3650 cx.run_until_parked();
3651
3652 // Summary should be set
3653 thread.read_with(cx, |thread, _| {
3654 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3655 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3656 });
3657
3658 // Now we should be able to set a summary
3659 thread.update(cx, |thread, cx| {
3660 thread.set_summary("Brief Intro", cx);
3661 });
3662
3663 thread.read_with(cx, |thread, _| {
3664 assert_eq!(thread.summary().or_default(), "Brief Intro");
3665 });
3666
3667 // Test setting an empty summary (should default to DEFAULT)
3668 thread.update(cx, |thread, cx| {
3669 thread.set_summary("", cx);
3670 });
3671
3672 thread.read_with(cx, |thread, _| {
3673 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3674 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3675 });
3676 }
3677
3678 #[gpui::test]
3679 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3680 init_test_settings(cx);
3681
3682 let project = create_test_project(cx, json!({})).await;
3683
3684 let (_, _thread_store, thread, _context_store, model) =
3685 setup_test_environment(cx, project.clone()).await;
3686
3687 test_summarize_error(&model, &thread, cx);
3688
3689 // Now we should be able to set a summary
3690 thread.update(cx, |thread, cx| {
3691 thread.set_summary("Brief Intro", cx);
3692 });
3693
3694 thread.read_with(cx, |thread, _| {
3695 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3696 assert_eq!(thread.summary().or_default(), "Brief Intro");
3697 });
3698 }
3699
3700 #[gpui::test]
3701 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3702 init_test_settings(cx);
3703
3704 let project = create_test_project(cx, json!({})).await;
3705
3706 let (_, _thread_store, thread, _context_store, model) =
3707 setup_test_environment(cx, project.clone()).await;
3708
3709 test_summarize_error(&model, &thread, cx);
3710
3711 // Sending another message should not trigger another summarize request
3712 thread.update(cx, |thread, cx| {
3713 thread.insert_user_message(
3714 "How are you?",
3715 ContextLoadResult::default(),
3716 None,
3717 vec![],
3718 cx,
3719 );
3720 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3721 });
3722
3723 let fake_model = model.as_fake();
3724 simulate_successful_response(&fake_model, cx);
3725
3726 thread.read_with(cx, |thread, _| {
3727 // State is still Error, not Generating
3728 assert!(matches!(thread.summary(), ThreadSummary::Error));
3729 });
3730
3731 // But the summarize request can be invoked manually
3732 thread.update(cx, |thread, cx| {
3733 thread.summarize(cx);
3734 });
3735
3736 thread.read_with(cx, |thread, _| {
3737 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3738 });
3739
3740 cx.run_until_parked();
3741 fake_model.stream_last_completion_response("A successful summary");
3742 fake_model.end_last_completion_stream();
3743 cx.run_until_parked();
3744
3745 thread.read_with(cx, |thread, _| {
3746 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3747 assert_eq!(thread.summary().or_default(), "A successful summary");
3748 });
3749 }
3750
3751 #[gpui::test]
3752 fn test_resolve_tool_name_conflicts() {
3753 use assistant_tool::{Tool, ToolSource};
3754
3755 assert_resolve_tool_name_conflicts(
3756 vec![
3757 TestTool::new("tool1", ToolSource::Native),
3758 TestTool::new("tool2", ToolSource::Native),
3759 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3760 ],
3761 vec!["tool1", "tool2", "tool3"],
3762 );
3763
3764 assert_resolve_tool_name_conflicts(
3765 vec![
3766 TestTool::new("tool1", ToolSource::Native),
3767 TestTool::new("tool2", ToolSource::Native),
3768 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3769 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3770 ],
3771 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3772 );
3773
3774 assert_resolve_tool_name_conflicts(
3775 vec![
3776 TestTool::new("tool1", ToolSource::Native),
3777 TestTool::new("tool2", ToolSource::Native),
3778 TestTool::new("tool3", ToolSource::Native),
3779 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3780 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3781 ],
3782 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3783 );
3784
3785 // Test that tool with very long name is always truncated
3786 assert_resolve_tool_name_conflicts(
3787 vec![TestTool::new(
3788 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3789 ToolSource::Native,
3790 )],
3791 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3792 );
3793
3794 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3795 assert_resolve_tool_name_conflicts(
3796 vec![
3797 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3798 TestTool::new(
3799 "tool-with-very-very-very-long-name",
3800 ToolSource::ContextServer {
3801 id: "mcp-with-very-very-very-long-name".into(),
3802 },
3803 ),
3804 ],
3805 vec![
3806 "tool-with-very-very-very-long-name",
3807 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3808 ],
3809 );
3810
3811 fn assert_resolve_tool_name_conflicts(
3812 tools: Vec<TestTool>,
3813 expected: Vec<impl Into<String>>,
3814 ) {
3815 let tools: Vec<Arc<dyn Tool>> = tools
3816 .into_iter()
3817 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3818 .collect();
3819 let tools = resolve_tool_name_conflicts(&tools);
3820 assert_eq!(tools.len(), expected.len());
3821 for (i, expected_name) in expected.into_iter().enumerate() {
3822 let expected_name = expected_name.into();
3823 let actual_name = &tools[i].0;
3824 assert_eq!(
3825 actual_name, &expected_name,
3826 "Expected '{}' got '{}' at index {}",
3827 expected_name, actual_name, i
3828 );
3829 }
3830 }
3831
3832 struct TestTool {
3833 name: String,
3834 source: ToolSource,
3835 }
3836
3837 impl TestTool {
3838 fn new(name: impl Into<String>, source: ToolSource) -> Self {
3839 Self {
3840 name: name.into(),
3841 source,
3842 }
3843 }
3844 }
3845
3846 impl Tool for TestTool {
3847 fn name(&self) -> String {
3848 self.name.clone()
3849 }
3850
3851 fn icon(&self) -> IconName {
3852 IconName::Ai
3853 }
3854
3855 fn may_perform_edits(&self) -> bool {
3856 false
3857 }
3858
3859 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3860 true
3861 }
3862
3863 fn source(&self) -> ToolSource {
3864 self.source.clone()
3865 }
3866
3867 fn description(&self) -> String {
3868 "Test tool".to_string()
3869 }
3870
3871 fn ui_text(&self, _input: &serde_json::Value) -> String {
3872 "Test tool".to_string()
3873 }
3874
3875 fn run(
3876 self: Arc<Self>,
3877 _input: serde_json::Value,
3878 _request: Arc<LanguageModelRequest>,
3879 _project: Entity<Project>,
3880 _action_log: Entity<ActionLog>,
3881 _model: Arc<dyn LanguageModel>,
3882 _window: Option<AnyWindowHandle>,
3883 _cx: &mut App,
3884 ) -> assistant_tool::ToolResult {
3885 assistant_tool::ToolResult {
3886 output: Task::ready(Err(anyhow::anyhow!("No content"))),
3887 card: None,
3888 }
3889 }
3890 }
3891 }
3892
3893 fn test_summarize_error(
3894 model: &Arc<dyn LanguageModel>,
3895 thread: &Entity<Thread>,
3896 cx: &mut TestAppContext,
3897 ) {
3898 thread.update(cx, |thread, cx| {
3899 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3900 thread.send_to_model(
3901 model.clone(),
3902 CompletionIntent::ThreadSummarization,
3903 None,
3904 cx,
3905 );
3906 });
3907
3908 let fake_model = model.as_fake();
3909 simulate_successful_response(&fake_model, cx);
3910
3911 thread.read_with(cx, |thread, _| {
3912 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3913 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3914 });
3915
3916 // Simulate summary request ending
3917 cx.run_until_parked();
3918 fake_model.end_last_completion_stream();
3919 cx.run_until_parked();
3920
3921 // State is set to Error and default message
3922 thread.read_with(cx, |thread, _| {
3923 assert!(matches!(thread.summary(), ThreadSummary::Error));
3924 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3925 });
3926 }
3927
3928 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3929 cx.run_until_parked();
3930 fake_model.stream_last_completion_response("Assistant response");
3931 fake_model.end_last_completion_stream();
3932 cx.run_until_parked();
3933 }
3934
3935 fn init_test_settings(cx: &mut TestAppContext) {
3936 cx.update(|cx| {
3937 let settings_store = SettingsStore::test(cx);
3938 cx.set_global(settings_store);
3939 language::init(cx);
3940 Project::init_settings(cx);
3941 AgentSettings::register(cx);
3942 prompt_store::init(cx);
3943 thread_store::init(cx);
3944 workspace::init_settings(cx);
3945 language_model::init_settings(cx);
3946 ThemeSettings::register(cx);
3947 EditorSettings::register(cx);
3948 ToolRegistry::default_global(cx);
3949 });
3950 }
3951
3952 // Helper to create a test project with test files
3953 async fn create_test_project(
3954 cx: &mut TestAppContext,
3955 files: serde_json::Value,
3956 ) -> Entity<Project> {
3957 let fs = FakeFs::new(cx.executor());
3958 fs.insert_tree(path!("/test"), files).await;
3959 Project::test(fs, [path!("/test").as_ref()], cx).await
3960 }
3961
3962 async fn setup_test_environment(
3963 cx: &mut TestAppContext,
3964 project: Entity<Project>,
3965 ) -> (
3966 Entity<Workspace>,
3967 Entity<ThreadStore>,
3968 Entity<Thread>,
3969 Entity<ContextStore>,
3970 Arc<dyn LanguageModel>,
3971 ) {
3972 let (workspace, cx) =
3973 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3974
3975 let thread_store = cx
3976 .update(|_, cx| {
3977 ThreadStore::load(
3978 project.clone(),
3979 cx.new(|_| ToolWorkingSet::default()),
3980 None,
3981 Arc::new(PromptBuilder::new(None).unwrap()),
3982 cx,
3983 )
3984 })
3985 .await
3986 .unwrap();
3987
3988 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3989 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3990
3991 let provider = Arc::new(FakeLanguageModelProvider);
3992 let model = provider.test_model();
3993 let model: Arc<dyn LanguageModel> = Arc::new(model);
3994
3995 cx.update(|_, cx| {
3996 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3997 registry.set_default_model(
3998 Some(ConfiguredModel {
3999 provider: provider.clone(),
4000 model: model.clone(),
4001 }),
4002 cx,
4003 );
4004 registry.set_thread_summary_model(
4005 Some(ConfiguredModel {
4006 provider,
4007 model: model.clone(),
4008 }),
4009 cx,
4010 );
4011 })
4012 });
4013
4014 (workspace, thread_store, thread, context_store, model)
4015 }
4016
4017 async fn add_file_to_context(
4018 project: &Entity<Project>,
4019 context_store: &Entity<ContextStore>,
4020 path: &str,
4021 cx: &mut TestAppContext,
4022 ) -> Result<Entity<language::Buffer>> {
4023 let buffer_path = project
4024 .read_with(cx, |project, cx| project.find_project_path(path, cx))
4025 .unwrap();
4026
4027 let buffer = project
4028 .update(cx, |project, cx| {
4029 project.open_buffer(buffer_path.clone(), cx)
4030 })
4031 .await
4032 .unwrap();
4033
4034 context_store.update(cx, |context_store, cx| {
4035 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
4036 });
4037
4038 Ok(buffer)
4039 }
4040}