1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::{HashMap, HashSet};
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
28 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
29 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
30 TokenUsage,
31};
32use postage::stream::Stream as _;
33use project::{
34 Project,
35 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
36};
37use prompt_store::{ModelContext, PromptBuilder};
38use proto::Plan;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use settings::Settings;
42use std::{io::Write, ops::Range, sync::Arc, time::Instant};
43use thiserror::Error;
44use util::{ResultExt as _, post_inc};
45use uuid::Uuid;
46use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
47
48#[derive(
49 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
50)]
51pub struct ThreadId(Arc<str>);
52
53impl ThreadId {
54 pub fn new() -> Self {
55 Self(Uuid::new_v4().to_string().into())
56 }
57}
58
59impl std::fmt::Display for ThreadId {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65impl From<&str> for ThreadId {
66 fn from(value: &str) -> Self {
67 Self(value.into())
68 }
69}
70
71/// The ID of the user prompt that initiated a request.
72///
73/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
74#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
75pub struct PromptId(Arc<str>);
76
77impl PromptId {
78 pub fn new() -> Self {
79 Self(Uuid::new_v4().to_string().into())
80 }
81}
82
83impl std::fmt::Display for PromptId {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "{}", self.0)
86 }
87}
88
89#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
90pub struct MessageId(pub(crate) usize);
91
92impl MessageId {
93 fn post_inc(&mut self) -> Self {
94 Self(post_inc(&mut self.0))
95 }
96
97 pub fn as_usize(&self) -> usize {
98 self.0
99 }
100}
101
102/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
103#[derive(Clone, Debug)]
104pub struct MessageCrease {
105 pub range: Range<usize>,
106 pub icon_path: SharedString,
107 pub label: SharedString,
108 /// None for a deserialized message, Some otherwise.
109 pub context: Option<AgentContextHandle>,
110}
111
112/// A message in a [`Thread`].
113#[derive(Debug, Clone)]
114pub struct Message {
115 pub id: MessageId,
116 pub role: Role,
117 pub segments: Vec<MessageSegment>,
118 pub loaded_context: LoadedContext,
119 pub creases: Vec<MessageCrease>,
120 pub is_hidden: bool,
121}
122
123impl Message {
124 /// Returns whether the message contains any meaningful text that should be displayed
125 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
126 pub fn should_display_content(&self) -> bool {
127 self.segments.iter().all(|segment| segment.should_display())
128 }
129
130 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
131 if let Some(MessageSegment::Thinking {
132 text: segment,
133 signature: current_signature,
134 }) = self.segments.last_mut()
135 {
136 if let Some(signature) = signature {
137 *current_signature = Some(signature);
138 }
139 segment.push_str(text);
140 } else {
141 self.segments.push(MessageSegment::Thinking {
142 text: text.to_string(),
143 signature,
144 });
145 }
146 }
147
148 pub fn push_text(&mut self, text: &str) {
149 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
150 segment.push_str(text);
151 } else {
152 self.segments.push(MessageSegment::Text(text.to_string()));
153 }
154 }
155
156 pub fn to_string(&self) -> String {
157 let mut result = String::new();
158
159 if !self.loaded_context.text.is_empty() {
160 result.push_str(&self.loaded_context.text);
161 }
162
163 for segment in &self.segments {
164 match segment {
165 MessageSegment::Text(text) => result.push_str(text),
166 MessageSegment::Thinking { text, .. } => {
167 result.push_str("<think>\n");
168 result.push_str(text);
169 result.push_str("\n</think>");
170 }
171 MessageSegment::RedactedThinking(_) => {}
172 }
173 }
174
175 result
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
180pub enum MessageSegment {
181 Text(String),
182 Thinking {
183 text: String,
184 signature: Option<String>,
185 },
186 RedactedThinking(Vec<u8>),
187}
188
189impl MessageSegment {
190 pub fn should_display(&self) -> bool {
191 match self {
192 Self::Text(text) => text.is_empty(),
193 Self::Thinking { text, .. } => text.is_empty(),
194 Self::RedactedThinking(_) => false,
195 }
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
200pub struct ProjectSnapshot {
201 pub worktree_snapshots: Vec<WorktreeSnapshot>,
202 pub unsaved_buffer_paths: Vec<String>,
203 pub timestamp: DateTime<Utc>,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
207pub struct WorktreeSnapshot {
208 pub worktree_path: String,
209 pub git_state: Option<GitState>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
213pub struct GitState {
214 pub remote_url: Option<String>,
215 pub head_sha: Option<String>,
216 pub current_branch: Option<String>,
217 pub diff: Option<String>,
218}
219
220#[derive(Clone, Debug)]
221pub struct ThreadCheckpoint {
222 message_id: MessageId,
223 git_checkpoint: GitStoreCheckpoint,
224}
225
226#[derive(Copy, Clone, Debug, PartialEq, Eq)]
227pub enum ThreadFeedback {
228 Positive,
229 Negative,
230}
231
232pub enum LastRestoreCheckpoint {
233 Pending {
234 message_id: MessageId,
235 },
236 Error {
237 message_id: MessageId,
238 error: String,
239 },
240}
241
242impl LastRestoreCheckpoint {
243 pub fn message_id(&self) -> MessageId {
244 match self {
245 LastRestoreCheckpoint::Pending { message_id } => *message_id,
246 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
247 }
248 }
249}
250
251#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
252pub enum DetailedSummaryState {
253 #[default]
254 NotGenerated,
255 Generating {
256 message_id: MessageId,
257 },
258 Generated {
259 text: SharedString,
260 message_id: MessageId,
261 },
262}
263
264impl DetailedSummaryState {
265 fn text(&self) -> Option<SharedString> {
266 if let Self::Generated { text, .. } = self {
267 Some(text.clone())
268 } else {
269 None
270 }
271 }
272}
273
274#[derive(Default, Debug)]
275pub struct TotalTokenUsage {
276 pub total: u64,
277 pub max: u64,
278}
279
280impl TotalTokenUsage {
281 pub fn ratio(&self) -> TokenUsageRatio {
282 #[cfg(debug_assertions)]
283 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
284 .unwrap_or("0.8".to_string())
285 .parse()
286 .unwrap();
287 #[cfg(not(debug_assertions))]
288 let warning_threshold: f32 = 0.8;
289
290 // When the maximum is unknown because there is no selected model,
291 // avoid showing the token limit warning.
292 if self.max == 0 {
293 TokenUsageRatio::Normal
294 } else if self.total >= self.max {
295 TokenUsageRatio::Exceeded
296 } else if self.total as f32 / self.max as f32 >= warning_threshold {
297 TokenUsageRatio::Warning
298 } else {
299 TokenUsageRatio::Normal
300 }
301 }
302
303 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
304 TotalTokenUsage {
305 total: self.total + tokens,
306 max: self.max,
307 }
308 }
309}
310
311#[derive(Debug, Default, PartialEq, Eq)]
312pub enum TokenUsageRatio {
313 #[default]
314 Normal,
315 Warning,
316 Exceeded,
317}
318
319#[derive(Debug, Clone, Copy)]
320pub enum QueueState {
321 Sending,
322 Queued { position: usize },
323 Started,
324}
325
326/// A thread of conversation with the LLM.
327pub struct Thread {
328 id: ThreadId,
329 updated_at: DateTime<Utc>,
330 summary: ThreadSummary,
331 pending_summary: Task<Option<()>>,
332 detailed_summary_task: Task<Option<()>>,
333 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
334 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
335 completion_mode: agent_settings::CompletionMode,
336 messages: Vec<Message>,
337 next_message_id: MessageId,
338 last_prompt_id: PromptId,
339 project_context: SharedProjectContext,
340 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
341 completion_count: usize,
342 pending_completions: Vec<PendingCompletion>,
343 project: Entity<Project>,
344 prompt_builder: Arc<PromptBuilder>,
345 tools: Entity<ToolWorkingSet>,
346 tool_use: ToolUseState,
347 action_log: Entity<ActionLog>,
348 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
349 pending_checkpoint: Option<ThreadCheckpoint>,
350 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
351 request_token_usage: Vec<TokenUsage>,
352 cumulative_token_usage: TokenUsage,
353 exceeded_window_error: Option<ExceededWindowError>,
354 tool_use_limit_reached: bool,
355 feedback: Option<ThreadFeedback>,
356 message_feedback: HashMap<MessageId, ThreadFeedback>,
357 last_auto_capture_at: Option<Instant>,
358 last_received_chunk_at: Option<Instant>,
359 request_callback: Option<
360 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
361 >,
362 remaining_turns: u32,
363 configured_model: Option<ConfiguredModel>,
364 profile: AgentProfile,
365}
366
367#[derive(Clone, Debug, PartialEq, Eq)]
368pub enum ThreadSummary {
369 Pending,
370 Generating,
371 Ready(SharedString),
372 Error,
373}
374
375impl ThreadSummary {
376 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
377
378 pub fn or_default(&self) -> SharedString {
379 self.unwrap_or(Self::DEFAULT)
380 }
381
382 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
383 self.ready().unwrap_or_else(|| message.into())
384 }
385
386 pub fn ready(&self) -> Option<SharedString> {
387 match self {
388 ThreadSummary::Ready(summary) => Some(summary.clone()),
389 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
390 }
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
395pub struct ExceededWindowError {
396 /// Model used when last message exceeded context window
397 model_id: LanguageModelId,
398 /// Token count including last message
399 token_count: u64,
400}
401
402impl Thread {
403 pub fn new(
404 project: Entity<Project>,
405 tools: Entity<ToolWorkingSet>,
406 prompt_builder: Arc<PromptBuilder>,
407 system_prompt: SharedProjectContext,
408 cx: &mut Context<Self>,
409 ) -> Self {
410 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
411 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
412 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
413
414 Self {
415 id: ThreadId::new(),
416 updated_at: Utc::now(),
417 summary: ThreadSummary::Pending,
418 pending_summary: Task::ready(None),
419 detailed_summary_task: Task::ready(None),
420 detailed_summary_tx,
421 detailed_summary_rx,
422 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
423 messages: Vec::new(),
424 next_message_id: MessageId(0),
425 last_prompt_id: PromptId::new(),
426 project_context: system_prompt,
427 checkpoints_by_message: HashMap::default(),
428 completion_count: 0,
429 pending_completions: Vec::new(),
430 project: project.clone(),
431 prompt_builder,
432 tools: tools.clone(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 tool_use: ToolUseState::new(tools.clone()),
436 action_log: cx.new(|_| ActionLog::new(project.clone())),
437 initial_project_snapshot: {
438 let project_snapshot = Self::project_snapshot(project, cx);
439 cx.foreground_executor()
440 .spawn(async move { Some(project_snapshot.await) })
441 .shared()
442 },
443 request_token_usage: Vec::new(),
444 cumulative_token_usage: TokenUsage::default(),
445 exceeded_window_error: None,
446 tool_use_limit_reached: false,
447 feedback: None,
448 message_feedback: HashMap::default(),
449 last_auto_capture_at: None,
450 last_received_chunk_at: None,
451 request_callback: None,
452 remaining_turns: u32::MAX,
453 configured_model,
454 profile: AgentProfile::new(profile_id, tools),
455 }
456 }
457
458 pub fn deserialize(
459 id: ThreadId,
460 serialized: SerializedThread,
461 project: Entity<Project>,
462 tools: Entity<ToolWorkingSet>,
463 prompt_builder: Arc<PromptBuilder>,
464 project_context: SharedProjectContext,
465 window: Option<&mut Window>, // None in headless mode
466 cx: &mut Context<Self>,
467 ) -> Self {
468 let next_message_id = MessageId(
469 serialized
470 .messages
471 .last()
472 .map(|message| message.id.0 + 1)
473 .unwrap_or(0),
474 );
475 let tool_use = ToolUseState::from_serialized_messages(
476 tools.clone(),
477 &serialized.messages,
478 project.clone(),
479 window,
480 cx,
481 );
482 let (detailed_summary_tx, detailed_summary_rx) =
483 postage::watch::channel_with(serialized.detailed_summary_state);
484
485 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
486 serialized
487 .model
488 .and_then(|model| {
489 let model = SelectedModel {
490 provider: model.provider.clone().into(),
491 model: model.model.clone().into(),
492 };
493 registry.select_model(&model, cx)
494 })
495 .or_else(|| registry.default_model())
496 });
497
498 let completion_mode = serialized
499 .completion_mode
500 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
501 let profile_id = serialized
502 .profile
503 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
504
505 Self {
506 id,
507 updated_at: serialized.updated_at,
508 summary: ThreadSummary::Ready(serialized.summary),
509 pending_summary: Task::ready(None),
510 detailed_summary_task: Task::ready(None),
511 detailed_summary_tx,
512 detailed_summary_rx,
513 completion_mode,
514 messages: serialized
515 .messages
516 .into_iter()
517 .map(|message| Message {
518 id: message.id,
519 role: message.role,
520 segments: message
521 .segments
522 .into_iter()
523 .map(|segment| match segment {
524 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
525 SerializedMessageSegment::Thinking { text, signature } => {
526 MessageSegment::Thinking { text, signature }
527 }
528 SerializedMessageSegment::RedactedThinking { data } => {
529 MessageSegment::RedactedThinking(data)
530 }
531 })
532 .collect(),
533 loaded_context: LoadedContext {
534 contexts: Vec::new(),
535 text: message.context,
536 images: Vec::new(),
537 },
538 creases: message
539 .creases
540 .into_iter()
541 .map(|crease| MessageCrease {
542 range: crease.start..crease.end,
543 icon_path: crease.icon_path,
544 label: crease.label,
545 context: None,
546 })
547 .collect(),
548 is_hidden: message.is_hidden,
549 })
550 .collect(),
551 next_message_id,
552 last_prompt_id: PromptId::new(),
553 project_context,
554 checkpoints_by_message: HashMap::default(),
555 completion_count: 0,
556 pending_completions: Vec::new(),
557 last_restore_checkpoint: None,
558 pending_checkpoint: None,
559 project: project.clone(),
560 prompt_builder,
561 tools: tools.clone(),
562 tool_use,
563 action_log: cx.new(|_| ActionLog::new(project)),
564 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
565 request_token_usage: serialized.request_token_usage,
566 cumulative_token_usage: serialized.cumulative_token_usage,
567 exceeded_window_error: None,
568 tool_use_limit_reached: serialized.tool_use_limit_reached,
569 feedback: None,
570 message_feedback: HashMap::default(),
571 last_auto_capture_at: None,
572 last_received_chunk_at: None,
573 request_callback: None,
574 remaining_turns: u32::MAX,
575 configured_model,
576 profile: AgentProfile::new(profile_id, tools),
577 }
578 }
579
580 pub fn set_request_callback(
581 &mut self,
582 callback: impl 'static
583 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
584 ) {
585 self.request_callback = Some(Box::new(callback));
586 }
587
588 pub fn id(&self) -> &ThreadId {
589 &self.id
590 }
591
592 pub fn profile(&self) -> &AgentProfile {
593 &self.profile
594 }
595
596 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
597 if &id != self.profile.id() {
598 self.profile = AgentProfile::new(id, self.tools.clone());
599 cx.emit(ThreadEvent::ProfileChanged);
600 }
601 }
602
603 pub fn is_empty(&self) -> bool {
604 self.messages.is_empty()
605 }
606
607 pub fn updated_at(&self) -> DateTime<Utc> {
608 self.updated_at
609 }
610
611 pub fn touch_updated_at(&mut self) {
612 self.updated_at = Utc::now();
613 }
614
615 pub fn advance_prompt_id(&mut self) {
616 self.last_prompt_id = PromptId::new();
617 }
618
619 pub fn project_context(&self) -> SharedProjectContext {
620 self.project_context.clone()
621 }
622
623 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
624 if self.configured_model.is_none() {
625 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
626 }
627 self.configured_model.clone()
628 }
629
630 pub fn configured_model(&self) -> Option<ConfiguredModel> {
631 self.configured_model.clone()
632 }
633
634 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
635 self.configured_model = model;
636 cx.notify();
637 }
638
639 pub fn summary(&self) -> &ThreadSummary {
640 &self.summary
641 }
642
643 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
644 let current_summary = match &self.summary {
645 ThreadSummary::Pending | ThreadSummary::Generating => return,
646 ThreadSummary::Ready(summary) => summary,
647 ThreadSummary::Error => &ThreadSummary::DEFAULT,
648 };
649
650 let mut new_summary = new_summary.into();
651
652 if new_summary.is_empty() {
653 new_summary = ThreadSummary::DEFAULT;
654 }
655
656 if current_summary != &new_summary {
657 self.summary = ThreadSummary::Ready(new_summary);
658 cx.emit(ThreadEvent::SummaryChanged);
659 }
660 }
661
662 pub fn completion_mode(&self) -> CompletionMode {
663 self.completion_mode
664 }
665
666 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
667 self.completion_mode = mode;
668 }
669
670 pub fn message(&self, id: MessageId) -> Option<&Message> {
671 let index = self
672 .messages
673 .binary_search_by(|message| message.id.cmp(&id))
674 .ok()?;
675
676 self.messages.get(index)
677 }
678
679 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
680 self.messages.iter()
681 }
682
683 pub fn is_generating(&self) -> bool {
684 !self.pending_completions.is_empty() || !self.all_tools_finished()
685 }
686
687 /// Indicates whether streaming of language model events is stale.
688 /// When `is_generating()` is false, this method returns `None`.
689 pub fn is_generation_stale(&self) -> Option<bool> {
690 const STALE_THRESHOLD: u128 = 250;
691
692 self.last_received_chunk_at
693 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
694 }
695
696 fn received_chunk(&mut self) {
697 self.last_received_chunk_at = Some(Instant::now());
698 }
699
700 pub fn queue_state(&self) -> Option<QueueState> {
701 self.pending_completions
702 .first()
703 .map(|pending_completion| pending_completion.queue_state)
704 }
705
706 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
707 &self.tools
708 }
709
710 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
711 self.tool_use
712 .pending_tool_uses()
713 .into_iter()
714 .find(|tool_use| &tool_use.id == id)
715 }
716
717 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
718 self.tool_use
719 .pending_tool_uses()
720 .into_iter()
721 .filter(|tool_use| tool_use.status.needs_confirmation())
722 }
723
724 pub fn has_pending_tool_uses(&self) -> bool {
725 !self.tool_use.pending_tool_uses().is_empty()
726 }
727
728 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
729 self.checkpoints_by_message.get(&id).cloned()
730 }
731
732 pub fn restore_checkpoint(
733 &mut self,
734 checkpoint: ThreadCheckpoint,
735 cx: &mut Context<Self>,
736 ) -> Task<Result<()>> {
737 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
738 message_id: checkpoint.message_id,
739 });
740 cx.emit(ThreadEvent::CheckpointChanged);
741 cx.notify();
742
743 let git_store = self.project().read(cx).git_store().clone();
744 let restore = git_store.update(cx, |git_store, cx| {
745 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
746 });
747
748 cx.spawn(async move |this, cx| {
749 let result = restore.await;
750 this.update(cx, |this, cx| {
751 if let Err(err) = result.as_ref() {
752 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
753 message_id: checkpoint.message_id,
754 error: err.to_string(),
755 });
756 } else {
757 this.truncate(checkpoint.message_id, cx);
758 this.last_restore_checkpoint = None;
759 }
760 this.pending_checkpoint = None;
761 cx.emit(ThreadEvent::CheckpointChanged);
762 cx.notify();
763 })?;
764 result
765 })
766 }
767
768 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
769 let pending_checkpoint = if self.is_generating() {
770 return;
771 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
772 checkpoint
773 } else {
774 return;
775 };
776
777 self.finalize_checkpoint(pending_checkpoint, cx);
778 }
779
780 fn finalize_checkpoint(
781 &mut self,
782 pending_checkpoint: ThreadCheckpoint,
783 cx: &mut Context<Self>,
784 ) {
785 let git_store = self.project.read(cx).git_store().clone();
786 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
787 cx.spawn(async move |this, cx| match final_checkpoint.await {
788 Ok(final_checkpoint) => {
789 let equal = git_store
790 .update(cx, |store, cx| {
791 store.compare_checkpoints(
792 pending_checkpoint.git_checkpoint.clone(),
793 final_checkpoint.clone(),
794 cx,
795 )
796 })?
797 .await
798 .unwrap_or(false);
799
800 if !equal {
801 this.update(cx, |this, cx| {
802 this.insert_checkpoint(pending_checkpoint, cx)
803 })?;
804 }
805
806 Ok(())
807 }
808 Err(_) => this.update(cx, |this, cx| {
809 this.insert_checkpoint(pending_checkpoint, cx)
810 }),
811 })
812 .detach();
813 }
814
815 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
816 self.checkpoints_by_message
817 .insert(checkpoint.message_id, checkpoint);
818 cx.emit(ThreadEvent::CheckpointChanged);
819 cx.notify();
820 }
821
822 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
823 self.last_restore_checkpoint.as_ref()
824 }
825
826 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
827 let Some(message_ix) = self
828 .messages
829 .iter()
830 .rposition(|message| message.id == message_id)
831 else {
832 return;
833 };
834 for deleted_message in self.messages.drain(message_ix..) {
835 self.checkpoints_by_message.remove(&deleted_message.id);
836 }
837 cx.notify();
838 }
839
840 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
841 self.messages
842 .iter()
843 .find(|message| message.id == id)
844 .into_iter()
845 .flat_map(|message| message.loaded_context.contexts.iter())
846 }
847
848 pub fn is_turn_end(&self, ix: usize) -> bool {
849 if self.messages.is_empty() {
850 return false;
851 }
852
853 if !self.is_generating() && ix == self.messages.len() - 1 {
854 return true;
855 }
856
857 let Some(message) = self.messages.get(ix) else {
858 return false;
859 };
860
861 if message.role != Role::Assistant {
862 return false;
863 }
864
865 self.messages
866 .get(ix + 1)
867 .and_then(|message| {
868 self.message(message.id)
869 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
870 })
871 .unwrap_or(false)
872 }
873
874 pub fn tool_use_limit_reached(&self) -> bool {
875 self.tool_use_limit_reached
876 }
877
878 /// Returns whether all of the tool uses have finished running.
879 pub fn all_tools_finished(&self) -> bool {
880 // If the only pending tool uses left are the ones with errors, then
881 // that means that we've finished running all of the pending tools.
882 self.tool_use
883 .pending_tool_uses()
884 .iter()
885 .all(|pending_tool_use| pending_tool_use.status.is_error())
886 }
887
888 /// Returns whether any pending tool uses may perform edits
889 pub fn has_pending_edit_tool_uses(&self) -> bool {
890 self.tool_use
891 .pending_tool_uses()
892 .iter()
893 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
894 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
895 }
896
897 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
898 self.tool_use.tool_uses_for_message(id, cx)
899 }
900
901 pub fn tool_results_for_message(
902 &self,
903 assistant_message_id: MessageId,
904 ) -> Vec<&LanguageModelToolResult> {
905 self.tool_use.tool_results_for_message(assistant_message_id)
906 }
907
908 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
909 self.tool_use.tool_result(id)
910 }
911
912 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
913 match &self.tool_use.tool_result(id)?.content {
914 LanguageModelToolResultContent::Text(text) => Some(text),
915 LanguageModelToolResultContent::Image(_) => {
916 // TODO: We should display image
917 None
918 }
919 }
920 }
921
922 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
923 self.tool_use.tool_result_card(id).cloned()
924 }
925
926 /// Return tools that are both enabled and supported by the model
927 pub fn available_tools(
928 &self,
929 cx: &App,
930 model: Arc<dyn LanguageModel>,
931 ) -> Vec<LanguageModelRequestTool> {
932 if model.supports_tools() {
933 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
934 .into_iter()
935 .filter_map(|(name, tool)| {
936 // Skip tools that cannot be supported
937 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
938 Some(LanguageModelRequestTool {
939 name,
940 description: tool.description(),
941 input_schema,
942 })
943 })
944 .collect()
945 } else {
946 Vec::default()
947 }
948 }
949
950 pub fn insert_user_message(
951 &mut self,
952 text: impl Into<String>,
953 loaded_context: ContextLoadResult,
954 git_checkpoint: Option<GitStoreCheckpoint>,
955 creases: Vec<MessageCrease>,
956 cx: &mut Context<Self>,
957 ) -> MessageId {
958 if !loaded_context.referenced_buffers.is_empty() {
959 self.action_log.update(cx, |log, cx| {
960 for buffer in loaded_context.referenced_buffers {
961 log.buffer_read(buffer, cx);
962 }
963 });
964 }
965
966 let message_id = self.insert_message(
967 Role::User,
968 vec![MessageSegment::Text(text.into())],
969 loaded_context.loaded_context,
970 creases,
971 false,
972 cx,
973 );
974
975 if let Some(git_checkpoint) = git_checkpoint {
976 self.pending_checkpoint = Some(ThreadCheckpoint {
977 message_id,
978 git_checkpoint,
979 });
980 }
981
982 self.auto_capture_telemetry(cx);
983
984 message_id
985 }
986
987 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
988 let id = self.insert_message(
989 Role::User,
990 vec![MessageSegment::Text("Continue where you left off".into())],
991 LoadedContext::default(),
992 vec![],
993 true,
994 cx,
995 );
996 self.pending_checkpoint = None;
997
998 id
999 }
1000
1001 pub fn insert_assistant_message(
1002 &mut self,
1003 segments: Vec<MessageSegment>,
1004 cx: &mut Context<Self>,
1005 ) -> MessageId {
1006 self.insert_message(
1007 Role::Assistant,
1008 segments,
1009 LoadedContext::default(),
1010 Vec::new(),
1011 false,
1012 cx,
1013 )
1014 }
1015
1016 pub fn insert_message(
1017 &mut self,
1018 role: Role,
1019 segments: Vec<MessageSegment>,
1020 loaded_context: LoadedContext,
1021 creases: Vec<MessageCrease>,
1022 is_hidden: bool,
1023 cx: &mut Context<Self>,
1024 ) -> MessageId {
1025 let id = self.next_message_id.post_inc();
1026 self.messages.push(Message {
1027 id,
1028 role,
1029 segments,
1030 loaded_context,
1031 creases,
1032 is_hidden,
1033 });
1034 self.touch_updated_at();
1035 cx.emit(ThreadEvent::MessageAdded(id));
1036 id
1037 }
1038
1039 pub fn edit_message(
1040 &mut self,
1041 id: MessageId,
1042 new_role: Role,
1043 new_segments: Vec<MessageSegment>,
1044 creases: Vec<MessageCrease>,
1045 loaded_context: Option<LoadedContext>,
1046 checkpoint: Option<GitStoreCheckpoint>,
1047 cx: &mut Context<Self>,
1048 ) -> bool {
1049 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1050 return false;
1051 };
1052 message.role = new_role;
1053 message.segments = new_segments;
1054 message.creases = creases;
1055 if let Some(context) = loaded_context {
1056 message.loaded_context = context;
1057 }
1058 if let Some(git_checkpoint) = checkpoint {
1059 self.checkpoints_by_message.insert(
1060 id,
1061 ThreadCheckpoint {
1062 message_id: id,
1063 git_checkpoint,
1064 },
1065 );
1066 }
1067 self.touch_updated_at();
1068 cx.emit(ThreadEvent::MessageEdited(id));
1069 true
1070 }
1071
1072 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1073 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1074 return false;
1075 };
1076 self.messages.remove(index);
1077 self.touch_updated_at();
1078 cx.emit(ThreadEvent::MessageDeleted(id));
1079 true
1080 }
1081
1082 /// Returns the representation of this [`Thread`] in a textual form.
1083 ///
1084 /// This is the representation we use when attaching a thread as context to another thread.
1085 pub fn text(&self) -> String {
1086 let mut text = String::new();
1087
1088 for message in &self.messages {
1089 text.push_str(match message.role {
1090 language_model::Role::User => "User:",
1091 language_model::Role::Assistant => "Agent:",
1092 language_model::Role::System => "System:",
1093 });
1094 text.push('\n');
1095
1096 for segment in &message.segments {
1097 match segment {
1098 MessageSegment::Text(content) => text.push_str(content),
1099 MessageSegment::Thinking { text: content, .. } => {
1100 text.push_str(&format!("<think>{}</think>", content))
1101 }
1102 MessageSegment::RedactedThinking(_) => {}
1103 }
1104 }
1105 text.push('\n');
1106 }
1107
1108 text
1109 }
1110
1111 /// Serializes this thread into a format for storage or telemetry.
1112 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1113 let initial_project_snapshot = self.initial_project_snapshot.clone();
1114 cx.spawn(async move |this, cx| {
1115 let initial_project_snapshot = initial_project_snapshot.await;
1116 this.read_with(cx, |this, cx| SerializedThread {
1117 version: SerializedThread::VERSION.to_string(),
1118 summary: this.summary().or_default(),
1119 updated_at: this.updated_at(),
1120 messages: this
1121 .messages()
1122 .map(|message| SerializedMessage {
1123 id: message.id,
1124 role: message.role,
1125 segments: message
1126 .segments
1127 .iter()
1128 .map(|segment| match segment {
1129 MessageSegment::Text(text) => {
1130 SerializedMessageSegment::Text { text: text.clone() }
1131 }
1132 MessageSegment::Thinking { text, signature } => {
1133 SerializedMessageSegment::Thinking {
1134 text: text.clone(),
1135 signature: signature.clone(),
1136 }
1137 }
1138 MessageSegment::RedactedThinking(data) => {
1139 SerializedMessageSegment::RedactedThinking {
1140 data: data.clone(),
1141 }
1142 }
1143 })
1144 .collect(),
1145 tool_uses: this
1146 .tool_uses_for_message(message.id, cx)
1147 .into_iter()
1148 .map(|tool_use| SerializedToolUse {
1149 id: tool_use.id,
1150 name: tool_use.name,
1151 input: tool_use.input,
1152 })
1153 .collect(),
1154 tool_results: this
1155 .tool_results_for_message(message.id)
1156 .into_iter()
1157 .map(|tool_result| SerializedToolResult {
1158 tool_use_id: tool_result.tool_use_id.clone(),
1159 is_error: tool_result.is_error,
1160 content: tool_result.content.clone(),
1161 output: tool_result.output.clone(),
1162 })
1163 .collect(),
1164 context: message.loaded_context.text.clone(),
1165 creases: message
1166 .creases
1167 .iter()
1168 .map(|crease| SerializedCrease {
1169 start: crease.range.start,
1170 end: crease.range.end,
1171 icon_path: crease.icon_path.clone(),
1172 label: crease.label.clone(),
1173 })
1174 .collect(),
1175 is_hidden: message.is_hidden,
1176 })
1177 .collect(),
1178 initial_project_snapshot,
1179 cumulative_token_usage: this.cumulative_token_usage,
1180 request_token_usage: this.request_token_usage.clone(),
1181 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1182 exceeded_window_error: this.exceeded_window_error.clone(),
1183 model: this
1184 .configured_model
1185 .as_ref()
1186 .map(|model| SerializedLanguageModel {
1187 provider: model.provider.id().0.to_string(),
1188 model: model.model.id().0.to_string(),
1189 }),
1190 completion_mode: Some(this.completion_mode),
1191 tool_use_limit_reached: this.tool_use_limit_reached,
1192 profile: Some(this.profile.id().clone()),
1193 })
1194 })
1195 }
1196
1197 pub fn remaining_turns(&self) -> u32 {
1198 self.remaining_turns
1199 }
1200
1201 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1202 self.remaining_turns = remaining_turns;
1203 }
1204
1205 pub fn send_to_model(
1206 &mut self,
1207 model: Arc<dyn LanguageModel>,
1208 intent: CompletionIntent,
1209 window: Option<AnyWindowHandle>,
1210 cx: &mut Context<Self>,
1211 ) {
1212 if self.remaining_turns == 0 {
1213 return;
1214 }
1215
1216 self.remaining_turns -= 1;
1217
1218 let request = self.to_completion_request(model.clone(), intent, cx);
1219
1220 self.stream_completion(request, model, window, cx);
1221 }
1222
1223 pub fn used_tools_since_last_user_message(&self) -> bool {
1224 for message in self.messages.iter().rev() {
1225 if self.tool_use.message_has_tool_results(message.id) {
1226 return true;
1227 } else if message.role == Role::User {
1228 return false;
1229 }
1230 }
1231
1232 false
1233 }
1234
1235 pub fn to_completion_request(
1236 &self,
1237 model: Arc<dyn LanguageModel>,
1238 intent: CompletionIntent,
1239 cx: &mut Context<Self>,
1240 ) -> LanguageModelRequest {
1241 let mut request = LanguageModelRequest {
1242 thread_id: Some(self.id.to_string()),
1243 prompt_id: Some(self.last_prompt_id.to_string()),
1244 intent: Some(intent),
1245 mode: None,
1246 messages: vec![],
1247 tools: Vec::new(),
1248 tool_choice: None,
1249 stop: Vec::new(),
1250 temperature: AgentSettings::temperature_for_model(&model, cx),
1251 };
1252
1253 let available_tools = self.available_tools(cx, model.clone());
1254 let available_tool_names = available_tools
1255 .iter()
1256 .map(|tool| tool.name.clone())
1257 .collect();
1258
1259 let model_context = &ModelContext {
1260 available_tools: available_tool_names,
1261 };
1262
1263 if let Some(project_context) = self.project_context.borrow().as_ref() {
1264 match self
1265 .prompt_builder
1266 .generate_assistant_system_prompt(project_context, model_context)
1267 {
1268 Err(err) => {
1269 let message = format!("{err:?}").into();
1270 log::error!("{message}");
1271 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1272 header: "Error generating system prompt".into(),
1273 message,
1274 }));
1275 }
1276 Ok(system_prompt) => {
1277 request.messages.push(LanguageModelRequestMessage {
1278 role: Role::System,
1279 content: vec![MessageContent::Text(system_prompt)],
1280 cache: true,
1281 });
1282 }
1283 }
1284 } else {
1285 let message = "Context for system prompt unexpectedly not ready.".into();
1286 log::error!("{message}");
1287 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1288 header: "Error generating system prompt".into(),
1289 message,
1290 }));
1291 }
1292
1293 let mut message_ix_to_cache = None;
1294 for message in &self.messages {
1295 let mut request_message = LanguageModelRequestMessage {
1296 role: message.role,
1297 content: Vec::new(),
1298 cache: false,
1299 };
1300
1301 message
1302 .loaded_context
1303 .add_to_request_message(&mut request_message);
1304
1305 for segment in &message.segments {
1306 match segment {
1307 MessageSegment::Text(text) => {
1308 if !text.is_empty() {
1309 request_message
1310 .content
1311 .push(MessageContent::Text(text.into()));
1312 }
1313 }
1314 MessageSegment::Thinking { text, signature } => {
1315 if !text.is_empty() {
1316 request_message.content.push(MessageContent::Thinking {
1317 text: text.into(),
1318 signature: signature.clone(),
1319 });
1320 }
1321 }
1322 MessageSegment::RedactedThinking(data) => {
1323 request_message
1324 .content
1325 .push(MessageContent::RedactedThinking(data.clone()));
1326 }
1327 };
1328 }
1329
1330 let mut cache_message = true;
1331 let mut tool_results_message = LanguageModelRequestMessage {
1332 role: Role::User,
1333 content: Vec::new(),
1334 cache: false,
1335 };
1336 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1337 if let Some(tool_result) = tool_result {
1338 request_message
1339 .content
1340 .push(MessageContent::ToolUse(tool_use.clone()));
1341 tool_results_message
1342 .content
1343 .push(MessageContent::ToolResult(LanguageModelToolResult {
1344 tool_use_id: tool_use.id.clone(),
1345 tool_name: tool_result.tool_name.clone(),
1346 is_error: tool_result.is_error,
1347 content: if tool_result.content.is_empty() {
1348 // Surprisingly, the API fails if we return an empty string here.
1349 // It thinks we are sending a tool use without a tool result.
1350 "<Tool returned an empty string>".into()
1351 } else {
1352 tool_result.content.clone()
1353 },
1354 output: None,
1355 }));
1356 } else {
1357 cache_message = false;
1358 log::debug!(
1359 "skipped tool use {:?} because it is still pending",
1360 tool_use
1361 );
1362 }
1363 }
1364
1365 if cache_message {
1366 message_ix_to_cache = Some(request.messages.len());
1367 }
1368 request.messages.push(request_message);
1369
1370 if !tool_results_message.content.is_empty() {
1371 if cache_message {
1372 message_ix_to_cache = Some(request.messages.len());
1373 }
1374 request.messages.push(tool_results_message);
1375 }
1376 }
1377
1378 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1379 if let Some(message_ix_to_cache) = message_ix_to_cache {
1380 request.messages[message_ix_to_cache].cache = true;
1381 }
1382
1383 request.tools = available_tools;
1384 request.mode = if model.supports_max_mode() {
1385 Some(self.completion_mode.into())
1386 } else {
1387 Some(CompletionMode::Normal.into())
1388 };
1389
1390 request
1391 }
1392
1393 fn to_summarize_request(
1394 &self,
1395 model: &Arc<dyn LanguageModel>,
1396 intent: CompletionIntent,
1397 added_user_message: String,
1398 cx: &App,
1399 ) -> LanguageModelRequest {
1400 let mut request = LanguageModelRequest {
1401 thread_id: None,
1402 prompt_id: None,
1403 intent: Some(intent),
1404 mode: None,
1405 messages: vec![],
1406 tools: Vec::new(),
1407 tool_choice: None,
1408 stop: Vec::new(),
1409 temperature: AgentSettings::temperature_for_model(model, cx),
1410 };
1411
1412 for message in &self.messages {
1413 let mut request_message = LanguageModelRequestMessage {
1414 role: message.role,
1415 content: Vec::new(),
1416 cache: false,
1417 };
1418
1419 for segment in &message.segments {
1420 match segment {
1421 MessageSegment::Text(text) => request_message
1422 .content
1423 .push(MessageContent::Text(text.clone())),
1424 MessageSegment::Thinking { .. } => {}
1425 MessageSegment::RedactedThinking(_) => {}
1426 }
1427 }
1428
1429 if request_message.content.is_empty() {
1430 continue;
1431 }
1432
1433 request.messages.push(request_message);
1434 }
1435
1436 request.messages.push(LanguageModelRequestMessage {
1437 role: Role::User,
1438 content: vec![MessageContent::Text(added_user_message)],
1439 cache: false,
1440 });
1441
1442 request
1443 }
1444
1445 pub fn stream_completion(
1446 &mut self,
1447 request: LanguageModelRequest,
1448 model: Arc<dyn LanguageModel>,
1449 window: Option<AnyWindowHandle>,
1450 cx: &mut Context<Self>,
1451 ) {
1452 self.tool_use_limit_reached = false;
1453
1454 let pending_completion_id = post_inc(&mut self.completion_count);
1455 let mut request_callback_parameters = if self.request_callback.is_some() {
1456 Some((request.clone(), Vec::new()))
1457 } else {
1458 None
1459 };
1460 let prompt_id = self.last_prompt_id.clone();
1461 let tool_use_metadata = ToolUseMetadata {
1462 model: model.clone(),
1463 thread_id: self.id.clone(),
1464 prompt_id: prompt_id.clone(),
1465 };
1466
1467 self.last_received_chunk_at = Some(Instant::now());
1468
1469 let task = cx.spawn(async move |thread, cx| {
1470 let stream_completion_future = model.stream_completion(request, &cx);
1471 let initial_token_usage =
1472 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1473 let stream_completion = async {
1474 let mut events = stream_completion_future.await?;
1475
1476 let mut stop_reason = StopReason::EndTurn;
1477 let mut current_token_usage = TokenUsage::default();
1478
1479 thread
1480 .update(cx, |_thread, cx| {
1481 cx.emit(ThreadEvent::NewRequest);
1482 })
1483 .ok();
1484
1485 let mut request_assistant_message_id = None;
1486
1487 while let Some(event) = events.next().await {
1488 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1489 response_events
1490 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1491 }
1492
1493 thread.update(cx, |thread, cx| {
1494 let event = match event {
1495 Ok(event) => event,
1496 Err(error) => {
1497 match error {
1498 LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
1499 anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
1500 }
1501 LanguageModelCompletionError::Overloaded => {
1502 anyhow::bail!(LanguageModelKnownError::Overloaded);
1503 }
1504 LanguageModelCompletionError::ApiInternalServerError =>{
1505 anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
1506 }
1507 LanguageModelCompletionError::PromptTooLarge { tokens } => {
1508 let tokens = tokens.unwrap_or_else(|| {
1509 // We didn't get an exact token count from the API, so fall back on our estimate.
1510 thread.total_token_usage()
1511 .map(|usage| usage.total)
1512 .unwrap_or(0)
1513 // We know the context window was exceeded in practice, so if our estimate was
1514 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1515 .max(model.max_token_count().saturating_add(1))
1516 });
1517
1518 anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
1519 }
1520 LanguageModelCompletionError::ApiReadResponseError(io_error) => {
1521 anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
1522 }
1523 LanguageModelCompletionError::UnknownResponseFormat(error) => {
1524 anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
1525 }
1526 LanguageModelCompletionError::HttpResponseError { status, ref body } => {
1527 if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
1528 anyhow::bail!(known_error);
1529 } else {
1530 return Err(error.into());
1531 }
1532 }
1533 LanguageModelCompletionError::DeserializeResponse(error) => {
1534 anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
1535 }
1536 LanguageModelCompletionError::BadInputJson {
1537 id,
1538 tool_name,
1539 raw_input: invalid_input_json,
1540 json_parse_error,
1541 } => {
1542 thread.receive_invalid_tool_json(
1543 id,
1544 tool_name,
1545 invalid_input_json,
1546 json_parse_error,
1547 window,
1548 cx,
1549 );
1550 return Ok(());
1551 }
1552 // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
1553 err @ LanguageModelCompletionError::BadRequestFormat |
1554 err @ LanguageModelCompletionError::AuthenticationError |
1555 err @ LanguageModelCompletionError::PermissionError |
1556 err @ LanguageModelCompletionError::ApiEndpointNotFound |
1557 err @ LanguageModelCompletionError::SerializeRequest(_) |
1558 err @ LanguageModelCompletionError::BuildRequestBody(_) |
1559 err @ LanguageModelCompletionError::HttpSend(_) => {
1560 anyhow::bail!(err);
1561 }
1562 LanguageModelCompletionError::Other(error) => {
1563 return Err(error);
1564 }
1565 }
1566 }
1567 };
1568
1569 match event {
1570 LanguageModelCompletionEvent::StartMessage { .. } => {
1571 request_assistant_message_id =
1572 Some(thread.insert_assistant_message(
1573 vec![MessageSegment::Text(String::new())],
1574 cx,
1575 ));
1576 }
1577 LanguageModelCompletionEvent::Stop(reason) => {
1578 stop_reason = reason;
1579 }
1580 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1581 thread.update_token_usage_at_last_message(token_usage);
1582 thread.cumulative_token_usage = thread.cumulative_token_usage
1583 + token_usage
1584 - current_token_usage;
1585 current_token_usage = token_usage;
1586 }
1587 LanguageModelCompletionEvent::Text(chunk) => {
1588 thread.received_chunk();
1589
1590 cx.emit(ThreadEvent::ReceivedTextChunk);
1591 if let Some(last_message) = thread.messages.last_mut() {
1592 if last_message.role == Role::Assistant
1593 && !thread.tool_use.has_tool_results(last_message.id)
1594 {
1595 last_message.push_text(&chunk);
1596 cx.emit(ThreadEvent::StreamedAssistantText(
1597 last_message.id,
1598 chunk,
1599 ));
1600 } else {
1601 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1602 // of a new Assistant response.
1603 //
1604 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1605 // will result in duplicating the text of the chunk in the rendered Markdown.
1606 request_assistant_message_id =
1607 Some(thread.insert_assistant_message(
1608 vec![MessageSegment::Text(chunk.to_string())],
1609 cx,
1610 ));
1611 };
1612 }
1613 }
1614 LanguageModelCompletionEvent::Thinking {
1615 text: chunk,
1616 signature,
1617 } => {
1618 thread.received_chunk();
1619
1620 if let Some(last_message) = thread.messages.last_mut() {
1621 if last_message.role == Role::Assistant
1622 && !thread.tool_use.has_tool_results(last_message.id)
1623 {
1624 last_message.push_thinking(&chunk, signature);
1625 cx.emit(ThreadEvent::StreamedAssistantThinking(
1626 last_message.id,
1627 chunk,
1628 ));
1629 } else {
1630 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1631 // of a new Assistant response.
1632 //
1633 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1634 // will result in duplicating the text of the chunk in the rendered Markdown.
1635 request_assistant_message_id =
1636 Some(thread.insert_assistant_message(
1637 vec![MessageSegment::Thinking {
1638 text: chunk.to_string(),
1639 signature,
1640 }],
1641 cx,
1642 ));
1643 };
1644 }
1645 }
1646 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1647 let last_assistant_message_id = request_assistant_message_id
1648 .unwrap_or_else(|| {
1649 let new_assistant_message_id =
1650 thread.insert_assistant_message(vec![], cx);
1651 request_assistant_message_id =
1652 Some(new_assistant_message_id);
1653 new_assistant_message_id
1654 });
1655
1656 let tool_use_id = tool_use.id.clone();
1657 let streamed_input = if tool_use.is_input_complete {
1658 None
1659 } else {
1660 Some((&tool_use.input).clone())
1661 };
1662
1663 let ui_text = thread.tool_use.request_tool_use(
1664 last_assistant_message_id,
1665 tool_use,
1666 tool_use_metadata.clone(),
1667 cx,
1668 );
1669
1670 if let Some(input) = streamed_input {
1671 cx.emit(ThreadEvent::StreamedToolUse {
1672 tool_use_id,
1673 ui_text,
1674 input,
1675 });
1676 }
1677 }
1678 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1679 if let Some(completion) = thread
1680 .pending_completions
1681 .iter_mut()
1682 .find(|completion| completion.id == pending_completion_id)
1683 {
1684 match status_update {
1685 CompletionRequestStatus::Queued {
1686 position,
1687 } => {
1688 completion.queue_state = QueueState::Queued { position };
1689 }
1690 CompletionRequestStatus::Started => {
1691 completion.queue_state = QueueState::Started;
1692 }
1693 CompletionRequestStatus::Failed {
1694 code, message, request_id
1695 } => {
1696 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1697 }
1698 CompletionRequestStatus::UsageUpdated {
1699 amount, limit
1700 } => {
1701 thread.update_model_request_usage(amount as u32, limit, cx);
1702 }
1703 CompletionRequestStatus::ToolUseLimitReached => {
1704 thread.tool_use_limit_reached = true;
1705 cx.emit(ThreadEvent::ToolUseLimitReached);
1706 }
1707 }
1708 }
1709 }
1710 }
1711
1712 thread.touch_updated_at();
1713 cx.emit(ThreadEvent::StreamedCompletion);
1714 cx.notify();
1715
1716 thread.auto_capture_telemetry(cx);
1717 Ok(())
1718 })??;
1719
1720 smol::future::yield_now().await;
1721 }
1722
1723 thread.update(cx, |thread, cx| {
1724 thread.last_received_chunk_at = None;
1725 thread
1726 .pending_completions
1727 .retain(|completion| completion.id != pending_completion_id);
1728
1729 // If there is a response without tool use, summarize the message. Otherwise,
1730 // allow two tool uses before summarizing.
1731 if matches!(thread.summary, ThreadSummary::Pending)
1732 && thread.messages.len() >= 2
1733 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1734 {
1735 thread.summarize(cx);
1736 }
1737 })?;
1738
1739 anyhow::Ok(stop_reason)
1740 };
1741
1742 let result = stream_completion.await;
1743
1744 thread
1745 .update(cx, |thread, cx| {
1746 thread.finalize_pending_checkpoint(cx);
1747 match result.as_ref() {
1748 Ok(stop_reason) => match stop_reason {
1749 StopReason::ToolUse => {
1750 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1751 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1752 }
1753 StopReason::EndTurn | StopReason::MaxTokens => {
1754 thread.project.update(cx, |project, cx| {
1755 project.set_agent_location(None, cx);
1756 });
1757 }
1758 StopReason::Refusal => {
1759 thread.project.update(cx, |project, cx| {
1760 project.set_agent_location(None, cx);
1761 });
1762
1763 // Remove the turn that was refused.
1764 //
1765 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1766 {
1767 let mut messages_to_remove = Vec::new();
1768
1769 for (ix, message) in thread.messages.iter().enumerate().rev() {
1770 messages_to_remove.push(message.id);
1771
1772 if message.role == Role::User {
1773 if ix == 0 {
1774 break;
1775 }
1776
1777 if let Some(prev_message) = thread.messages.get(ix - 1) {
1778 if prev_message.role == Role::Assistant {
1779 break;
1780 }
1781 }
1782 }
1783 }
1784
1785 for message_id in messages_to_remove {
1786 thread.delete_message(message_id, cx);
1787 }
1788 }
1789
1790 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1791 header: "Language model refusal".into(),
1792 message: "Model refused to generate content for safety reasons.".into(),
1793 }));
1794 }
1795 },
1796 Err(error) => {
1797 thread.project.update(cx, |project, cx| {
1798 project.set_agent_location(None, cx);
1799 });
1800
1801 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1802 let error_message = error
1803 .chain()
1804 .map(|err| err.to_string())
1805 .collect::<Vec<_>>()
1806 .join("\n");
1807 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1808 header: "Error interacting with language model".into(),
1809 message: SharedString::from(error_message.clone()),
1810 }));
1811 }
1812
1813 if error.is::<PaymentRequiredError>() {
1814 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1815 } else if let Some(error) =
1816 error.downcast_ref::<ModelRequestLimitReachedError>()
1817 {
1818 cx.emit(ThreadEvent::ShowError(
1819 ThreadError::ModelRequestLimitReached { plan: error.plan },
1820 ));
1821 } else if let Some(known_error) =
1822 error.downcast_ref::<LanguageModelKnownError>()
1823 {
1824 match known_error {
1825 LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
1826 thread.exceeded_window_error = Some(ExceededWindowError {
1827 model_id: model.id(),
1828 token_count: *tokens,
1829 });
1830 cx.notify();
1831 }
1832 LanguageModelKnownError::RateLimitExceeded { .. } => {
1833 // In the future we will report the error to the user, wait retry_after, and then retry.
1834 emit_generic_error(error, cx);
1835 }
1836 LanguageModelKnownError::Overloaded => {
1837 // In the future we will wait and then retry, up to N times.
1838 emit_generic_error(error, cx);
1839 }
1840 LanguageModelKnownError::ApiInternalServerError => {
1841 // In the future we will retry the request, but only once.
1842 emit_generic_error(error, cx);
1843 }
1844 LanguageModelKnownError::ReadResponseError(_) |
1845 LanguageModelKnownError::DeserializeResponse(_) |
1846 LanguageModelKnownError::UnknownResponseFormat(_) => {
1847 // In the future we will attempt to re-roll response, but only once
1848 emit_generic_error(error, cx);
1849 }
1850 }
1851 } else {
1852 emit_generic_error(error, cx);
1853 }
1854
1855 thread.cancel_last_completion(window, cx);
1856 }
1857 }
1858
1859 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1860
1861 if let Some((request_callback, (request, response_events))) = thread
1862 .request_callback
1863 .as_mut()
1864 .zip(request_callback_parameters.as_ref())
1865 {
1866 request_callback(request, response_events);
1867 }
1868
1869 thread.auto_capture_telemetry(cx);
1870
1871 if let Ok(initial_usage) = initial_token_usage {
1872 let usage = thread.cumulative_token_usage - initial_usage;
1873
1874 telemetry::event!(
1875 "Assistant Thread Completion",
1876 thread_id = thread.id().to_string(),
1877 prompt_id = prompt_id,
1878 model = model.telemetry_id(),
1879 model_provider = model.provider_id().to_string(),
1880 input_tokens = usage.input_tokens,
1881 output_tokens = usage.output_tokens,
1882 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1883 cache_read_input_tokens = usage.cache_read_input_tokens,
1884 );
1885 }
1886 })
1887 .ok();
1888 });
1889
1890 self.pending_completions.push(PendingCompletion {
1891 id: pending_completion_id,
1892 queue_state: QueueState::Sending,
1893 _task: task,
1894 });
1895 }
1896
1897 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1898 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1899 println!("No thread summary model");
1900 return;
1901 };
1902
1903 if !model.provider.is_authenticated(cx) {
1904 return;
1905 }
1906
1907 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1908
1909 let request = self.to_summarize_request(
1910 &model.model,
1911 CompletionIntent::ThreadSummarization,
1912 added_user_message.into(),
1913 cx,
1914 );
1915
1916 self.summary = ThreadSummary::Generating;
1917
1918 self.pending_summary = cx.spawn(async move |this, cx| {
1919 let result = async {
1920 let mut messages = model.model.stream_completion(request, &cx).await?;
1921
1922 let mut new_summary = String::new();
1923 while let Some(event) = messages.next().await {
1924 let Ok(event) = event else {
1925 continue;
1926 };
1927 let text = match event {
1928 LanguageModelCompletionEvent::Text(text) => text,
1929 LanguageModelCompletionEvent::StatusUpdate(
1930 CompletionRequestStatus::UsageUpdated { amount, limit },
1931 ) => {
1932 this.update(cx, |thread, cx| {
1933 thread.update_model_request_usage(amount as u32, limit, cx);
1934 })?;
1935 continue;
1936 }
1937 _ => continue,
1938 };
1939
1940 let mut lines = text.lines();
1941 new_summary.extend(lines.next());
1942
1943 // Stop if the LLM generated multiple lines.
1944 if lines.next().is_some() {
1945 break;
1946 }
1947 }
1948
1949 anyhow::Ok(new_summary)
1950 }
1951 .await;
1952
1953 this.update(cx, |this, cx| {
1954 match result {
1955 Ok(new_summary) => {
1956 if new_summary.is_empty() {
1957 this.summary = ThreadSummary::Error;
1958 } else {
1959 this.summary = ThreadSummary::Ready(new_summary.into());
1960 }
1961 }
1962 Err(err) => {
1963 this.summary = ThreadSummary::Error;
1964 log::error!("Failed to generate thread summary: {}", err);
1965 }
1966 }
1967 cx.emit(ThreadEvent::SummaryGenerated);
1968 })
1969 .log_err()?;
1970
1971 Some(())
1972 });
1973 }
1974
1975 pub fn start_generating_detailed_summary_if_needed(
1976 &mut self,
1977 thread_store: WeakEntity<ThreadStore>,
1978 cx: &mut Context<Self>,
1979 ) {
1980 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1981 return;
1982 };
1983
1984 match &*self.detailed_summary_rx.borrow() {
1985 DetailedSummaryState::Generating { message_id, .. }
1986 | DetailedSummaryState::Generated { message_id, .. }
1987 if *message_id == last_message_id =>
1988 {
1989 // Already up-to-date
1990 return;
1991 }
1992 _ => {}
1993 }
1994
1995 let Some(ConfiguredModel { model, provider }) =
1996 LanguageModelRegistry::read_global(cx).thread_summary_model()
1997 else {
1998 return;
1999 };
2000
2001 if !provider.is_authenticated(cx) {
2002 return;
2003 }
2004
2005 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2006
2007 let request = self.to_summarize_request(
2008 &model,
2009 CompletionIntent::ThreadContextSummarization,
2010 added_user_message.into(),
2011 cx,
2012 );
2013
2014 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2015 message_id: last_message_id,
2016 };
2017
2018 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2019 // be better to allow the old task to complete, but this would require logic for choosing
2020 // which result to prefer (the old task could complete after the new one, resulting in a
2021 // stale summary).
2022 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2023 let stream = model.stream_completion_text(request, &cx);
2024 let Some(mut messages) = stream.await.log_err() else {
2025 thread
2026 .update(cx, |thread, _cx| {
2027 *thread.detailed_summary_tx.borrow_mut() =
2028 DetailedSummaryState::NotGenerated;
2029 })
2030 .ok()?;
2031 return None;
2032 };
2033
2034 let mut new_detailed_summary = String::new();
2035
2036 while let Some(chunk) = messages.stream.next().await {
2037 if let Some(chunk) = chunk.log_err() {
2038 new_detailed_summary.push_str(&chunk);
2039 }
2040 }
2041
2042 thread
2043 .update(cx, |thread, _cx| {
2044 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2045 text: new_detailed_summary.into(),
2046 message_id: last_message_id,
2047 };
2048 })
2049 .ok()?;
2050
2051 // Save thread so its summary can be reused later
2052 if let Some(thread) = thread.upgrade() {
2053 if let Ok(Ok(save_task)) = cx.update(|cx| {
2054 thread_store
2055 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2056 }) {
2057 save_task.await.log_err();
2058 }
2059 }
2060
2061 Some(())
2062 });
2063 }
2064
2065 pub async fn wait_for_detailed_summary_or_text(
2066 this: &Entity<Self>,
2067 cx: &mut AsyncApp,
2068 ) -> Option<SharedString> {
2069 let mut detailed_summary_rx = this
2070 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2071 .ok()?;
2072 loop {
2073 match detailed_summary_rx.recv().await? {
2074 DetailedSummaryState::Generating { .. } => {}
2075 DetailedSummaryState::NotGenerated => {
2076 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2077 }
2078 DetailedSummaryState::Generated { text, .. } => return Some(text),
2079 }
2080 }
2081 }
2082
2083 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2084 self.detailed_summary_rx
2085 .borrow()
2086 .text()
2087 .unwrap_or_else(|| self.text().into())
2088 }
2089
2090 pub fn is_generating_detailed_summary(&self) -> bool {
2091 matches!(
2092 &*self.detailed_summary_rx.borrow(),
2093 DetailedSummaryState::Generating { .. }
2094 )
2095 }
2096
2097 pub fn use_pending_tools(
2098 &mut self,
2099 window: Option<AnyWindowHandle>,
2100 cx: &mut Context<Self>,
2101 model: Arc<dyn LanguageModel>,
2102 ) -> Vec<PendingToolUse> {
2103 self.auto_capture_telemetry(cx);
2104 let request =
2105 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2106 let pending_tool_uses = self
2107 .tool_use
2108 .pending_tool_uses()
2109 .into_iter()
2110 .filter(|tool_use| tool_use.status.is_idle())
2111 .cloned()
2112 .collect::<Vec<_>>();
2113
2114 for tool_use in pending_tool_uses.iter() {
2115 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2116 if tool.needs_confirmation(&tool_use.input, cx)
2117 && !AgentSettings::get_global(cx).always_allow_tool_actions
2118 {
2119 self.tool_use.confirm_tool_use(
2120 tool_use.id.clone(),
2121 tool_use.ui_text.clone(),
2122 tool_use.input.clone(),
2123 request.clone(),
2124 tool,
2125 );
2126 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2127 } else {
2128 self.run_tool(
2129 tool_use.id.clone(),
2130 tool_use.ui_text.clone(),
2131 tool_use.input.clone(),
2132 request.clone(),
2133 tool,
2134 model.clone(),
2135 window,
2136 cx,
2137 );
2138 }
2139 } else {
2140 self.handle_hallucinated_tool_use(
2141 tool_use.id.clone(),
2142 tool_use.name.clone(),
2143 window,
2144 cx,
2145 );
2146 }
2147 }
2148
2149 pending_tool_uses
2150 }
2151
2152 pub fn handle_hallucinated_tool_use(
2153 &mut self,
2154 tool_use_id: LanguageModelToolUseId,
2155 hallucinated_tool_name: Arc<str>,
2156 window: Option<AnyWindowHandle>,
2157 cx: &mut Context<Thread>,
2158 ) {
2159 let available_tools = self.profile.enabled_tools(cx);
2160
2161 let tool_list = available_tools
2162 .iter()
2163 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2164 .collect::<Vec<_>>()
2165 .join("\n");
2166
2167 let error_message = format!(
2168 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2169 hallucinated_tool_name, tool_list
2170 );
2171
2172 let pending_tool_use = self.tool_use.insert_tool_output(
2173 tool_use_id.clone(),
2174 hallucinated_tool_name,
2175 Err(anyhow!("Missing tool call: {error_message}")),
2176 self.configured_model.as_ref(),
2177 );
2178
2179 cx.emit(ThreadEvent::MissingToolUse {
2180 tool_use_id: tool_use_id.clone(),
2181 ui_text: error_message.into(),
2182 });
2183
2184 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2185 }
2186
2187 pub fn receive_invalid_tool_json(
2188 &mut self,
2189 tool_use_id: LanguageModelToolUseId,
2190 tool_name: Arc<str>,
2191 invalid_json: Arc<str>,
2192 error: String,
2193 window: Option<AnyWindowHandle>,
2194 cx: &mut Context<Thread>,
2195 ) {
2196 log::error!("The model returned invalid input JSON: {invalid_json}");
2197
2198 let pending_tool_use = self.tool_use.insert_tool_output(
2199 tool_use_id.clone(),
2200 tool_name,
2201 Err(anyhow!("Error parsing input JSON: {error}")),
2202 self.configured_model.as_ref(),
2203 );
2204 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2205 pending_tool_use.ui_text.clone()
2206 } else {
2207 log::error!(
2208 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2209 );
2210 format!("Unknown tool {}", tool_use_id).into()
2211 };
2212
2213 cx.emit(ThreadEvent::InvalidToolInput {
2214 tool_use_id: tool_use_id.clone(),
2215 ui_text,
2216 invalid_input_json: invalid_json,
2217 });
2218
2219 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2220 }
2221
2222 pub fn run_tool(
2223 &mut self,
2224 tool_use_id: LanguageModelToolUseId,
2225 ui_text: impl Into<SharedString>,
2226 input: serde_json::Value,
2227 request: Arc<LanguageModelRequest>,
2228 tool: Arc<dyn Tool>,
2229 model: Arc<dyn LanguageModel>,
2230 window: Option<AnyWindowHandle>,
2231 cx: &mut Context<Thread>,
2232 ) {
2233 let task =
2234 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2235 self.tool_use
2236 .run_pending_tool(tool_use_id, ui_text.into(), task);
2237 }
2238
2239 fn spawn_tool_use(
2240 &mut self,
2241 tool_use_id: LanguageModelToolUseId,
2242 request: Arc<LanguageModelRequest>,
2243 input: serde_json::Value,
2244 tool: Arc<dyn Tool>,
2245 model: Arc<dyn LanguageModel>,
2246 window: Option<AnyWindowHandle>,
2247 cx: &mut Context<Thread>,
2248 ) -> Task<()> {
2249 let tool_name: Arc<str> = tool.name().into();
2250
2251 let tool_result = tool.run(
2252 input,
2253 request,
2254 self.project.clone(),
2255 self.action_log.clone(),
2256 model,
2257 window,
2258 cx,
2259 );
2260
2261 // Store the card separately if it exists
2262 if let Some(card) = tool_result.card.clone() {
2263 self.tool_use
2264 .insert_tool_result_card(tool_use_id.clone(), card);
2265 }
2266
2267 cx.spawn({
2268 async move |thread: WeakEntity<Thread>, cx| {
2269 let output = tool_result.output.await;
2270
2271 thread
2272 .update(cx, |thread, cx| {
2273 let pending_tool_use = thread.tool_use.insert_tool_output(
2274 tool_use_id.clone(),
2275 tool_name,
2276 output,
2277 thread.configured_model.as_ref(),
2278 );
2279 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2280 })
2281 .ok();
2282 }
2283 })
2284 }
2285
2286 fn tool_finished(
2287 &mut self,
2288 tool_use_id: LanguageModelToolUseId,
2289 pending_tool_use: Option<PendingToolUse>,
2290 canceled: bool,
2291 window: Option<AnyWindowHandle>,
2292 cx: &mut Context<Self>,
2293 ) {
2294 if self.all_tools_finished() {
2295 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2296 if !canceled {
2297 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2298 }
2299 self.auto_capture_telemetry(cx);
2300 }
2301 }
2302
2303 cx.emit(ThreadEvent::ToolFinished {
2304 tool_use_id,
2305 pending_tool_use,
2306 });
2307 }
2308
2309 /// Cancels the last pending completion, if there are any pending.
2310 ///
2311 /// Returns whether a completion was canceled.
2312 pub fn cancel_last_completion(
2313 &mut self,
2314 window: Option<AnyWindowHandle>,
2315 cx: &mut Context<Self>,
2316 ) -> bool {
2317 let mut canceled = self.pending_completions.pop().is_some();
2318
2319 for pending_tool_use in self.tool_use.cancel_pending() {
2320 canceled = true;
2321 self.tool_finished(
2322 pending_tool_use.id.clone(),
2323 Some(pending_tool_use),
2324 true,
2325 window,
2326 cx,
2327 );
2328 }
2329
2330 if canceled {
2331 cx.emit(ThreadEvent::CompletionCanceled);
2332
2333 // When canceled, we always want to insert the checkpoint.
2334 // (We skip over finalize_pending_checkpoint, because it
2335 // would conclude we didn't have anything to insert here.)
2336 if let Some(checkpoint) = self.pending_checkpoint.take() {
2337 self.insert_checkpoint(checkpoint, cx);
2338 }
2339 } else {
2340 self.finalize_pending_checkpoint(cx);
2341 }
2342
2343 canceled
2344 }
2345
2346 /// Signals that any in-progress editing should be canceled.
2347 ///
2348 /// This method is used to notify listeners (like ActiveThread) that
2349 /// they should cancel any editing operations.
2350 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2351 cx.emit(ThreadEvent::CancelEditing);
2352 }
2353
2354 pub fn feedback(&self) -> Option<ThreadFeedback> {
2355 self.feedback
2356 }
2357
2358 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2359 self.message_feedback.get(&message_id).copied()
2360 }
2361
2362 pub fn report_message_feedback(
2363 &mut self,
2364 message_id: MessageId,
2365 feedback: ThreadFeedback,
2366 cx: &mut Context<Self>,
2367 ) -> Task<Result<()>> {
2368 if self.message_feedback.get(&message_id) == Some(&feedback) {
2369 return Task::ready(Ok(()));
2370 }
2371
2372 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2373 let serialized_thread = self.serialize(cx);
2374 let thread_id = self.id().clone();
2375 let client = self.project.read(cx).client();
2376
2377 let enabled_tool_names: Vec<String> = self
2378 .profile
2379 .enabled_tools(cx)
2380 .iter()
2381 .map(|tool| tool.name())
2382 .collect();
2383
2384 self.message_feedback.insert(message_id, feedback);
2385
2386 cx.notify();
2387
2388 let message_content = self
2389 .message(message_id)
2390 .map(|msg| msg.to_string())
2391 .unwrap_or_default();
2392
2393 cx.background_spawn(async move {
2394 let final_project_snapshot = final_project_snapshot.await;
2395 let serialized_thread = serialized_thread.await?;
2396 let thread_data =
2397 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2398
2399 let rating = match feedback {
2400 ThreadFeedback::Positive => "positive",
2401 ThreadFeedback::Negative => "negative",
2402 };
2403 telemetry::event!(
2404 "Assistant Thread Rated",
2405 rating,
2406 thread_id,
2407 enabled_tool_names,
2408 message_id = message_id.0,
2409 message_content,
2410 thread_data,
2411 final_project_snapshot
2412 );
2413 client.telemetry().flush_events().await;
2414
2415 Ok(())
2416 })
2417 }
2418
2419 pub fn report_feedback(
2420 &mut self,
2421 feedback: ThreadFeedback,
2422 cx: &mut Context<Self>,
2423 ) -> Task<Result<()>> {
2424 let last_assistant_message_id = self
2425 .messages
2426 .iter()
2427 .rev()
2428 .find(|msg| msg.role == Role::Assistant)
2429 .map(|msg| msg.id);
2430
2431 if let Some(message_id) = last_assistant_message_id {
2432 self.report_message_feedback(message_id, feedback, cx)
2433 } else {
2434 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2435 let serialized_thread = self.serialize(cx);
2436 let thread_id = self.id().clone();
2437 let client = self.project.read(cx).client();
2438 self.feedback = Some(feedback);
2439 cx.notify();
2440
2441 cx.background_spawn(async move {
2442 let final_project_snapshot = final_project_snapshot.await;
2443 let serialized_thread = serialized_thread.await?;
2444 let thread_data = serde_json::to_value(serialized_thread)
2445 .unwrap_or_else(|_| serde_json::Value::Null);
2446
2447 let rating = match feedback {
2448 ThreadFeedback::Positive => "positive",
2449 ThreadFeedback::Negative => "negative",
2450 };
2451 telemetry::event!(
2452 "Assistant Thread Rated",
2453 rating,
2454 thread_id,
2455 thread_data,
2456 final_project_snapshot
2457 );
2458 client.telemetry().flush_events().await;
2459
2460 Ok(())
2461 })
2462 }
2463 }
2464
2465 /// Create a snapshot of the current project state including git information and unsaved buffers.
2466 fn project_snapshot(
2467 project: Entity<Project>,
2468 cx: &mut Context<Self>,
2469 ) -> Task<Arc<ProjectSnapshot>> {
2470 let git_store = project.read(cx).git_store().clone();
2471 let worktree_snapshots: Vec<_> = project
2472 .read(cx)
2473 .visible_worktrees(cx)
2474 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2475 .collect();
2476
2477 cx.spawn(async move |_, cx| {
2478 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2479
2480 let mut unsaved_buffers = Vec::new();
2481 cx.update(|app_cx| {
2482 let buffer_store = project.read(app_cx).buffer_store();
2483 for buffer_handle in buffer_store.read(app_cx).buffers() {
2484 let buffer = buffer_handle.read(app_cx);
2485 if buffer.is_dirty() {
2486 if let Some(file) = buffer.file() {
2487 let path = file.path().to_string_lossy().to_string();
2488 unsaved_buffers.push(path);
2489 }
2490 }
2491 }
2492 })
2493 .ok();
2494
2495 Arc::new(ProjectSnapshot {
2496 worktree_snapshots,
2497 unsaved_buffer_paths: unsaved_buffers,
2498 timestamp: Utc::now(),
2499 })
2500 })
2501 }
2502
2503 fn worktree_snapshot(
2504 worktree: Entity<project::Worktree>,
2505 git_store: Entity<GitStore>,
2506 cx: &App,
2507 ) -> Task<WorktreeSnapshot> {
2508 cx.spawn(async move |cx| {
2509 // Get worktree path and snapshot
2510 let worktree_info = cx.update(|app_cx| {
2511 let worktree = worktree.read(app_cx);
2512 let path = worktree.abs_path().to_string_lossy().to_string();
2513 let snapshot = worktree.snapshot();
2514 (path, snapshot)
2515 });
2516
2517 let Ok((worktree_path, _snapshot)) = worktree_info else {
2518 return WorktreeSnapshot {
2519 worktree_path: String::new(),
2520 git_state: None,
2521 };
2522 };
2523
2524 let git_state = git_store
2525 .update(cx, |git_store, cx| {
2526 git_store
2527 .repositories()
2528 .values()
2529 .find(|repo| {
2530 repo.read(cx)
2531 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2532 .is_some()
2533 })
2534 .cloned()
2535 })
2536 .ok()
2537 .flatten()
2538 .map(|repo| {
2539 repo.update(cx, |repo, _| {
2540 let current_branch =
2541 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2542 repo.send_job(None, |state, _| async move {
2543 let RepositoryState::Local { backend, .. } = state else {
2544 return GitState {
2545 remote_url: None,
2546 head_sha: None,
2547 current_branch,
2548 diff: None,
2549 };
2550 };
2551
2552 let remote_url = backend.remote_url("origin");
2553 let head_sha = backend.head_sha().await;
2554 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2555
2556 GitState {
2557 remote_url,
2558 head_sha,
2559 current_branch,
2560 diff,
2561 }
2562 })
2563 })
2564 });
2565
2566 let git_state = match git_state {
2567 Some(git_state) => match git_state.ok() {
2568 Some(git_state) => git_state.await.ok(),
2569 None => None,
2570 },
2571 None => None,
2572 };
2573
2574 WorktreeSnapshot {
2575 worktree_path,
2576 git_state,
2577 }
2578 })
2579 }
2580
2581 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2582 let mut markdown = Vec::new();
2583
2584 let summary = self.summary().or_default();
2585 writeln!(markdown, "# {summary}\n")?;
2586
2587 for message in self.messages() {
2588 writeln!(
2589 markdown,
2590 "## {role}\n",
2591 role = match message.role {
2592 Role::User => "User",
2593 Role::Assistant => "Agent",
2594 Role::System => "System",
2595 }
2596 )?;
2597
2598 if !message.loaded_context.text.is_empty() {
2599 writeln!(markdown, "{}", message.loaded_context.text)?;
2600 }
2601
2602 if !message.loaded_context.images.is_empty() {
2603 writeln!(
2604 markdown,
2605 "\n{} images attached as context.\n",
2606 message.loaded_context.images.len()
2607 )?;
2608 }
2609
2610 for segment in &message.segments {
2611 match segment {
2612 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2613 MessageSegment::Thinking { text, .. } => {
2614 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2615 }
2616 MessageSegment::RedactedThinking(_) => {}
2617 }
2618 }
2619
2620 for tool_use in self.tool_uses_for_message(message.id, cx) {
2621 writeln!(
2622 markdown,
2623 "**Use Tool: {} ({})**",
2624 tool_use.name, tool_use.id
2625 )?;
2626 writeln!(markdown, "```json")?;
2627 writeln!(
2628 markdown,
2629 "{}",
2630 serde_json::to_string_pretty(&tool_use.input)?
2631 )?;
2632 writeln!(markdown, "```")?;
2633 }
2634
2635 for tool_result in self.tool_results_for_message(message.id) {
2636 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2637 if tool_result.is_error {
2638 write!(markdown, " (Error)")?;
2639 }
2640
2641 writeln!(markdown, "**\n")?;
2642 match &tool_result.content {
2643 LanguageModelToolResultContent::Text(text) => {
2644 writeln!(markdown, "{text}")?;
2645 }
2646 LanguageModelToolResultContent::Image(image) => {
2647 writeln!(markdown, "", image.source)?;
2648 }
2649 }
2650
2651 if let Some(output) = tool_result.output.as_ref() {
2652 writeln!(
2653 markdown,
2654 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2655 serde_json::to_string_pretty(output)?
2656 )?;
2657 }
2658 }
2659 }
2660
2661 Ok(String::from_utf8_lossy(&markdown).to_string())
2662 }
2663
2664 pub fn keep_edits_in_range(
2665 &mut self,
2666 buffer: Entity<language::Buffer>,
2667 buffer_range: Range<language::Anchor>,
2668 cx: &mut Context<Self>,
2669 ) {
2670 self.action_log.update(cx, |action_log, cx| {
2671 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2672 });
2673 }
2674
2675 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2676 self.action_log
2677 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2678 }
2679
2680 pub fn reject_edits_in_ranges(
2681 &mut self,
2682 buffer: Entity<language::Buffer>,
2683 buffer_ranges: Vec<Range<language::Anchor>>,
2684 cx: &mut Context<Self>,
2685 ) -> Task<Result<()>> {
2686 self.action_log.update(cx, |action_log, cx| {
2687 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2688 })
2689 }
2690
2691 pub fn action_log(&self) -> &Entity<ActionLog> {
2692 &self.action_log
2693 }
2694
2695 pub fn project(&self) -> &Entity<Project> {
2696 &self.project
2697 }
2698
2699 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2700 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2701 return;
2702 }
2703
2704 let now = Instant::now();
2705 if let Some(last) = self.last_auto_capture_at {
2706 if now.duration_since(last).as_secs() < 10 {
2707 return;
2708 }
2709 }
2710
2711 self.last_auto_capture_at = Some(now);
2712
2713 let thread_id = self.id().clone();
2714 let github_login = self
2715 .project
2716 .read(cx)
2717 .user_store()
2718 .read(cx)
2719 .current_user()
2720 .map(|user| user.github_login.clone());
2721 let client = self.project.read(cx).client();
2722 let serialize_task = self.serialize(cx);
2723
2724 cx.background_executor()
2725 .spawn(async move {
2726 if let Ok(serialized_thread) = serialize_task.await {
2727 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2728 telemetry::event!(
2729 "Agent Thread Auto-Captured",
2730 thread_id = thread_id.to_string(),
2731 thread_data = thread_data,
2732 auto_capture_reason = "tracked_user",
2733 github_login = github_login
2734 );
2735
2736 client.telemetry().flush_events().await;
2737 }
2738 }
2739 })
2740 .detach();
2741 }
2742
2743 pub fn cumulative_token_usage(&self) -> TokenUsage {
2744 self.cumulative_token_usage
2745 }
2746
2747 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2748 let Some(model) = self.configured_model.as_ref() else {
2749 return TotalTokenUsage::default();
2750 };
2751
2752 let max = model.model.max_token_count();
2753
2754 let index = self
2755 .messages
2756 .iter()
2757 .position(|msg| msg.id == message_id)
2758 .unwrap_or(0);
2759
2760 if index == 0 {
2761 return TotalTokenUsage { total: 0, max };
2762 }
2763
2764 let token_usage = &self
2765 .request_token_usage
2766 .get(index - 1)
2767 .cloned()
2768 .unwrap_or_default();
2769
2770 TotalTokenUsage {
2771 total: token_usage.total_tokens(),
2772 max,
2773 }
2774 }
2775
2776 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2777 let model = self.configured_model.as_ref()?;
2778
2779 let max = model.model.max_token_count();
2780
2781 if let Some(exceeded_error) = &self.exceeded_window_error {
2782 if model.model.id() == exceeded_error.model_id {
2783 return Some(TotalTokenUsage {
2784 total: exceeded_error.token_count,
2785 max,
2786 });
2787 }
2788 }
2789
2790 let total = self
2791 .token_usage_at_last_message()
2792 .unwrap_or_default()
2793 .total_tokens();
2794
2795 Some(TotalTokenUsage { total, max })
2796 }
2797
2798 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2799 self.request_token_usage
2800 .get(self.messages.len().saturating_sub(1))
2801 .or_else(|| self.request_token_usage.last())
2802 .cloned()
2803 }
2804
2805 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2806 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2807 self.request_token_usage
2808 .resize(self.messages.len(), placeholder);
2809
2810 if let Some(last) = self.request_token_usage.last_mut() {
2811 *last = token_usage;
2812 }
2813 }
2814
2815 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
2816 self.project.update(cx, |project, cx| {
2817 project.user_store().update(cx, |user_store, cx| {
2818 user_store.update_model_request_usage(
2819 ModelRequestUsage(RequestUsage {
2820 amount: amount as i32,
2821 limit,
2822 }),
2823 cx,
2824 )
2825 })
2826 });
2827 }
2828
2829 pub fn deny_tool_use(
2830 &mut self,
2831 tool_use_id: LanguageModelToolUseId,
2832 tool_name: Arc<str>,
2833 window: Option<AnyWindowHandle>,
2834 cx: &mut Context<Self>,
2835 ) {
2836 let err = Err(anyhow::anyhow!(
2837 "Permission to run tool action denied by user"
2838 ));
2839
2840 self.tool_use.insert_tool_output(
2841 tool_use_id.clone(),
2842 tool_name,
2843 err,
2844 self.configured_model.as_ref(),
2845 );
2846 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2847 }
2848}
2849
2850#[derive(Debug, Clone, Error)]
2851pub enum ThreadError {
2852 #[error("Payment required")]
2853 PaymentRequired,
2854 #[error("Model request limit reached")]
2855 ModelRequestLimitReached { plan: Plan },
2856 #[error("Message {header}: {message}")]
2857 Message {
2858 header: SharedString,
2859 message: SharedString,
2860 },
2861}
2862
2863#[derive(Debug, Clone)]
2864pub enum ThreadEvent {
2865 ShowError(ThreadError),
2866 StreamedCompletion,
2867 ReceivedTextChunk,
2868 NewRequest,
2869 StreamedAssistantText(MessageId, String),
2870 StreamedAssistantThinking(MessageId, String),
2871 StreamedToolUse {
2872 tool_use_id: LanguageModelToolUseId,
2873 ui_text: Arc<str>,
2874 input: serde_json::Value,
2875 },
2876 MissingToolUse {
2877 tool_use_id: LanguageModelToolUseId,
2878 ui_text: Arc<str>,
2879 },
2880 InvalidToolInput {
2881 tool_use_id: LanguageModelToolUseId,
2882 ui_text: Arc<str>,
2883 invalid_input_json: Arc<str>,
2884 },
2885 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2886 MessageAdded(MessageId),
2887 MessageEdited(MessageId),
2888 MessageDeleted(MessageId),
2889 SummaryGenerated,
2890 SummaryChanged,
2891 UsePendingTools {
2892 tool_uses: Vec<PendingToolUse>,
2893 },
2894 ToolFinished {
2895 #[allow(unused)]
2896 tool_use_id: LanguageModelToolUseId,
2897 /// The pending tool use that corresponds to this tool.
2898 pending_tool_use: Option<PendingToolUse>,
2899 },
2900 CheckpointChanged,
2901 ToolConfirmationNeeded,
2902 ToolUseLimitReached,
2903 CancelEditing,
2904 CompletionCanceled,
2905 ProfileChanged,
2906}
2907
2908impl EventEmitter<ThreadEvent> for Thread {}
2909
2910struct PendingCompletion {
2911 id: usize,
2912 queue_state: QueueState,
2913 _task: Task<()>,
2914}
2915
2916/// Resolves tool name conflicts by ensuring all tool names are unique.
2917///
2918/// When multiple tools have the same name, this function applies the following rules:
2919/// 1. Native tools always keep their original name
2920/// 2. Context server tools get prefixed with their server ID and an underscore
2921/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
2922/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
2923///
2924/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
2925fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
2926 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
2927 let mut tool_name = tool.name();
2928 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2929 tool_name
2930 }
2931
2932 const MAX_TOOL_NAME_LENGTH: usize = 64;
2933
2934 let mut duplicated_tool_names = HashSet::default();
2935 let mut seen_tool_names = HashSet::default();
2936 for tool in tools {
2937 let tool_name = resolve_tool_name(tool);
2938 if seen_tool_names.contains(&tool_name) {
2939 debug_assert!(
2940 tool.source() != assistant_tool::ToolSource::Native,
2941 "There are two built-in tools with the same name: {}",
2942 tool_name
2943 );
2944 duplicated_tool_names.insert(tool_name);
2945 } else {
2946 seen_tool_names.insert(tool_name);
2947 }
2948 }
2949
2950 if duplicated_tool_names.is_empty() {
2951 return tools
2952 .into_iter()
2953 .map(|tool| (resolve_tool_name(tool), tool.clone()))
2954 .collect();
2955 }
2956
2957 tools
2958 .into_iter()
2959 .filter_map(|tool| {
2960 let mut tool_name = resolve_tool_name(tool);
2961 if !duplicated_tool_names.contains(&tool_name) {
2962 return Some((tool_name, tool.clone()));
2963 }
2964 match tool.source() {
2965 assistant_tool::ToolSource::Native => {
2966 // Built-in tools always keep their original name
2967 Some((tool_name, tool.clone()))
2968 }
2969 assistant_tool::ToolSource::ContextServer { id } => {
2970 // Context server tools are prefixed with the context server ID, and truncated if necessary
2971 tool_name.insert(0, '_');
2972 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
2973 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
2974 let mut id = id.to_string();
2975 id.truncate(len);
2976 tool_name.insert_str(0, &id);
2977 } else {
2978 tool_name.insert_str(0, &id);
2979 }
2980
2981 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2982
2983 if seen_tool_names.contains(&tool_name) {
2984 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
2985 None
2986 } else {
2987 Some((tool_name, tool.clone()))
2988 }
2989 }
2990 }
2991 })
2992 .collect()
2993}
2994
2995#[cfg(test)]
2996mod tests {
2997 use super::*;
2998 use crate::{
2999 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3000 };
3001 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3002 use assistant_tool::ToolRegistry;
3003 use gpui::TestAppContext;
3004 use icons::IconName;
3005 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3006 use project::{FakeFs, Project};
3007 use prompt_store::PromptBuilder;
3008 use serde_json::json;
3009 use settings::{Settings, SettingsStore};
3010 use std::sync::Arc;
3011 use theme::ThemeSettings;
3012 use util::path;
3013 use workspace::Workspace;
3014
3015 #[gpui::test]
3016 async fn test_message_with_context(cx: &mut TestAppContext) {
3017 init_test_settings(cx);
3018
3019 let project = create_test_project(
3020 cx,
3021 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3022 )
3023 .await;
3024
3025 let (_workspace, _thread_store, thread, context_store, model) =
3026 setup_test_environment(cx, project.clone()).await;
3027
3028 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3029 .await
3030 .unwrap();
3031
3032 let context =
3033 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3034 let loaded_context = cx
3035 .update(|cx| load_context(vec![context], &project, &None, cx))
3036 .await;
3037
3038 // Insert user message with context
3039 let message_id = thread.update(cx, |thread, cx| {
3040 thread.insert_user_message(
3041 "Please explain this code",
3042 loaded_context,
3043 None,
3044 Vec::new(),
3045 cx,
3046 )
3047 });
3048
3049 // Check content and context in message object
3050 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3051
3052 // Use different path format strings based on platform for the test
3053 #[cfg(windows)]
3054 let path_part = r"test\code.rs";
3055 #[cfg(not(windows))]
3056 let path_part = "test/code.rs";
3057
3058 let expected_context = format!(
3059 r#"
3060<context>
3061The following items were attached by the user. They are up-to-date and don't need to be re-read.
3062
3063<files>
3064```rs {path_part}
3065fn main() {{
3066 println!("Hello, world!");
3067}}
3068```
3069</files>
3070</context>
3071"#
3072 );
3073
3074 assert_eq!(message.role, Role::User);
3075 assert_eq!(message.segments.len(), 1);
3076 assert_eq!(
3077 message.segments[0],
3078 MessageSegment::Text("Please explain this code".to_string())
3079 );
3080 assert_eq!(message.loaded_context.text, expected_context);
3081
3082 // Check message in request
3083 let request = thread.update(cx, |thread, cx| {
3084 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3085 });
3086
3087 assert_eq!(request.messages.len(), 2);
3088 let expected_full_message = format!("{}Please explain this code", expected_context);
3089 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3090 }
3091
3092 #[gpui::test]
3093 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3094 init_test_settings(cx);
3095
3096 let project = create_test_project(
3097 cx,
3098 json!({
3099 "file1.rs": "fn function1() {}\n",
3100 "file2.rs": "fn function2() {}\n",
3101 "file3.rs": "fn function3() {}\n",
3102 "file4.rs": "fn function4() {}\n",
3103 }),
3104 )
3105 .await;
3106
3107 let (_, _thread_store, thread, context_store, model) =
3108 setup_test_environment(cx, project.clone()).await;
3109
3110 // First message with context 1
3111 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3112 .await
3113 .unwrap();
3114 let new_contexts = context_store.update(cx, |store, cx| {
3115 store.new_context_for_thread(thread.read(cx), None)
3116 });
3117 assert_eq!(new_contexts.len(), 1);
3118 let loaded_context = cx
3119 .update(|cx| load_context(new_contexts, &project, &None, cx))
3120 .await;
3121 let message1_id = thread.update(cx, |thread, cx| {
3122 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3123 });
3124
3125 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3126 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3127 .await
3128 .unwrap();
3129 let new_contexts = context_store.update(cx, |store, cx| {
3130 store.new_context_for_thread(thread.read(cx), None)
3131 });
3132 assert_eq!(new_contexts.len(), 1);
3133 let loaded_context = cx
3134 .update(|cx| load_context(new_contexts, &project, &None, cx))
3135 .await;
3136 let message2_id = thread.update(cx, |thread, cx| {
3137 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3138 });
3139
3140 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3141 //
3142 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3143 .await
3144 .unwrap();
3145 let new_contexts = context_store.update(cx, |store, cx| {
3146 store.new_context_for_thread(thread.read(cx), None)
3147 });
3148 assert_eq!(new_contexts.len(), 1);
3149 let loaded_context = cx
3150 .update(|cx| load_context(new_contexts, &project, &None, cx))
3151 .await;
3152 let message3_id = thread.update(cx, |thread, cx| {
3153 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3154 });
3155
3156 // Check what contexts are included in each message
3157 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3158 (
3159 thread.message(message1_id).unwrap().clone(),
3160 thread.message(message2_id).unwrap().clone(),
3161 thread.message(message3_id).unwrap().clone(),
3162 )
3163 });
3164
3165 // First message should include context 1
3166 assert!(message1.loaded_context.text.contains("file1.rs"));
3167
3168 // Second message should include only context 2 (not 1)
3169 assert!(!message2.loaded_context.text.contains("file1.rs"));
3170 assert!(message2.loaded_context.text.contains("file2.rs"));
3171
3172 // Third message should include only context 3 (not 1 or 2)
3173 assert!(!message3.loaded_context.text.contains("file1.rs"));
3174 assert!(!message3.loaded_context.text.contains("file2.rs"));
3175 assert!(message3.loaded_context.text.contains("file3.rs"));
3176
3177 // Check entire request to make sure all contexts are properly included
3178 let request = thread.update(cx, |thread, cx| {
3179 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3180 });
3181
3182 // The request should contain all 3 messages
3183 assert_eq!(request.messages.len(), 4);
3184
3185 // Check that the contexts are properly formatted in each message
3186 assert!(request.messages[1].string_contents().contains("file1.rs"));
3187 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3188 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3189
3190 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3191 assert!(request.messages[2].string_contents().contains("file2.rs"));
3192 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3193
3194 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3195 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3196 assert!(request.messages[3].string_contents().contains("file3.rs"));
3197
3198 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3199 .await
3200 .unwrap();
3201 let new_contexts = context_store.update(cx, |store, cx| {
3202 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3203 });
3204 assert_eq!(new_contexts.len(), 3);
3205 let loaded_context = cx
3206 .update(|cx| load_context(new_contexts, &project, &None, cx))
3207 .await
3208 .loaded_context;
3209
3210 assert!(!loaded_context.text.contains("file1.rs"));
3211 assert!(loaded_context.text.contains("file2.rs"));
3212 assert!(loaded_context.text.contains("file3.rs"));
3213 assert!(loaded_context.text.contains("file4.rs"));
3214
3215 let new_contexts = context_store.update(cx, |store, cx| {
3216 // Remove file4.rs
3217 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3218 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3219 });
3220 assert_eq!(new_contexts.len(), 2);
3221 let loaded_context = cx
3222 .update(|cx| load_context(new_contexts, &project, &None, cx))
3223 .await
3224 .loaded_context;
3225
3226 assert!(!loaded_context.text.contains("file1.rs"));
3227 assert!(loaded_context.text.contains("file2.rs"));
3228 assert!(loaded_context.text.contains("file3.rs"));
3229 assert!(!loaded_context.text.contains("file4.rs"));
3230
3231 let new_contexts = context_store.update(cx, |store, cx| {
3232 // Remove file3.rs
3233 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3234 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3235 });
3236 assert_eq!(new_contexts.len(), 1);
3237 let loaded_context = cx
3238 .update(|cx| load_context(new_contexts, &project, &None, cx))
3239 .await
3240 .loaded_context;
3241
3242 assert!(!loaded_context.text.contains("file1.rs"));
3243 assert!(loaded_context.text.contains("file2.rs"));
3244 assert!(!loaded_context.text.contains("file3.rs"));
3245 assert!(!loaded_context.text.contains("file4.rs"));
3246 }
3247
3248 #[gpui::test]
3249 async fn test_message_without_files(cx: &mut TestAppContext) {
3250 init_test_settings(cx);
3251
3252 let project = create_test_project(
3253 cx,
3254 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3255 )
3256 .await;
3257
3258 let (_, _thread_store, thread, _context_store, model) =
3259 setup_test_environment(cx, project.clone()).await;
3260
3261 // Insert user message without any context (empty context vector)
3262 let message_id = thread.update(cx, |thread, cx| {
3263 thread.insert_user_message(
3264 "What is the best way to learn Rust?",
3265 ContextLoadResult::default(),
3266 None,
3267 Vec::new(),
3268 cx,
3269 )
3270 });
3271
3272 // Check content and context in message object
3273 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3274
3275 // Context should be empty when no files are included
3276 assert_eq!(message.role, Role::User);
3277 assert_eq!(message.segments.len(), 1);
3278 assert_eq!(
3279 message.segments[0],
3280 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3281 );
3282 assert_eq!(message.loaded_context.text, "");
3283
3284 // Check message in request
3285 let request = thread.update(cx, |thread, cx| {
3286 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3287 });
3288
3289 assert_eq!(request.messages.len(), 2);
3290 assert_eq!(
3291 request.messages[1].string_contents(),
3292 "What is the best way to learn Rust?"
3293 );
3294
3295 // Add second message, also without context
3296 let message2_id = thread.update(cx, |thread, cx| {
3297 thread.insert_user_message(
3298 "Are there any good books?",
3299 ContextLoadResult::default(),
3300 None,
3301 Vec::new(),
3302 cx,
3303 )
3304 });
3305
3306 let message2 =
3307 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3308 assert_eq!(message2.loaded_context.text, "");
3309
3310 // Check that both messages appear in the request
3311 let request = thread.update(cx, |thread, cx| {
3312 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3313 });
3314
3315 assert_eq!(request.messages.len(), 3);
3316 assert_eq!(
3317 request.messages[1].string_contents(),
3318 "What is the best way to learn Rust?"
3319 );
3320 assert_eq!(
3321 request.messages[2].string_contents(),
3322 "Are there any good books?"
3323 );
3324 }
3325
3326 #[gpui::test]
3327 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3328 init_test_settings(cx);
3329
3330 let project = create_test_project(
3331 cx,
3332 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3333 )
3334 .await;
3335
3336 let (_workspace, thread_store, thread, _context_store, _model) =
3337 setup_test_environment(cx, project.clone()).await;
3338
3339 // Check that we are starting with the default profile
3340 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3341 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3342 assert_eq!(
3343 profile,
3344 AgentProfile::new(AgentProfileId::default(), tool_set)
3345 );
3346 }
3347
3348 #[gpui::test]
3349 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3350 init_test_settings(cx);
3351
3352 let project = create_test_project(
3353 cx,
3354 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3355 )
3356 .await;
3357
3358 let (_workspace, thread_store, thread, _context_store, _model) =
3359 setup_test_environment(cx, project.clone()).await;
3360
3361 // Profile gets serialized with default values
3362 let serialized = thread
3363 .update(cx, |thread, cx| thread.serialize(cx))
3364 .await
3365 .unwrap();
3366
3367 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3368
3369 let deserialized = cx.update(|cx| {
3370 thread.update(cx, |thread, cx| {
3371 Thread::deserialize(
3372 thread.id.clone(),
3373 serialized,
3374 thread.project.clone(),
3375 thread.tools.clone(),
3376 thread.prompt_builder.clone(),
3377 thread.project_context.clone(),
3378 None,
3379 cx,
3380 )
3381 })
3382 });
3383 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3384
3385 assert_eq!(
3386 deserialized.profile,
3387 AgentProfile::new(AgentProfileId::default(), tool_set)
3388 );
3389 }
3390
3391 #[gpui::test]
3392 async fn test_temperature_setting(cx: &mut TestAppContext) {
3393 init_test_settings(cx);
3394
3395 let project = create_test_project(
3396 cx,
3397 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3398 )
3399 .await;
3400
3401 let (_workspace, _thread_store, thread, _context_store, model) =
3402 setup_test_environment(cx, project.clone()).await;
3403
3404 // Both model and provider
3405 cx.update(|cx| {
3406 AgentSettings::override_global(
3407 AgentSettings {
3408 model_parameters: vec![LanguageModelParameters {
3409 provider: Some(model.provider_id().0.to_string().into()),
3410 model: Some(model.id().0.clone()),
3411 temperature: Some(0.66),
3412 }],
3413 ..AgentSettings::get_global(cx).clone()
3414 },
3415 cx,
3416 );
3417 });
3418
3419 let request = thread.update(cx, |thread, cx| {
3420 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3421 });
3422 assert_eq!(request.temperature, Some(0.66));
3423
3424 // Only model
3425 cx.update(|cx| {
3426 AgentSettings::override_global(
3427 AgentSettings {
3428 model_parameters: vec![LanguageModelParameters {
3429 provider: None,
3430 model: Some(model.id().0.clone()),
3431 temperature: Some(0.66),
3432 }],
3433 ..AgentSettings::get_global(cx).clone()
3434 },
3435 cx,
3436 );
3437 });
3438
3439 let request = thread.update(cx, |thread, cx| {
3440 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3441 });
3442 assert_eq!(request.temperature, Some(0.66));
3443
3444 // Only provider
3445 cx.update(|cx| {
3446 AgentSettings::override_global(
3447 AgentSettings {
3448 model_parameters: vec![LanguageModelParameters {
3449 provider: Some(model.provider_id().0.to_string().into()),
3450 model: None,
3451 temperature: Some(0.66),
3452 }],
3453 ..AgentSettings::get_global(cx).clone()
3454 },
3455 cx,
3456 );
3457 });
3458
3459 let request = thread.update(cx, |thread, cx| {
3460 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3461 });
3462 assert_eq!(request.temperature, Some(0.66));
3463
3464 // Same model name, different provider
3465 cx.update(|cx| {
3466 AgentSettings::override_global(
3467 AgentSettings {
3468 model_parameters: vec![LanguageModelParameters {
3469 provider: Some("anthropic".into()),
3470 model: Some(model.id().0.clone()),
3471 temperature: Some(0.66),
3472 }],
3473 ..AgentSettings::get_global(cx).clone()
3474 },
3475 cx,
3476 );
3477 });
3478
3479 let request = thread.update(cx, |thread, cx| {
3480 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3481 });
3482 assert_eq!(request.temperature, None);
3483 }
3484
3485 #[gpui::test]
3486 async fn test_thread_summary(cx: &mut TestAppContext) {
3487 init_test_settings(cx);
3488
3489 let project = create_test_project(cx, json!({})).await;
3490
3491 let (_, _thread_store, thread, _context_store, model) =
3492 setup_test_environment(cx, project.clone()).await;
3493
3494 // Initial state should be pending
3495 thread.read_with(cx, |thread, _| {
3496 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3497 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3498 });
3499
3500 // Manually setting the summary should not be allowed in this state
3501 thread.update(cx, |thread, cx| {
3502 thread.set_summary("This should not work", cx);
3503 });
3504
3505 thread.read_with(cx, |thread, _| {
3506 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3507 });
3508
3509 // Send a message
3510 thread.update(cx, |thread, cx| {
3511 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3512 thread.send_to_model(
3513 model.clone(),
3514 CompletionIntent::ThreadSummarization,
3515 None,
3516 cx,
3517 );
3518 });
3519
3520 let fake_model = model.as_fake();
3521 simulate_successful_response(&fake_model, cx);
3522
3523 // Should start generating summary when there are >= 2 messages
3524 thread.read_with(cx, |thread, _| {
3525 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3526 });
3527
3528 // Should not be able to set the summary while generating
3529 thread.update(cx, |thread, cx| {
3530 thread.set_summary("This should not work either", cx);
3531 });
3532
3533 thread.read_with(cx, |thread, _| {
3534 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3535 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3536 });
3537
3538 cx.run_until_parked();
3539 fake_model.stream_last_completion_response("Brief");
3540 fake_model.stream_last_completion_response(" Introduction");
3541 fake_model.end_last_completion_stream();
3542 cx.run_until_parked();
3543
3544 // Summary should be set
3545 thread.read_with(cx, |thread, _| {
3546 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3547 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3548 });
3549
3550 // Now we should be able to set a summary
3551 thread.update(cx, |thread, cx| {
3552 thread.set_summary("Brief Intro", cx);
3553 });
3554
3555 thread.read_with(cx, |thread, _| {
3556 assert_eq!(thread.summary().or_default(), "Brief Intro");
3557 });
3558
3559 // Test setting an empty summary (should default to DEFAULT)
3560 thread.update(cx, |thread, cx| {
3561 thread.set_summary("", cx);
3562 });
3563
3564 thread.read_with(cx, |thread, _| {
3565 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3566 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3567 });
3568 }
3569
3570 #[gpui::test]
3571 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3572 init_test_settings(cx);
3573
3574 let project = create_test_project(cx, json!({})).await;
3575
3576 let (_, _thread_store, thread, _context_store, model) =
3577 setup_test_environment(cx, project.clone()).await;
3578
3579 test_summarize_error(&model, &thread, cx);
3580
3581 // Now we should be able to set a summary
3582 thread.update(cx, |thread, cx| {
3583 thread.set_summary("Brief Intro", cx);
3584 });
3585
3586 thread.read_with(cx, |thread, _| {
3587 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3588 assert_eq!(thread.summary().or_default(), "Brief Intro");
3589 });
3590 }
3591
3592 #[gpui::test]
3593 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3594 init_test_settings(cx);
3595
3596 let project = create_test_project(cx, json!({})).await;
3597
3598 let (_, _thread_store, thread, _context_store, model) =
3599 setup_test_environment(cx, project.clone()).await;
3600
3601 test_summarize_error(&model, &thread, cx);
3602
3603 // Sending another message should not trigger another summarize request
3604 thread.update(cx, |thread, cx| {
3605 thread.insert_user_message(
3606 "How are you?",
3607 ContextLoadResult::default(),
3608 None,
3609 vec![],
3610 cx,
3611 );
3612 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3613 });
3614
3615 let fake_model = model.as_fake();
3616 simulate_successful_response(&fake_model, cx);
3617
3618 thread.read_with(cx, |thread, _| {
3619 // State is still Error, not Generating
3620 assert!(matches!(thread.summary(), ThreadSummary::Error));
3621 });
3622
3623 // But the summarize request can be invoked manually
3624 thread.update(cx, |thread, cx| {
3625 thread.summarize(cx);
3626 });
3627
3628 thread.read_with(cx, |thread, _| {
3629 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3630 });
3631
3632 cx.run_until_parked();
3633 fake_model.stream_last_completion_response("A successful summary");
3634 fake_model.end_last_completion_stream();
3635 cx.run_until_parked();
3636
3637 thread.read_with(cx, |thread, _| {
3638 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3639 assert_eq!(thread.summary().or_default(), "A successful summary");
3640 });
3641 }
3642
3643 #[gpui::test]
3644 fn test_resolve_tool_name_conflicts() {
3645 use assistant_tool::{Tool, ToolSource};
3646
3647 assert_resolve_tool_name_conflicts(
3648 vec![
3649 TestTool::new("tool1", ToolSource::Native),
3650 TestTool::new("tool2", ToolSource::Native),
3651 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3652 ],
3653 vec!["tool1", "tool2", "tool3"],
3654 );
3655
3656 assert_resolve_tool_name_conflicts(
3657 vec![
3658 TestTool::new("tool1", ToolSource::Native),
3659 TestTool::new("tool2", ToolSource::Native),
3660 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3661 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3662 ],
3663 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3664 );
3665
3666 assert_resolve_tool_name_conflicts(
3667 vec![
3668 TestTool::new("tool1", ToolSource::Native),
3669 TestTool::new("tool2", ToolSource::Native),
3670 TestTool::new("tool3", ToolSource::Native),
3671 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3672 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3673 ],
3674 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3675 );
3676
3677 // Test that tool with very long name is always truncated
3678 assert_resolve_tool_name_conflicts(
3679 vec![TestTool::new(
3680 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3681 ToolSource::Native,
3682 )],
3683 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3684 );
3685
3686 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3687 assert_resolve_tool_name_conflicts(
3688 vec![
3689 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3690 TestTool::new(
3691 "tool-with-very-very-very-long-name",
3692 ToolSource::ContextServer {
3693 id: "mcp-with-very-very-very-long-name".into(),
3694 },
3695 ),
3696 ],
3697 vec![
3698 "tool-with-very-very-very-long-name",
3699 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3700 ],
3701 );
3702
3703 fn assert_resolve_tool_name_conflicts(
3704 tools: Vec<TestTool>,
3705 expected: Vec<impl Into<String>>,
3706 ) {
3707 let tools: Vec<Arc<dyn Tool>> = tools
3708 .into_iter()
3709 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3710 .collect();
3711 let tools = resolve_tool_name_conflicts(&tools);
3712 assert_eq!(tools.len(), expected.len());
3713 for (i, expected_name) in expected.into_iter().enumerate() {
3714 let expected_name = expected_name.into();
3715 let actual_name = &tools[i].0;
3716 assert_eq!(
3717 actual_name, &expected_name,
3718 "Expected '{}' got '{}' at index {}",
3719 expected_name, actual_name, i
3720 );
3721 }
3722 }
3723
3724 struct TestTool {
3725 name: String,
3726 source: ToolSource,
3727 }
3728
3729 impl TestTool {
3730 fn new(name: impl Into<String>, source: ToolSource) -> Self {
3731 Self {
3732 name: name.into(),
3733 source,
3734 }
3735 }
3736 }
3737
3738 impl Tool for TestTool {
3739 fn name(&self) -> String {
3740 self.name.clone()
3741 }
3742
3743 fn icon(&self) -> IconName {
3744 IconName::Ai
3745 }
3746
3747 fn may_perform_edits(&self) -> bool {
3748 false
3749 }
3750
3751 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3752 true
3753 }
3754
3755 fn source(&self) -> ToolSource {
3756 self.source.clone()
3757 }
3758
3759 fn description(&self) -> String {
3760 "Test tool".to_string()
3761 }
3762
3763 fn ui_text(&self, _input: &serde_json::Value) -> String {
3764 "Test tool".to_string()
3765 }
3766
3767 fn run(
3768 self: Arc<Self>,
3769 _input: serde_json::Value,
3770 _request: Arc<LanguageModelRequest>,
3771 _project: Entity<Project>,
3772 _action_log: Entity<ActionLog>,
3773 _model: Arc<dyn LanguageModel>,
3774 _window: Option<AnyWindowHandle>,
3775 _cx: &mut App,
3776 ) -> assistant_tool::ToolResult {
3777 assistant_tool::ToolResult {
3778 output: Task::ready(Err(anyhow::anyhow!("No content"))),
3779 card: None,
3780 }
3781 }
3782 }
3783 }
3784
3785 fn test_summarize_error(
3786 model: &Arc<dyn LanguageModel>,
3787 thread: &Entity<Thread>,
3788 cx: &mut TestAppContext,
3789 ) {
3790 thread.update(cx, |thread, cx| {
3791 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3792 thread.send_to_model(
3793 model.clone(),
3794 CompletionIntent::ThreadSummarization,
3795 None,
3796 cx,
3797 );
3798 });
3799
3800 let fake_model = model.as_fake();
3801 simulate_successful_response(&fake_model, cx);
3802
3803 thread.read_with(cx, |thread, _| {
3804 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3805 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3806 });
3807
3808 // Simulate summary request ending
3809 cx.run_until_parked();
3810 fake_model.end_last_completion_stream();
3811 cx.run_until_parked();
3812
3813 // State is set to Error and default message
3814 thread.read_with(cx, |thread, _| {
3815 assert!(matches!(thread.summary(), ThreadSummary::Error));
3816 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3817 });
3818 }
3819
3820 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3821 cx.run_until_parked();
3822 fake_model.stream_last_completion_response("Assistant response");
3823 fake_model.end_last_completion_stream();
3824 cx.run_until_parked();
3825 }
3826
3827 fn init_test_settings(cx: &mut TestAppContext) {
3828 cx.update(|cx| {
3829 let settings_store = SettingsStore::test(cx);
3830 cx.set_global(settings_store);
3831 language::init(cx);
3832 Project::init_settings(cx);
3833 AgentSettings::register(cx);
3834 prompt_store::init(cx);
3835 thread_store::init(cx);
3836 workspace::init_settings(cx);
3837 language_model::init_settings(cx);
3838 ThemeSettings::register(cx);
3839 ToolRegistry::default_global(cx);
3840 });
3841 }
3842
3843 // Helper to create a test project with test files
3844 async fn create_test_project(
3845 cx: &mut TestAppContext,
3846 files: serde_json::Value,
3847 ) -> Entity<Project> {
3848 let fs = FakeFs::new(cx.executor());
3849 fs.insert_tree(path!("/test"), files).await;
3850 Project::test(fs, [path!("/test").as_ref()], cx).await
3851 }
3852
3853 async fn setup_test_environment(
3854 cx: &mut TestAppContext,
3855 project: Entity<Project>,
3856 ) -> (
3857 Entity<Workspace>,
3858 Entity<ThreadStore>,
3859 Entity<Thread>,
3860 Entity<ContextStore>,
3861 Arc<dyn LanguageModel>,
3862 ) {
3863 let (workspace, cx) =
3864 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3865
3866 let thread_store = cx
3867 .update(|_, cx| {
3868 ThreadStore::load(
3869 project.clone(),
3870 cx.new(|_| ToolWorkingSet::default()),
3871 None,
3872 Arc::new(PromptBuilder::new(None).unwrap()),
3873 cx,
3874 )
3875 })
3876 .await
3877 .unwrap();
3878
3879 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3880 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3881
3882 let provider = Arc::new(FakeLanguageModelProvider);
3883 let model = provider.test_model();
3884 let model: Arc<dyn LanguageModel> = Arc::new(model);
3885
3886 cx.update(|_, cx| {
3887 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3888 registry.set_default_model(
3889 Some(ConfiguredModel {
3890 provider: provider.clone(),
3891 model: model.clone(),
3892 }),
3893 cx,
3894 );
3895 registry.set_thread_summary_model(
3896 Some(ConfiguredModel {
3897 provider,
3898 model: model.clone(),
3899 }),
3900 cx,
3901 );
3902 })
3903 });
3904
3905 (workspace, thread_store, thread, context_store, model)
3906 }
3907
3908 async fn add_file_to_context(
3909 project: &Entity<Project>,
3910 context_store: &Entity<ContextStore>,
3911 path: &str,
3912 cx: &mut TestAppContext,
3913 ) -> Result<Entity<language::Buffer>> {
3914 let buffer_path = project
3915 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3916 .unwrap();
3917
3918 let buffer = project
3919 .update(cx, |project, cx| {
3920 project.open_buffer(buffer_path.clone(), cx)
3921 })
3922 .await
3923 .unwrap();
3924
3925 context_store.update(cx, |context_store, cx| {
3926 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3927 });
3928
3929 Ok(buffer)
3930 }
3931}