1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::{HashMap, HashSet};
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
28 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
29 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
30 TokenUsage,
31};
32use postage::stream::Stream as _;
33use project::{
34 Project,
35 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
36};
37use prompt_store::{ModelContext, PromptBuilder};
38use proto::Plan;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use settings::Settings;
42use std::{io::Write, ops::Range, sync::Arc, time::Instant};
43use thiserror::Error;
44use util::{ResultExt as _, post_inc};
45use uuid::Uuid;
46use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
47
48#[derive(
49 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
50)]
51pub struct ThreadId(Arc<str>);
52
53impl ThreadId {
54 pub fn new() -> Self {
55 Self(Uuid::new_v4().to_string().into())
56 }
57}
58
59impl std::fmt::Display for ThreadId {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65impl From<&str> for ThreadId {
66 fn from(value: &str) -> Self {
67 Self(value.into())
68 }
69}
70
71/// The ID of the user prompt that initiated a request.
72///
73/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
74#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
75pub struct PromptId(Arc<str>);
76
77impl PromptId {
78 pub fn new() -> Self {
79 Self(Uuid::new_v4().to_string().into())
80 }
81}
82
83impl std::fmt::Display for PromptId {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "{}", self.0)
86 }
87}
88
89#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
90pub struct MessageId(pub(crate) usize);
91
92impl MessageId {
93 fn post_inc(&mut self) -> Self {
94 Self(post_inc(&mut self.0))
95 }
96
97 pub fn as_usize(&self) -> usize {
98 self.0
99 }
100}
101
102/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
103#[derive(Clone, Debug)]
104pub struct MessageCrease {
105 pub range: Range<usize>,
106 pub icon_path: SharedString,
107 pub label: SharedString,
108 /// None for a deserialized message, Some otherwise.
109 pub context: Option<AgentContextHandle>,
110}
111
112/// A message in a [`Thread`].
113#[derive(Debug, Clone)]
114pub struct Message {
115 pub id: MessageId,
116 pub role: Role,
117 pub segments: Vec<MessageSegment>,
118 pub loaded_context: LoadedContext,
119 pub creases: Vec<MessageCrease>,
120 pub is_hidden: bool,
121}
122
123impl Message {
124 /// Returns whether the message contains any meaningful text that should be displayed
125 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
126 pub fn should_display_content(&self) -> bool {
127 self.segments.iter().all(|segment| segment.should_display())
128 }
129
130 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
131 if let Some(MessageSegment::Thinking {
132 text: segment,
133 signature: current_signature,
134 }) = self.segments.last_mut()
135 {
136 if let Some(signature) = signature {
137 *current_signature = Some(signature);
138 }
139 segment.push_str(text);
140 } else {
141 self.segments.push(MessageSegment::Thinking {
142 text: text.to_string(),
143 signature,
144 });
145 }
146 }
147
148 pub fn push_redacted_thinking(&mut self, data: String) {
149 self.segments.push(MessageSegment::RedactedThinking(data));
150 }
151
152 pub fn push_text(&mut self, text: &str) {
153 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
154 segment.push_str(text);
155 } else {
156 self.segments.push(MessageSegment::Text(text.to_string()));
157 }
158 }
159
160 pub fn to_string(&self) -> String {
161 let mut result = String::new();
162
163 if !self.loaded_context.text.is_empty() {
164 result.push_str(&self.loaded_context.text);
165 }
166
167 for segment in &self.segments {
168 match segment {
169 MessageSegment::Text(text) => result.push_str(text),
170 MessageSegment::Thinking { text, .. } => {
171 result.push_str("<think>\n");
172 result.push_str(text);
173 result.push_str("\n</think>");
174 }
175 MessageSegment::RedactedThinking(_) => {}
176 }
177 }
178
179 result
180 }
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub enum MessageSegment {
185 Text(String),
186 Thinking {
187 text: String,
188 signature: Option<String>,
189 },
190 RedactedThinking(String),
191}
192
193impl MessageSegment {
194 pub fn should_display(&self) -> bool {
195 match self {
196 Self::Text(text) => text.is_empty(),
197 Self::Thinking { text, .. } => text.is_empty(),
198 Self::RedactedThinking(_) => false,
199 }
200 }
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
204pub struct ProjectSnapshot {
205 pub worktree_snapshots: Vec<WorktreeSnapshot>,
206 pub unsaved_buffer_paths: Vec<String>,
207 pub timestamp: DateTime<Utc>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
211pub struct WorktreeSnapshot {
212 pub worktree_path: String,
213 pub git_state: Option<GitState>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
217pub struct GitState {
218 pub remote_url: Option<String>,
219 pub head_sha: Option<String>,
220 pub current_branch: Option<String>,
221 pub diff: Option<String>,
222}
223
224#[derive(Clone, Debug)]
225pub struct ThreadCheckpoint {
226 message_id: MessageId,
227 git_checkpoint: GitStoreCheckpoint,
228}
229
230#[derive(Copy, Clone, Debug, PartialEq, Eq)]
231pub enum ThreadFeedback {
232 Positive,
233 Negative,
234}
235
236pub enum LastRestoreCheckpoint {
237 Pending {
238 message_id: MessageId,
239 },
240 Error {
241 message_id: MessageId,
242 error: String,
243 },
244}
245
246impl LastRestoreCheckpoint {
247 pub fn message_id(&self) -> MessageId {
248 match self {
249 LastRestoreCheckpoint::Pending { message_id } => *message_id,
250 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
251 }
252 }
253}
254
255#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
256pub enum DetailedSummaryState {
257 #[default]
258 NotGenerated,
259 Generating {
260 message_id: MessageId,
261 },
262 Generated {
263 text: SharedString,
264 message_id: MessageId,
265 },
266}
267
268impl DetailedSummaryState {
269 fn text(&self) -> Option<SharedString> {
270 if let Self::Generated { text, .. } = self {
271 Some(text.clone())
272 } else {
273 None
274 }
275 }
276}
277
278#[derive(Default, Debug)]
279pub struct TotalTokenUsage {
280 pub total: u64,
281 pub max: u64,
282}
283
284impl TotalTokenUsage {
285 pub fn ratio(&self) -> TokenUsageRatio {
286 #[cfg(debug_assertions)]
287 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
288 .unwrap_or("0.8".to_string())
289 .parse()
290 .unwrap();
291 #[cfg(not(debug_assertions))]
292 let warning_threshold: f32 = 0.8;
293
294 // When the maximum is unknown because there is no selected model,
295 // avoid showing the token limit warning.
296 if self.max == 0 {
297 TokenUsageRatio::Normal
298 } else if self.total >= self.max {
299 TokenUsageRatio::Exceeded
300 } else if self.total as f32 / self.max as f32 >= warning_threshold {
301 TokenUsageRatio::Warning
302 } else {
303 TokenUsageRatio::Normal
304 }
305 }
306
307 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
308 TotalTokenUsage {
309 total: self.total + tokens,
310 max: self.max,
311 }
312 }
313}
314
315#[derive(Debug, Default, PartialEq, Eq)]
316pub enum TokenUsageRatio {
317 #[default]
318 Normal,
319 Warning,
320 Exceeded,
321}
322
323#[derive(Debug, Clone, Copy)]
324pub enum QueueState {
325 Sending,
326 Queued { position: usize },
327 Started,
328}
329
330/// A thread of conversation with the LLM.
331pub struct Thread {
332 id: ThreadId,
333 updated_at: DateTime<Utc>,
334 summary: ThreadSummary,
335 pending_summary: Task<Option<()>>,
336 detailed_summary_task: Task<Option<()>>,
337 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
338 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
339 completion_mode: agent_settings::CompletionMode,
340 messages: Vec<Message>,
341 next_message_id: MessageId,
342 last_prompt_id: PromptId,
343 project_context: SharedProjectContext,
344 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
345 completion_count: usize,
346 pending_completions: Vec<PendingCompletion>,
347 project: Entity<Project>,
348 prompt_builder: Arc<PromptBuilder>,
349 tools: Entity<ToolWorkingSet>,
350 tool_use: ToolUseState,
351 action_log: Entity<ActionLog>,
352 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
353 pending_checkpoint: Option<ThreadCheckpoint>,
354 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
355 request_token_usage: Vec<TokenUsage>,
356 cumulative_token_usage: TokenUsage,
357 exceeded_window_error: Option<ExceededWindowError>,
358 tool_use_limit_reached: bool,
359 feedback: Option<ThreadFeedback>,
360 message_feedback: HashMap<MessageId, ThreadFeedback>,
361 last_auto_capture_at: Option<Instant>,
362 last_received_chunk_at: Option<Instant>,
363 request_callback: Option<
364 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
365 >,
366 remaining_turns: u32,
367 configured_model: Option<ConfiguredModel>,
368 profile: AgentProfile,
369}
370
371#[derive(Clone, Debug, PartialEq, Eq)]
372pub enum ThreadSummary {
373 Pending,
374 Generating,
375 Ready(SharedString),
376 Error,
377}
378
379impl ThreadSummary {
380 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
381
382 pub fn or_default(&self) -> SharedString {
383 self.unwrap_or(Self::DEFAULT)
384 }
385
386 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
387 self.ready().unwrap_or_else(|| message.into())
388 }
389
390 pub fn ready(&self) -> Option<SharedString> {
391 match self {
392 ThreadSummary::Ready(summary) => Some(summary.clone()),
393 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
394 }
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
399pub struct ExceededWindowError {
400 /// Model used when last message exceeded context window
401 model_id: LanguageModelId,
402 /// Token count including last message
403 token_count: u64,
404}
405
406impl Thread {
407 pub fn new(
408 project: Entity<Project>,
409 tools: Entity<ToolWorkingSet>,
410 prompt_builder: Arc<PromptBuilder>,
411 system_prompt: SharedProjectContext,
412 cx: &mut Context<Self>,
413 ) -> Self {
414 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
415 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
416 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
417
418 Self {
419 id: ThreadId::new(),
420 updated_at: Utc::now(),
421 summary: ThreadSummary::Pending,
422 pending_summary: Task::ready(None),
423 detailed_summary_task: Task::ready(None),
424 detailed_summary_tx,
425 detailed_summary_rx,
426 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
427 messages: Vec::new(),
428 next_message_id: MessageId(0),
429 last_prompt_id: PromptId::new(),
430 project_context: system_prompt,
431 checkpoints_by_message: HashMap::default(),
432 completion_count: 0,
433 pending_completions: Vec::new(),
434 project: project.clone(),
435 prompt_builder,
436 tools: tools.clone(),
437 last_restore_checkpoint: None,
438 pending_checkpoint: None,
439 tool_use: ToolUseState::new(tools.clone()),
440 action_log: cx.new(|_| ActionLog::new(project.clone())),
441 initial_project_snapshot: {
442 let project_snapshot = Self::project_snapshot(project, cx);
443 cx.foreground_executor()
444 .spawn(async move { Some(project_snapshot.await) })
445 .shared()
446 },
447 request_token_usage: Vec::new(),
448 cumulative_token_usage: TokenUsage::default(),
449 exceeded_window_error: None,
450 tool_use_limit_reached: false,
451 feedback: None,
452 message_feedback: HashMap::default(),
453 last_auto_capture_at: None,
454 last_received_chunk_at: None,
455 request_callback: None,
456 remaining_turns: u32::MAX,
457 configured_model,
458 profile: AgentProfile::new(profile_id, tools),
459 }
460 }
461
462 pub fn deserialize(
463 id: ThreadId,
464 serialized: SerializedThread,
465 project: Entity<Project>,
466 tools: Entity<ToolWorkingSet>,
467 prompt_builder: Arc<PromptBuilder>,
468 project_context: SharedProjectContext,
469 window: Option<&mut Window>, // None in headless mode
470 cx: &mut Context<Self>,
471 ) -> Self {
472 let next_message_id = MessageId(
473 serialized
474 .messages
475 .last()
476 .map(|message| message.id.0 + 1)
477 .unwrap_or(0),
478 );
479 let tool_use = ToolUseState::from_serialized_messages(
480 tools.clone(),
481 &serialized.messages,
482 project.clone(),
483 window,
484 cx,
485 );
486 let (detailed_summary_tx, detailed_summary_rx) =
487 postage::watch::channel_with(serialized.detailed_summary_state);
488
489 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
490 serialized
491 .model
492 .and_then(|model| {
493 let model = SelectedModel {
494 provider: model.provider.clone().into(),
495 model: model.model.clone().into(),
496 };
497 registry.select_model(&model, cx)
498 })
499 .or_else(|| registry.default_model())
500 });
501
502 let completion_mode = serialized
503 .completion_mode
504 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
505 let profile_id = serialized
506 .profile
507 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
508
509 Self {
510 id,
511 updated_at: serialized.updated_at,
512 summary: ThreadSummary::Ready(serialized.summary),
513 pending_summary: Task::ready(None),
514 detailed_summary_task: Task::ready(None),
515 detailed_summary_tx,
516 detailed_summary_rx,
517 completion_mode,
518 messages: serialized
519 .messages
520 .into_iter()
521 .map(|message| Message {
522 id: message.id,
523 role: message.role,
524 segments: message
525 .segments
526 .into_iter()
527 .map(|segment| match segment {
528 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
529 SerializedMessageSegment::Thinking { text, signature } => {
530 MessageSegment::Thinking { text, signature }
531 }
532 SerializedMessageSegment::RedactedThinking { data } => {
533 MessageSegment::RedactedThinking(data)
534 }
535 })
536 .collect(),
537 loaded_context: LoadedContext {
538 contexts: Vec::new(),
539 text: message.context,
540 images: Vec::new(),
541 },
542 creases: message
543 .creases
544 .into_iter()
545 .map(|crease| MessageCrease {
546 range: crease.start..crease.end,
547 icon_path: crease.icon_path,
548 label: crease.label,
549 context: None,
550 })
551 .collect(),
552 is_hidden: message.is_hidden,
553 })
554 .collect(),
555 next_message_id,
556 last_prompt_id: PromptId::new(),
557 project_context,
558 checkpoints_by_message: HashMap::default(),
559 completion_count: 0,
560 pending_completions: Vec::new(),
561 last_restore_checkpoint: None,
562 pending_checkpoint: None,
563 project: project.clone(),
564 prompt_builder,
565 tools: tools.clone(),
566 tool_use,
567 action_log: cx.new(|_| ActionLog::new(project)),
568 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
569 request_token_usage: serialized.request_token_usage,
570 cumulative_token_usage: serialized.cumulative_token_usage,
571 exceeded_window_error: None,
572 tool_use_limit_reached: serialized.tool_use_limit_reached,
573 feedback: None,
574 message_feedback: HashMap::default(),
575 last_auto_capture_at: None,
576 last_received_chunk_at: None,
577 request_callback: None,
578 remaining_turns: u32::MAX,
579 configured_model,
580 profile: AgentProfile::new(profile_id, tools),
581 }
582 }
583
584 pub fn set_request_callback(
585 &mut self,
586 callback: impl 'static
587 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
588 ) {
589 self.request_callback = Some(Box::new(callback));
590 }
591
592 pub fn id(&self) -> &ThreadId {
593 &self.id
594 }
595
596 pub fn profile(&self) -> &AgentProfile {
597 &self.profile
598 }
599
600 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
601 if &id != self.profile.id() {
602 self.profile = AgentProfile::new(id, self.tools.clone());
603 cx.emit(ThreadEvent::ProfileChanged);
604 }
605 }
606
607 pub fn is_empty(&self) -> bool {
608 self.messages.is_empty()
609 }
610
611 pub fn updated_at(&self) -> DateTime<Utc> {
612 self.updated_at
613 }
614
615 pub fn touch_updated_at(&mut self) {
616 self.updated_at = Utc::now();
617 }
618
619 pub fn advance_prompt_id(&mut self) {
620 self.last_prompt_id = PromptId::new();
621 }
622
623 pub fn project_context(&self) -> SharedProjectContext {
624 self.project_context.clone()
625 }
626
627 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
628 if self.configured_model.is_none() {
629 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
630 }
631 self.configured_model.clone()
632 }
633
634 pub fn configured_model(&self) -> Option<ConfiguredModel> {
635 self.configured_model.clone()
636 }
637
638 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
639 self.configured_model = model;
640 cx.notify();
641 }
642
643 pub fn summary(&self) -> &ThreadSummary {
644 &self.summary
645 }
646
647 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
648 let current_summary = match &self.summary {
649 ThreadSummary::Pending | ThreadSummary::Generating => return,
650 ThreadSummary::Ready(summary) => summary,
651 ThreadSummary::Error => &ThreadSummary::DEFAULT,
652 };
653
654 let mut new_summary = new_summary.into();
655
656 if new_summary.is_empty() {
657 new_summary = ThreadSummary::DEFAULT;
658 }
659
660 if current_summary != &new_summary {
661 self.summary = ThreadSummary::Ready(new_summary);
662 cx.emit(ThreadEvent::SummaryChanged);
663 }
664 }
665
666 pub fn completion_mode(&self) -> CompletionMode {
667 self.completion_mode
668 }
669
670 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
671 self.completion_mode = mode;
672 }
673
674 pub fn message(&self, id: MessageId) -> Option<&Message> {
675 let index = self
676 .messages
677 .binary_search_by(|message| message.id.cmp(&id))
678 .ok()?;
679
680 self.messages.get(index)
681 }
682
683 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
684 self.messages.iter()
685 }
686
687 pub fn is_generating(&self) -> bool {
688 !self.pending_completions.is_empty() || !self.all_tools_finished()
689 }
690
691 /// Indicates whether streaming of language model events is stale.
692 /// When `is_generating()` is false, this method returns `None`.
693 pub fn is_generation_stale(&self) -> Option<bool> {
694 const STALE_THRESHOLD: u128 = 250;
695
696 self.last_received_chunk_at
697 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
698 }
699
700 fn received_chunk(&mut self) {
701 self.last_received_chunk_at = Some(Instant::now());
702 }
703
704 pub fn queue_state(&self) -> Option<QueueState> {
705 self.pending_completions
706 .first()
707 .map(|pending_completion| pending_completion.queue_state)
708 }
709
710 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
711 &self.tools
712 }
713
714 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
715 self.tool_use
716 .pending_tool_uses()
717 .into_iter()
718 .find(|tool_use| &tool_use.id == id)
719 }
720
721 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
722 self.tool_use
723 .pending_tool_uses()
724 .into_iter()
725 .filter(|tool_use| tool_use.status.needs_confirmation())
726 }
727
728 pub fn has_pending_tool_uses(&self) -> bool {
729 !self.tool_use.pending_tool_uses().is_empty()
730 }
731
732 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
733 self.checkpoints_by_message.get(&id).cloned()
734 }
735
736 pub fn restore_checkpoint(
737 &mut self,
738 checkpoint: ThreadCheckpoint,
739 cx: &mut Context<Self>,
740 ) -> Task<Result<()>> {
741 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
742 message_id: checkpoint.message_id,
743 });
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746
747 let git_store = self.project().read(cx).git_store().clone();
748 let restore = git_store.update(cx, |git_store, cx| {
749 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
750 });
751
752 cx.spawn(async move |this, cx| {
753 let result = restore.await;
754 this.update(cx, |this, cx| {
755 if let Err(err) = result.as_ref() {
756 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
757 message_id: checkpoint.message_id,
758 error: err.to_string(),
759 });
760 } else {
761 this.truncate(checkpoint.message_id, cx);
762 this.last_restore_checkpoint = None;
763 }
764 this.pending_checkpoint = None;
765 cx.emit(ThreadEvent::CheckpointChanged);
766 cx.notify();
767 })?;
768 result
769 })
770 }
771
772 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
773 let pending_checkpoint = if self.is_generating() {
774 return;
775 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
776 checkpoint
777 } else {
778 return;
779 };
780
781 self.finalize_checkpoint(pending_checkpoint, cx);
782 }
783
784 fn finalize_checkpoint(
785 &mut self,
786 pending_checkpoint: ThreadCheckpoint,
787 cx: &mut Context<Self>,
788 ) {
789 let git_store = self.project.read(cx).git_store().clone();
790 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
791 cx.spawn(async move |this, cx| match final_checkpoint.await {
792 Ok(final_checkpoint) => {
793 let equal = git_store
794 .update(cx, |store, cx| {
795 store.compare_checkpoints(
796 pending_checkpoint.git_checkpoint.clone(),
797 final_checkpoint.clone(),
798 cx,
799 )
800 })?
801 .await
802 .unwrap_or(false);
803
804 if !equal {
805 this.update(cx, |this, cx| {
806 this.insert_checkpoint(pending_checkpoint, cx)
807 })?;
808 }
809
810 Ok(())
811 }
812 Err(_) => this.update(cx, |this, cx| {
813 this.insert_checkpoint(pending_checkpoint, cx)
814 }),
815 })
816 .detach();
817 }
818
819 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
820 self.checkpoints_by_message
821 .insert(checkpoint.message_id, checkpoint);
822 cx.emit(ThreadEvent::CheckpointChanged);
823 cx.notify();
824 }
825
826 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
827 self.last_restore_checkpoint.as_ref()
828 }
829
830 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
831 let Some(message_ix) = self
832 .messages
833 .iter()
834 .rposition(|message| message.id == message_id)
835 else {
836 return;
837 };
838 for deleted_message in self.messages.drain(message_ix..) {
839 self.checkpoints_by_message.remove(&deleted_message.id);
840 }
841 cx.notify();
842 }
843
844 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
845 self.messages
846 .iter()
847 .find(|message| message.id == id)
848 .into_iter()
849 .flat_map(|message| message.loaded_context.contexts.iter())
850 }
851
852 pub fn is_turn_end(&self, ix: usize) -> bool {
853 if self.messages.is_empty() {
854 return false;
855 }
856
857 if !self.is_generating() && ix == self.messages.len() - 1 {
858 return true;
859 }
860
861 let Some(message) = self.messages.get(ix) else {
862 return false;
863 };
864
865 if message.role != Role::Assistant {
866 return false;
867 }
868
869 self.messages
870 .get(ix + 1)
871 .and_then(|message| {
872 self.message(message.id)
873 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
874 })
875 .unwrap_or(false)
876 }
877
878 pub fn tool_use_limit_reached(&self) -> bool {
879 self.tool_use_limit_reached
880 }
881
882 /// Returns whether all of the tool uses have finished running.
883 pub fn all_tools_finished(&self) -> bool {
884 // If the only pending tool uses left are the ones with errors, then
885 // that means that we've finished running all of the pending tools.
886 self.tool_use
887 .pending_tool_uses()
888 .iter()
889 .all(|pending_tool_use| pending_tool_use.status.is_error())
890 }
891
892 /// Returns whether any pending tool uses may perform edits
893 pub fn has_pending_edit_tool_uses(&self) -> bool {
894 self.tool_use
895 .pending_tool_uses()
896 .iter()
897 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
898 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
899 }
900
901 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
902 self.tool_use.tool_uses_for_message(id, cx)
903 }
904
905 pub fn tool_results_for_message(
906 &self,
907 assistant_message_id: MessageId,
908 ) -> Vec<&LanguageModelToolResult> {
909 self.tool_use.tool_results_for_message(assistant_message_id)
910 }
911
912 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
913 self.tool_use.tool_result(id)
914 }
915
916 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
917 match &self.tool_use.tool_result(id)?.content {
918 LanguageModelToolResultContent::Text(text) => Some(text),
919 LanguageModelToolResultContent::Image(_) => {
920 // TODO: We should display image
921 None
922 }
923 }
924 }
925
926 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
927 self.tool_use.tool_result_card(id).cloned()
928 }
929
930 /// Return tools that are both enabled and supported by the model
931 pub fn available_tools(
932 &self,
933 cx: &App,
934 model: Arc<dyn LanguageModel>,
935 ) -> Vec<LanguageModelRequestTool> {
936 if model.supports_tools() {
937 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
938 .into_iter()
939 .filter_map(|(name, tool)| {
940 // Skip tools that cannot be supported
941 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
942 Some(LanguageModelRequestTool {
943 name,
944 description: tool.description(),
945 input_schema,
946 })
947 })
948 .collect()
949 } else {
950 Vec::default()
951 }
952 }
953
954 pub fn insert_user_message(
955 &mut self,
956 text: impl Into<String>,
957 loaded_context: ContextLoadResult,
958 git_checkpoint: Option<GitStoreCheckpoint>,
959 creases: Vec<MessageCrease>,
960 cx: &mut Context<Self>,
961 ) -> MessageId {
962 if !loaded_context.referenced_buffers.is_empty() {
963 self.action_log.update(cx, |log, cx| {
964 for buffer in loaded_context.referenced_buffers {
965 log.buffer_read(buffer, cx);
966 }
967 });
968 }
969
970 let message_id = self.insert_message(
971 Role::User,
972 vec![MessageSegment::Text(text.into())],
973 loaded_context.loaded_context,
974 creases,
975 false,
976 cx,
977 );
978
979 if let Some(git_checkpoint) = git_checkpoint {
980 self.pending_checkpoint = Some(ThreadCheckpoint {
981 message_id,
982 git_checkpoint,
983 });
984 }
985
986 self.auto_capture_telemetry(cx);
987
988 message_id
989 }
990
991 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
992 let id = self.insert_message(
993 Role::User,
994 vec![MessageSegment::Text("Continue where you left off".into())],
995 LoadedContext::default(),
996 vec![],
997 true,
998 cx,
999 );
1000 self.pending_checkpoint = None;
1001
1002 id
1003 }
1004
1005 pub fn insert_assistant_message(
1006 &mut self,
1007 segments: Vec<MessageSegment>,
1008 cx: &mut Context<Self>,
1009 ) -> MessageId {
1010 self.insert_message(
1011 Role::Assistant,
1012 segments,
1013 LoadedContext::default(),
1014 Vec::new(),
1015 false,
1016 cx,
1017 )
1018 }
1019
1020 pub fn insert_message(
1021 &mut self,
1022 role: Role,
1023 segments: Vec<MessageSegment>,
1024 loaded_context: LoadedContext,
1025 creases: Vec<MessageCrease>,
1026 is_hidden: bool,
1027 cx: &mut Context<Self>,
1028 ) -> MessageId {
1029 let id = self.next_message_id.post_inc();
1030 self.messages.push(Message {
1031 id,
1032 role,
1033 segments,
1034 loaded_context,
1035 creases,
1036 is_hidden,
1037 });
1038 self.touch_updated_at();
1039 cx.emit(ThreadEvent::MessageAdded(id));
1040 id
1041 }
1042
1043 pub fn edit_message(
1044 &mut self,
1045 id: MessageId,
1046 new_role: Role,
1047 new_segments: Vec<MessageSegment>,
1048 creases: Vec<MessageCrease>,
1049 loaded_context: Option<LoadedContext>,
1050 checkpoint: Option<GitStoreCheckpoint>,
1051 cx: &mut Context<Self>,
1052 ) -> bool {
1053 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1054 return false;
1055 };
1056 message.role = new_role;
1057 message.segments = new_segments;
1058 message.creases = creases;
1059 if let Some(context) = loaded_context {
1060 message.loaded_context = context;
1061 }
1062 if let Some(git_checkpoint) = checkpoint {
1063 self.checkpoints_by_message.insert(
1064 id,
1065 ThreadCheckpoint {
1066 message_id: id,
1067 git_checkpoint,
1068 },
1069 );
1070 }
1071 self.touch_updated_at();
1072 cx.emit(ThreadEvent::MessageEdited(id));
1073 true
1074 }
1075
1076 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1077 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1078 return false;
1079 };
1080 self.messages.remove(index);
1081 self.touch_updated_at();
1082 cx.emit(ThreadEvent::MessageDeleted(id));
1083 true
1084 }
1085
1086 /// Returns the representation of this [`Thread`] in a textual form.
1087 ///
1088 /// This is the representation we use when attaching a thread as context to another thread.
1089 pub fn text(&self) -> String {
1090 let mut text = String::new();
1091
1092 for message in &self.messages {
1093 text.push_str(match message.role {
1094 language_model::Role::User => "User:",
1095 language_model::Role::Assistant => "Agent:",
1096 language_model::Role::System => "System:",
1097 });
1098 text.push('\n');
1099
1100 for segment in &message.segments {
1101 match segment {
1102 MessageSegment::Text(content) => text.push_str(content),
1103 MessageSegment::Thinking { text: content, .. } => {
1104 text.push_str(&format!("<think>{}</think>", content))
1105 }
1106 MessageSegment::RedactedThinking(_) => {}
1107 }
1108 }
1109 text.push('\n');
1110 }
1111
1112 text
1113 }
1114
1115 /// Serializes this thread into a format for storage or telemetry.
1116 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1117 let initial_project_snapshot = self.initial_project_snapshot.clone();
1118 cx.spawn(async move |this, cx| {
1119 let initial_project_snapshot = initial_project_snapshot.await;
1120 this.read_with(cx, |this, cx| SerializedThread {
1121 version: SerializedThread::VERSION.to_string(),
1122 summary: this.summary().or_default(),
1123 updated_at: this.updated_at(),
1124 messages: this
1125 .messages()
1126 .map(|message| SerializedMessage {
1127 id: message.id,
1128 role: message.role,
1129 segments: message
1130 .segments
1131 .iter()
1132 .map(|segment| match segment {
1133 MessageSegment::Text(text) => {
1134 SerializedMessageSegment::Text { text: text.clone() }
1135 }
1136 MessageSegment::Thinking { text, signature } => {
1137 SerializedMessageSegment::Thinking {
1138 text: text.clone(),
1139 signature: signature.clone(),
1140 }
1141 }
1142 MessageSegment::RedactedThinking(data) => {
1143 SerializedMessageSegment::RedactedThinking {
1144 data: data.clone(),
1145 }
1146 }
1147 })
1148 .collect(),
1149 tool_uses: this
1150 .tool_uses_for_message(message.id, cx)
1151 .into_iter()
1152 .map(|tool_use| SerializedToolUse {
1153 id: tool_use.id,
1154 name: tool_use.name,
1155 input: tool_use.input,
1156 })
1157 .collect(),
1158 tool_results: this
1159 .tool_results_for_message(message.id)
1160 .into_iter()
1161 .map(|tool_result| SerializedToolResult {
1162 tool_use_id: tool_result.tool_use_id.clone(),
1163 is_error: tool_result.is_error,
1164 content: tool_result.content.clone(),
1165 output: tool_result.output.clone(),
1166 })
1167 .collect(),
1168 context: message.loaded_context.text.clone(),
1169 creases: message
1170 .creases
1171 .iter()
1172 .map(|crease| SerializedCrease {
1173 start: crease.range.start,
1174 end: crease.range.end,
1175 icon_path: crease.icon_path.clone(),
1176 label: crease.label.clone(),
1177 })
1178 .collect(),
1179 is_hidden: message.is_hidden,
1180 })
1181 .collect(),
1182 initial_project_snapshot,
1183 cumulative_token_usage: this.cumulative_token_usage,
1184 request_token_usage: this.request_token_usage.clone(),
1185 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1186 exceeded_window_error: this.exceeded_window_error.clone(),
1187 model: this
1188 .configured_model
1189 .as_ref()
1190 .map(|model| SerializedLanguageModel {
1191 provider: model.provider.id().0.to_string(),
1192 model: model.model.id().0.to_string(),
1193 }),
1194 completion_mode: Some(this.completion_mode),
1195 tool_use_limit_reached: this.tool_use_limit_reached,
1196 profile: Some(this.profile.id().clone()),
1197 })
1198 })
1199 }
1200
1201 pub fn remaining_turns(&self) -> u32 {
1202 self.remaining_turns
1203 }
1204
1205 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1206 self.remaining_turns = remaining_turns;
1207 }
1208
1209 pub fn send_to_model(
1210 &mut self,
1211 model: Arc<dyn LanguageModel>,
1212 intent: CompletionIntent,
1213 window: Option<AnyWindowHandle>,
1214 cx: &mut Context<Self>,
1215 ) {
1216 if self.remaining_turns == 0 {
1217 return;
1218 }
1219
1220 self.remaining_turns -= 1;
1221
1222 let request = self.to_completion_request(model.clone(), intent, cx);
1223
1224 self.stream_completion(request, model, window, cx);
1225 }
1226
1227 pub fn used_tools_since_last_user_message(&self) -> bool {
1228 for message in self.messages.iter().rev() {
1229 if self.tool_use.message_has_tool_results(message.id) {
1230 return true;
1231 } else if message.role == Role::User {
1232 return false;
1233 }
1234 }
1235
1236 false
1237 }
1238
1239 pub fn to_completion_request(
1240 &self,
1241 model: Arc<dyn LanguageModel>,
1242 intent: CompletionIntent,
1243 cx: &mut Context<Self>,
1244 ) -> LanguageModelRequest {
1245 let mut request = LanguageModelRequest {
1246 thread_id: Some(self.id.to_string()),
1247 prompt_id: Some(self.last_prompt_id.to_string()),
1248 intent: Some(intent),
1249 mode: None,
1250 messages: vec![],
1251 tools: Vec::new(),
1252 tool_choice: None,
1253 stop: Vec::new(),
1254 temperature: AgentSettings::temperature_for_model(&model, cx),
1255 };
1256
1257 let available_tools = self.available_tools(cx, model.clone());
1258 let available_tool_names = available_tools
1259 .iter()
1260 .map(|tool| tool.name.clone())
1261 .collect();
1262
1263 let model_context = &ModelContext {
1264 available_tools: available_tool_names,
1265 };
1266
1267 if let Some(project_context) = self.project_context.borrow().as_ref() {
1268 match self
1269 .prompt_builder
1270 .generate_assistant_system_prompt(project_context, model_context)
1271 {
1272 Err(err) => {
1273 let message = format!("{err:?}").into();
1274 log::error!("{message}");
1275 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1276 header: "Error generating system prompt".into(),
1277 message,
1278 }));
1279 }
1280 Ok(system_prompt) => {
1281 request.messages.push(LanguageModelRequestMessage {
1282 role: Role::System,
1283 content: vec![MessageContent::Text(system_prompt)],
1284 cache: true,
1285 });
1286 }
1287 }
1288 } else {
1289 let message = "Context for system prompt unexpectedly not ready.".into();
1290 log::error!("{message}");
1291 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1292 header: "Error generating system prompt".into(),
1293 message,
1294 }));
1295 }
1296
1297 let mut message_ix_to_cache = None;
1298 for message in &self.messages {
1299 let mut request_message = LanguageModelRequestMessage {
1300 role: message.role,
1301 content: Vec::new(),
1302 cache: false,
1303 };
1304
1305 message
1306 .loaded_context
1307 .add_to_request_message(&mut request_message);
1308
1309 for segment in &message.segments {
1310 match segment {
1311 MessageSegment::Text(text) => {
1312 if !text.is_empty() {
1313 request_message
1314 .content
1315 .push(MessageContent::Text(text.into()));
1316 }
1317 }
1318 MessageSegment::Thinking { text, signature } => {
1319 if !text.is_empty() {
1320 request_message.content.push(MessageContent::Thinking {
1321 text: text.into(),
1322 signature: signature.clone(),
1323 });
1324 }
1325 }
1326 MessageSegment::RedactedThinking(data) => {
1327 request_message
1328 .content
1329 .push(MessageContent::RedactedThinking(data.clone()));
1330 }
1331 };
1332 }
1333
1334 let mut cache_message = true;
1335 let mut tool_results_message = LanguageModelRequestMessage {
1336 role: Role::User,
1337 content: Vec::new(),
1338 cache: false,
1339 };
1340 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1341 if let Some(tool_result) = tool_result {
1342 request_message
1343 .content
1344 .push(MessageContent::ToolUse(tool_use.clone()));
1345 tool_results_message
1346 .content
1347 .push(MessageContent::ToolResult(LanguageModelToolResult {
1348 tool_use_id: tool_use.id.clone(),
1349 tool_name: tool_result.tool_name.clone(),
1350 is_error: tool_result.is_error,
1351 content: if tool_result.content.is_empty() {
1352 // Surprisingly, the API fails if we return an empty string here.
1353 // It thinks we are sending a tool use without a tool result.
1354 "<Tool returned an empty string>".into()
1355 } else {
1356 tool_result.content.clone()
1357 },
1358 output: None,
1359 }));
1360 } else {
1361 cache_message = false;
1362 log::debug!(
1363 "skipped tool use {:?} because it is still pending",
1364 tool_use
1365 );
1366 }
1367 }
1368
1369 if cache_message {
1370 message_ix_to_cache = Some(request.messages.len());
1371 }
1372 request.messages.push(request_message);
1373
1374 if !tool_results_message.content.is_empty() {
1375 if cache_message {
1376 message_ix_to_cache = Some(request.messages.len());
1377 }
1378 request.messages.push(tool_results_message);
1379 }
1380 }
1381
1382 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1383 if let Some(message_ix_to_cache) = message_ix_to_cache {
1384 request.messages[message_ix_to_cache].cache = true;
1385 }
1386
1387 request.tools = available_tools;
1388 request.mode = if model.supports_max_mode() {
1389 Some(self.completion_mode.into())
1390 } else {
1391 Some(CompletionMode::Normal.into())
1392 };
1393
1394 request
1395 }
1396
1397 fn to_summarize_request(
1398 &self,
1399 model: &Arc<dyn LanguageModel>,
1400 intent: CompletionIntent,
1401 added_user_message: String,
1402 cx: &App,
1403 ) -> LanguageModelRequest {
1404 let mut request = LanguageModelRequest {
1405 thread_id: None,
1406 prompt_id: None,
1407 intent: Some(intent),
1408 mode: None,
1409 messages: vec![],
1410 tools: Vec::new(),
1411 tool_choice: None,
1412 stop: Vec::new(),
1413 temperature: AgentSettings::temperature_for_model(model, cx),
1414 };
1415
1416 for message in &self.messages {
1417 let mut request_message = LanguageModelRequestMessage {
1418 role: message.role,
1419 content: Vec::new(),
1420 cache: false,
1421 };
1422
1423 for segment in &message.segments {
1424 match segment {
1425 MessageSegment::Text(text) => request_message
1426 .content
1427 .push(MessageContent::Text(text.clone())),
1428 MessageSegment::Thinking { .. } => {}
1429 MessageSegment::RedactedThinking(_) => {}
1430 }
1431 }
1432
1433 if request_message.content.is_empty() {
1434 continue;
1435 }
1436
1437 request.messages.push(request_message);
1438 }
1439
1440 request.messages.push(LanguageModelRequestMessage {
1441 role: Role::User,
1442 content: vec![MessageContent::Text(added_user_message)],
1443 cache: false,
1444 });
1445
1446 request
1447 }
1448
1449 pub fn stream_completion(
1450 &mut self,
1451 request: LanguageModelRequest,
1452 model: Arc<dyn LanguageModel>,
1453 window: Option<AnyWindowHandle>,
1454 cx: &mut Context<Self>,
1455 ) {
1456 self.tool_use_limit_reached = false;
1457
1458 let pending_completion_id = post_inc(&mut self.completion_count);
1459 let mut request_callback_parameters = if self.request_callback.is_some() {
1460 Some((request.clone(), Vec::new()))
1461 } else {
1462 None
1463 };
1464 let prompt_id = self.last_prompt_id.clone();
1465 let tool_use_metadata = ToolUseMetadata {
1466 model: model.clone(),
1467 thread_id: self.id.clone(),
1468 prompt_id: prompt_id.clone(),
1469 };
1470
1471 self.last_received_chunk_at = Some(Instant::now());
1472
1473 let task = cx.spawn(async move |thread, cx| {
1474 let stream_completion_future = model.stream_completion(request, &cx);
1475 let initial_token_usage =
1476 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1477 let stream_completion = async {
1478 let mut events = stream_completion_future.await?;
1479
1480 let mut stop_reason = StopReason::EndTurn;
1481 let mut current_token_usage = TokenUsage::default();
1482
1483 thread
1484 .update(cx, |_thread, cx| {
1485 cx.emit(ThreadEvent::NewRequest);
1486 })
1487 .ok();
1488
1489 let mut request_assistant_message_id = None;
1490
1491 while let Some(event) = events.next().await {
1492 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1493 response_events
1494 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1495 }
1496
1497 thread.update(cx, |thread, cx| {
1498 let event = match event {
1499 Ok(event) => event,
1500 Err(error) => {
1501 match error {
1502 LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
1503 anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
1504 }
1505 LanguageModelCompletionError::Overloaded => {
1506 anyhow::bail!(LanguageModelKnownError::Overloaded);
1507 }
1508 LanguageModelCompletionError::ApiInternalServerError =>{
1509 anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
1510 }
1511 LanguageModelCompletionError::PromptTooLarge { tokens } => {
1512 let tokens = tokens.unwrap_or_else(|| {
1513 // We didn't get an exact token count from the API, so fall back on our estimate.
1514 thread.total_token_usage()
1515 .map(|usage| usage.total)
1516 .unwrap_or(0)
1517 // We know the context window was exceeded in practice, so if our estimate was
1518 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1519 .max(model.max_token_count().saturating_add(1))
1520 });
1521
1522 anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
1523 }
1524 LanguageModelCompletionError::ApiReadResponseError(io_error) => {
1525 anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
1526 }
1527 LanguageModelCompletionError::UnknownResponseFormat(error) => {
1528 anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
1529 }
1530 LanguageModelCompletionError::HttpResponseError { status, ref body } => {
1531 if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
1532 anyhow::bail!(known_error);
1533 } else {
1534 return Err(error.into());
1535 }
1536 }
1537 LanguageModelCompletionError::DeserializeResponse(error) => {
1538 anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
1539 }
1540 LanguageModelCompletionError::BadInputJson {
1541 id,
1542 tool_name,
1543 raw_input: invalid_input_json,
1544 json_parse_error,
1545 } => {
1546 thread.receive_invalid_tool_json(
1547 id,
1548 tool_name,
1549 invalid_input_json,
1550 json_parse_error,
1551 window,
1552 cx,
1553 );
1554 return Ok(());
1555 }
1556 // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
1557 err @ LanguageModelCompletionError::BadRequestFormat |
1558 err @ LanguageModelCompletionError::AuthenticationError |
1559 err @ LanguageModelCompletionError::PermissionError |
1560 err @ LanguageModelCompletionError::ApiEndpointNotFound |
1561 err @ LanguageModelCompletionError::SerializeRequest(_) |
1562 err @ LanguageModelCompletionError::BuildRequestBody(_) |
1563 err @ LanguageModelCompletionError::HttpSend(_) => {
1564 anyhow::bail!(err);
1565 }
1566 LanguageModelCompletionError::Other(error) => {
1567 return Err(error);
1568 }
1569 }
1570 }
1571 };
1572
1573 match event {
1574 LanguageModelCompletionEvent::StartMessage { .. } => {
1575 request_assistant_message_id =
1576 Some(thread.insert_assistant_message(
1577 vec![MessageSegment::Text(String::new())],
1578 cx,
1579 ));
1580 }
1581 LanguageModelCompletionEvent::Stop(reason) => {
1582 stop_reason = reason;
1583 }
1584 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1585 thread.update_token_usage_at_last_message(token_usage);
1586 thread.cumulative_token_usage = thread.cumulative_token_usage
1587 + token_usage
1588 - current_token_usage;
1589 current_token_usage = token_usage;
1590 }
1591 LanguageModelCompletionEvent::Text(chunk) => {
1592 thread.received_chunk();
1593
1594 cx.emit(ThreadEvent::ReceivedTextChunk);
1595 if let Some(last_message) = thread.messages.last_mut() {
1596 if last_message.role == Role::Assistant
1597 && !thread.tool_use.has_tool_results(last_message.id)
1598 {
1599 last_message.push_text(&chunk);
1600 cx.emit(ThreadEvent::StreamedAssistantText(
1601 last_message.id,
1602 chunk,
1603 ));
1604 } else {
1605 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1606 // of a new Assistant response.
1607 //
1608 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1609 // will result in duplicating the text of the chunk in the rendered Markdown.
1610 request_assistant_message_id =
1611 Some(thread.insert_assistant_message(
1612 vec![MessageSegment::Text(chunk.to_string())],
1613 cx,
1614 ));
1615 };
1616 }
1617 }
1618 LanguageModelCompletionEvent::Thinking {
1619 text: chunk,
1620 signature,
1621 } => {
1622 thread.received_chunk();
1623
1624 if let Some(last_message) = thread.messages.last_mut() {
1625 if last_message.role == Role::Assistant
1626 && !thread.tool_use.has_tool_results(last_message.id)
1627 {
1628 last_message.push_thinking(&chunk, signature);
1629 cx.emit(ThreadEvent::StreamedAssistantThinking(
1630 last_message.id,
1631 chunk,
1632 ));
1633 } else {
1634 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1635 // of a new Assistant response.
1636 //
1637 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1638 // will result in duplicating the text of the chunk in the rendered Markdown.
1639 request_assistant_message_id =
1640 Some(thread.insert_assistant_message(
1641 vec![MessageSegment::Thinking {
1642 text: chunk.to_string(),
1643 signature,
1644 }],
1645 cx,
1646 ));
1647 };
1648 }
1649 }
1650 LanguageModelCompletionEvent::RedactedThinking {
1651 data
1652 } => {
1653 thread.received_chunk();
1654
1655 if let Some(last_message) = thread.messages.last_mut() {
1656 if last_message.role == Role::Assistant
1657 && !thread.tool_use.has_tool_results(last_message.id)
1658 {
1659 last_message.push_redacted_thinking(data);
1660 } else {
1661 request_assistant_message_id =
1662 Some(thread.insert_assistant_message(
1663 vec![MessageSegment::RedactedThinking(data)],
1664 cx,
1665 ));
1666 };
1667 }
1668 }
1669 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1670 let last_assistant_message_id = request_assistant_message_id
1671 .unwrap_or_else(|| {
1672 let new_assistant_message_id =
1673 thread.insert_assistant_message(vec![], cx);
1674 request_assistant_message_id =
1675 Some(new_assistant_message_id);
1676 new_assistant_message_id
1677 });
1678
1679 let tool_use_id = tool_use.id.clone();
1680 let streamed_input = if tool_use.is_input_complete {
1681 None
1682 } else {
1683 Some((&tool_use.input).clone())
1684 };
1685
1686 let ui_text = thread.tool_use.request_tool_use(
1687 last_assistant_message_id,
1688 tool_use,
1689 tool_use_metadata.clone(),
1690 cx,
1691 );
1692
1693 if let Some(input) = streamed_input {
1694 cx.emit(ThreadEvent::StreamedToolUse {
1695 tool_use_id,
1696 ui_text,
1697 input,
1698 });
1699 }
1700 }
1701 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1702 if let Some(completion) = thread
1703 .pending_completions
1704 .iter_mut()
1705 .find(|completion| completion.id == pending_completion_id)
1706 {
1707 match status_update {
1708 CompletionRequestStatus::Queued {
1709 position,
1710 } => {
1711 completion.queue_state = QueueState::Queued { position };
1712 }
1713 CompletionRequestStatus::Started => {
1714 completion.queue_state = QueueState::Started;
1715 }
1716 CompletionRequestStatus::Failed {
1717 code, message, request_id
1718 } => {
1719 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1720 }
1721 CompletionRequestStatus::UsageUpdated {
1722 amount, limit
1723 } => {
1724 thread.update_model_request_usage(amount as u32, limit, cx);
1725 }
1726 CompletionRequestStatus::ToolUseLimitReached => {
1727 thread.tool_use_limit_reached = true;
1728 cx.emit(ThreadEvent::ToolUseLimitReached);
1729 }
1730 }
1731 }
1732 }
1733 }
1734
1735 thread.touch_updated_at();
1736 cx.emit(ThreadEvent::StreamedCompletion);
1737 cx.notify();
1738
1739 thread.auto_capture_telemetry(cx);
1740 Ok(())
1741 })??;
1742
1743 smol::future::yield_now().await;
1744 }
1745
1746 thread.update(cx, |thread, cx| {
1747 thread.last_received_chunk_at = None;
1748 thread
1749 .pending_completions
1750 .retain(|completion| completion.id != pending_completion_id);
1751
1752 // If there is a response without tool use, summarize the message. Otherwise,
1753 // allow two tool uses before summarizing.
1754 if matches!(thread.summary, ThreadSummary::Pending)
1755 && thread.messages.len() >= 2
1756 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1757 {
1758 thread.summarize(cx);
1759 }
1760 })?;
1761
1762 anyhow::Ok(stop_reason)
1763 };
1764
1765 let result = stream_completion.await;
1766
1767 thread
1768 .update(cx, |thread, cx| {
1769 thread.finalize_pending_checkpoint(cx);
1770 match result.as_ref() {
1771 Ok(stop_reason) => match stop_reason {
1772 StopReason::ToolUse => {
1773 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1774 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1775 }
1776 StopReason::EndTurn | StopReason::MaxTokens => {
1777 thread.project.update(cx, |project, cx| {
1778 project.set_agent_location(None, cx);
1779 });
1780 }
1781 StopReason::Refusal => {
1782 thread.project.update(cx, |project, cx| {
1783 project.set_agent_location(None, cx);
1784 });
1785
1786 // Remove the turn that was refused.
1787 //
1788 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1789 {
1790 let mut messages_to_remove = Vec::new();
1791
1792 for (ix, message) in thread.messages.iter().enumerate().rev() {
1793 messages_to_remove.push(message.id);
1794
1795 if message.role == Role::User {
1796 if ix == 0 {
1797 break;
1798 }
1799
1800 if let Some(prev_message) = thread.messages.get(ix - 1) {
1801 if prev_message.role == Role::Assistant {
1802 break;
1803 }
1804 }
1805 }
1806 }
1807
1808 for message_id in messages_to_remove {
1809 thread.delete_message(message_id, cx);
1810 }
1811 }
1812
1813 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1814 header: "Language model refusal".into(),
1815 message: "Model refused to generate content for safety reasons.".into(),
1816 }));
1817 }
1818 },
1819 Err(error) => {
1820 thread.project.update(cx, |project, cx| {
1821 project.set_agent_location(None, cx);
1822 });
1823
1824 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1825 let error_message = error
1826 .chain()
1827 .map(|err| err.to_string())
1828 .collect::<Vec<_>>()
1829 .join("\n");
1830 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1831 header: "Error interacting with language model".into(),
1832 message: SharedString::from(error_message.clone()),
1833 }));
1834 }
1835
1836 if error.is::<PaymentRequiredError>() {
1837 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1838 } else if let Some(error) =
1839 error.downcast_ref::<ModelRequestLimitReachedError>()
1840 {
1841 cx.emit(ThreadEvent::ShowError(
1842 ThreadError::ModelRequestLimitReached { plan: error.plan },
1843 ));
1844 } else if let Some(known_error) =
1845 error.downcast_ref::<LanguageModelKnownError>()
1846 {
1847 match known_error {
1848 LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
1849 thread.exceeded_window_error = Some(ExceededWindowError {
1850 model_id: model.id(),
1851 token_count: *tokens,
1852 });
1853 cx.notify();
1854 }
1855 LanguageModelKnownError::RateLimitExceeded { .. } => {
1856 // In the future we will report the error to the user, wait retry_after, and then retry.
1857 emit_generic_error(error, cx);
1858 }
1859 LanguageModelKnownError::Overloaded => {
1860 // In the future we will wait and then retry, up to N times.
1861 emit_generic_error(error, cx);
1862 }
1863 LanguageModelKnownError::ApiInternalServerError => {
1864 // In the future we will retry the request, but only once.
1865 emit_generic_error(error, cx);
1866 }
1867 LanguageModelKnownError::ReadResponseError(_) |
1868 LanguageModelKnownError::DeserializeResponse(_) |
1869 LanguageModelKnownError::UnknownResponseFormat(_) => {
1870 // In the future we will attempt to re-roll response, but only once
1871 emit_generic_error(error, cx);
1872 }
1873 }
1874 } else {
1875 emit_generic_error(error, cx);
1876 }
1877
1878 thread.cancel_last_completion(window, cx);
1879 }
1880 }
1881
1882 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1883
1884 if let Some((request_callback, (request, response_events))) = thread
1885 .request_callback
1886 .as_mut()
1887 .zip(request_callback_parameters.as_ref())
1888 {
1889 request_callback(request, response_events);
1890 }
1891
1892 thread.auto_capture_telemetry(cx);
1893
1894 if let Ok(initial_usage) = initial_token_usage {
1895 let usage = thread.cumulative_token_usage - initial_usage;
1896
1897 telemetry::event!(
1898 "Assistant Thread Completion",
1899 thread_id = thread.id().to_string(),
1900 prompt_id = prompt_id,
1901 model = model.telemetry_id(),
1902 model_provider = model.provider_id().to_string(),
1903 input_tokens = usage.input_tokens,
1904 output_tokens = usage.output_tokens,
1905 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1906 cache_read_input_tokens = usage.cache_read_input_tokens,
1907 );
1908 }
1909 })
1910 .ok();
1911 });
1912
1913 self.pending_completions.push(PendingCompletion {
1914 id: pending_completion_id,
1915 queue_state: QueueState::Sending,
1916 _task: task,
1917 });
1918 }
1919
1920 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1921 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1922 println!("No thread summary model");
1923 return;
1924 };
1925
1926 if !model.provider.is_authenticated(cx) {
1927 return;
1928 }
1929
1930 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1931
1932 let request = self.to_summarize_request(
1933 &model.model,
1934 CompletionIntent::ThreadSummarization,
1935 added_user_message.into(),
1936 cx,
1937 );
1938
1939 self.summary = ThreadSummary::Generating;
1940
1941 self.pending_summary = cx.spawn(async move |this, cx| {
1942 let result = async {
1943 let mut messages = model.model.stream_completion(request, &cx).await?;
1944
1945 let mut new_summary = String::new();
1946 while let Some(event) = messages.next().await {
1947 let Ok(event) = event else {
1948 continue;
1949 };
1950 let text = match event {
1951 LanguageModelCompletionEvent::Text(text) => text,
1952 LanguageModelCompletionEvent::StatusUpdate(
1953 CompletionRequestStatus::UsageUpdated { amount, limit },
1954 ) => {
1955 this.update(cx, |thread, cx| {
1956 thread.update_model_request_usage(amount as u32, limit, cx);
1957 })?;
1958 continue;
1959 }
1960 _ => continue,
1961 };
1962
1963 let mut lines = text.lines();
1964 new_summary.extend(lines.next());
1965
1966 // Stop if the LLM generated multiple lines.
1967 if lines.next().is_some() {
1968 break;
1969 }
1970 }
1971
1972 anyhow::Ok(new_summary)
1973 }
1974 .await;
1975
1976 this.update(cx, |this, cx| {
1977 match result {
1978 Ok(new_summary) => {
1979 if new_summary.is_empty() {
1980 this.summary = ThreadSummary::Error;
1981 } else {
1982 this.summary = ThreadSummary::Ready(new_summary.into());
1983 }
1984 }
1985 Err(err) => {
1986 this.summary = ThreadSummary::Error;
1987 log::error!("Failed to generate thread summary: {}", err);
1988 }
1989 }
1990 cx.emit(ThreadEvent::SummaryGenerated);
1991 })
1992 .log_err()?;
1993
1994 Some(())
1995 });
1996 }
1997
1998 pub fn start_generating_detailed_summary_if_needed(
1999 &mut self,
2000 thread_store: WeakEntity<ThreadStore>,
2001 cx: &mut Context<Self>,
2002 ) {
2003 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2004 return;
2005 };
2006
2007 match &*self.detailed_summary_rx.borrow() {
2008 DetailedSummaryState::Generating { message_id, .. }
2009 | DetailedSummaryState::Generated { message_id, .. }
2010 if *message_id == last_message_id =>
2011 {
2012 // Already up-to-date
2013 return;
2014 }
2015 _ => {}
2016 }
2017
2018 let Some(ConfiguredModel { model, provider }) =
2019 LanguageModelRegistry::read_global(cx).thread_summary_model()
2020 else {
2021 return;
2022 };
2023
2024 if !provider.is_authenticated(cx) {
2025 return;
2026 }
2027
2028 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2029
2030 let request = self.to_summarize_request(
2031 &model,
2032 CompletionIntent::ThreadContextSummarization,
2033 added_user_message.into(),
2034 cx,
2035 );
2036
2037 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2038 message_id: last_message_id,
2039 };
2040
2041 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2042 // be better to allow the old task to complete, but this would require logic for choosing
2043 // which result to prefer (the old task could complete after the new one, resulting in a
2044 // stale summary).
2045 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2046 let stream = model.stream_completion_text(request, &cx);
2047 let Some(mut messages) = stream.await.log_err() else {
2048 thread
2049 .update(cx, |thread, _cx| {
2050 *thread.detailed_summary_tx.borrow_mut() =
2051 DetailedSummaryState::NotGenerated;
2052 })
2053 .ok()?;
2054 return None;
2055 };
2056
2057 let mut new_detailed_summary = String::new();
2058
2059 while let Some(chunk) = messages.stream.next().await {
2060 if let Some(chunk) = chunk.log_err() {
2061 new_detailed_summary.push_str(&chunk);
2062 }
2063 }
2064
2065 thread
2066 .update(cx, |thread, _cx| {
2067 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2068 text: new_detailed_summary.into(),
2069 message_id: last_message_id,
2070 };
2071 })
2072 .ok()?;
2073
2074 // Save thread so its summary can be reused later
2075 if let Some(thread) = thread.upgrade() {
2076 if let Ok(Ok(save_task)) = cx.update(|cx| {
2077 thread_store
2078 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2079 }) {
2080 save_task.await.log_err();
2081 }
2082 }
2083
2084 Some(())
2085 });
2086 }
2087
2088 pub async fn wait_for_detailed_summary_or_text(
2089 this: &Entity<Self>,
2090 cx: &mut AsyncApp,
2091 ) -> Option<SharedString> {
2092 let mut detailed_summary_rx = this
2093 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2094 .ok()?;
2095 loop {
2096 match detailed_summary_rx.recv().await? {
2097 DetailedSummaryState::Generating { .. } => {}
2098 DetailedSummaryState::NotGenerated => {
2099 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2100 }
2101 DetailedSummaryState::Generated { text, .. } => return Some(text),
2102 }
2103 }
2104 }
2105
2106 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2107 self.detailed_summary_rx
2108 .borrow()
2109 .text()
2110 .unwrap_or_else(|| self.text().into())
2111 }
2112
2113 pub fn is_generating_detailed_summary(&self) -> bool {
2114 matches!(
2115 &*self.detailed_summary_rx.borrow(),
2116 DetailedSummaryState::Generating { .. }
2117 )
2118 }
2119
2120 pub fn use_pending_tools(
2121 &mut self,
2122 window: Option<AnyWindowHandle>,
2123 cx: &mut Context<Self>,
2124 model: Arc<dyn LanguageModel>,
2125 ) -> Vec<PendingToolUse> {
2126 self.auto_capture_telemetry(cx);
2127 let request =
2128 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2129 let pending_tool_uses = self
2130 .tool_use
2131 .pending_tool_uses()
2132 .into_iter()
2133 .filter(|tool_use| tool_use.status.is_idle())
2134 .cloned()
2135 .collect::<Vec<_>>();
2136
2137 for tool_use in pending_tool_uses.iter() {
2138 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2139 if tool.needs_confirmation(&tool_use.input, cx)
2140 && !AgentSettings::get_global(cx).always_allow_tool_actions
2141 {
2142 self.tool_use.confirm_tool_use(
2143 tool_use.id.clone(),
2144 tool_use.ui_text.clone(),
2145 tool_use.input.clone(),
2146 request.clone(),
2147 tool,
2148 );
2149 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2150 } else {
2151 self.run_tool(
2152 tool_use.id.clone(),
2153 tool_use.ui_text.clone(),
2154 tool_use.input.clone(),
2155 request.clone(),
2156 tool,
2157 model.clone(),
2158 window,
2159 cx,
2160 );
2161 }
2162 } else {
2163 self.handle_hallucinated_tool_use(
2164 tool_use.id.clone(),
2165 tool_use.name.clone(),
2166 window,
2167 cx,
2168 );
2169 }
2170 }
2171
2172 pending_tool_uses
2173 }
2174
2175 pub fn handle_hallucinated_tool_use(
2176 &mut self,
2177 tool_use_id: LanguageModelToolUseId,
2178 hallucinated_tool_name: Arc<str>,
2179 window: Option<AnyWindowHandle>,
2180 cx: &mut Context<Thread>,
2181 ) {
2182 let available_tools = self.profile.enabled_tools(cx);
2183
2184 let tool_list = available_tools
2185 .iter()
2186 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2187 .collect::<Vec<_>>()
2188 .join("\n");
2189
2190 let error_message = format!(
2191 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2192 hallucinated_tool_name, tool_list
2193 );
2194
2195 let pending_tool_use = self.tool_use.insert_tool_output(
2196 tool_use_id.clone(),
2197 hallucinated_tool_name,
2198 Err(anyhow!("Missing tool call: {error_message}")),
2199 self.configured_model.as_ref(),
2200 );
2201
2202 cx.emit(ThreadEvent::MissingToolUse {
2203 tool_use_id: tool_use_id.clone(),
2204 ui_text: error_message.into(),
2205 });
2206
2207 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2208 }
2209
2210 pub fn receive_invalid_tool_json(
2211 &mut self,
2212 tool_use_id: LanguageModelToolUseId,
2213 tool_name: Arc<str>,
2214 invalid_json: Arc<str>,
2215 error: String,
2216 window: Option<AnyWindowHandle>,
2217 cx: &mut Context<Thread>,
2218 ) {
2219 log::error!("The model returned invalid input JSON: {invalid_json}");
2220
2221 let pending_tool_use = self.tool_use.insert_tool_output(
2222 tool_use_id.clone(),
2223 tool_name,
2224 Err(anyhow!("Error parsing input JSON: {error}")),
2225 self.configured_model.as_ref(),
2226 );
2227 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2228 pending_tool_use.ui_text.clone()
2229 } else {
2230 log::error!(
2231 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2232 );
2233 format!("Unknown tool {}", tool_use_id).into()
2234 };
2235
2236 cx.emit(ThreadEvent::InvalidToolInput {
2237 tool_use_id: tool_use_id.clone(),
2238 ui_text,
2239 invalid_input_json: invalid_json,
2240 });
2241
2242 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2243 }
2244
2245 pub fn run_tool(
2246 &mut self,
2247 tool_use_id: LanguageModelToolUseId,
2248 ui_text: impl Into<SharedString>,
2249 input: serde_json::Value,
2250 request: Arc<LanguageModelRequest>,
2251 tool: Arc<dyn Tool>,
2252 model: Arc<dyn LanguageModel>,
2253 window: Option<AnyWindowHandle>,
2254 cx: &mut Context<Thread>,
2255 ) {
2256 let task =
2257 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2258 self.tool_use
2259 .run_pending_tool(tool_use_id, ui_text.into(), task);
2260 }
2261
2262 fn spawn_tool_use(
2263 &mut self,
2264 tool_use_id: LanguageModelToolUseId,
2265 request: Arc<LanguageModelRequest>,
2266 input: serde_json::Value,
2267 tool: Arc<dyn Tool>,
2268 model: Arc<dyn LanguageModel>,
2269 window: Option<AnyWindowHandle>,
2270 cx: &mut Context<Thread>,
2271 ) -> Task<()> {
2272 let tool_name: Arc<str> = tool.name().into();
2273
2274 let tool_result = tool.run(
2275 input,
2276 request,
2277 self.project.clone(),
2278 self.action_log.clone(),
2279 model,
2280 window,
2281 cx,
2282 );
2283
2284 // Store the card separately if it exists
2285 if let Some(card) = tool_result.card.clone() {
2286 self.tool_use
2287 .insert_tool_result_card(tool_use_id.clone(), card);
2288 }
2289
2290 cx.spawn({
2291 async move |thread: WeakEntity<Thread>, cx| {
2292 let output = tool_result.output.await;
2293
2294 thread
2295 .update(cx, |thread, cx| {
2296 let pending_tool_use = thread.tool_use.insert_tool_output(
2297 tool_use_id.clone(),
2298 tool_name,
2299 output,
2300 thread.configured_model.as_ref(),
2301 );
2302 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2303 })
2304 .ok();
2305 }
2306 })
2307 }
2308
2309 fn tool_finished(
2310 &mut self,
2311 tool_use_id: LanguageModelToolUseId,
2312 pending_tool_use: Option<PendingToolUse>,
2313 canceled: bool,
2314 window: Option<AnyWindowHandle>,
2315 cx: &mut Context<Self>,
2316 ) {
2317 if self.all_tools_finished() {
2318 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2319 if !canceled {
2320 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2321 }
2322 self.auto_capture_telemetry(cx);
2323 }
2324 }
2325
2326 cx.emit(ThreadEvent::ToolFinished {
2327 tool_use_id,
2328 pending_tool_use,
2329 });
2330 }
2331
2332 /// Cancels the last pending completion, if there are any pending.
2333 ///
2334 /// Returns whether a completion was canceled.
2335 pub fn cancel_last_completion(
2336 &mut self,
2337 window: Option<AnyWindowHandle>,
2338 cx: &mut Context<Self>,
2339 ) -> bool {
2340 let mut canceled = self.pending_completions.pop().is_some();
2341
2342 for pending_tool_use in self.tool_use.cancel_pending() {
2343 canceled = true;
2344 self.tool_finished(
2345 pending_tool_use.id.clone(),
2346 Some(pending_tool_use),
2347 true,
2348 window,
2349 cx,
2350 );
2351 }
2352
2353 if canceled {
2354 cx.emit(ThreadEvent::CompletionCanceled);
2355
2356 // When canceled, we always want to insert the checkpoint.
2357 // (We skip over finalize_pending_checkpoint, because it
2358 // would conclude we didn't have anything to insert here.)
2359 if let Some(checkpoint) = self.pending_checkpoint.take() {
2360 self.insert_checkpoint(checkpoint, cx);
2361 }
2362 } else {
2363 self.finalize_pending_checkpoint(cx);
2364 }
2365
2366 canceled
2367 }
2368
2369 /// Signals that any in-progress editing should be canceled.
2370 ///
2371 /// This method is used to notify listeners (like ActiveThread) that
2372 /// they should cancel any editing operations.
2373 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2374 cx.emit(ThreadEvent::CancelEditing);
2375 }
2376
2377 pub fn feedback(&self) -> Option<ThreadFeedback> {
2378 self.feedback
2379 }
2380
2381 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2382 self.message_feedback.get(&message_id).copied()
2383 }
2384
2385 pub fn report_message_feedback(
2386 &mut self,
2387 message_id: MessageId,
2388 feedback: ThreadFeedback,
2389 cx: &mut Context<Self>,
2390 ) -> Task<Result<()>> {
2391 if self.message_feedback.get(&message_id) == Some(&feedback) {
2392 return Task::ready(Ok(()));
2393 }
2394
2395 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2396 let serialized_thread = self.serialize(cx);
2397 let thread_id = self.id().clone();
2398 let client = self.project.read(cx).client();
2399
2400 let enabled_tool_names: Vec<String> = self
2401 .profile
2402 .enabled_tools(cx)
2403 .iter()
2404 .map(|tool| tool.name())
2405 .collect();
2406
2407 self.message_feedback.insert(message_id, feedback);
2408
2409 cx.notify();
2410
2411 let message_content = self
2412 .message(message_id)
2413 .map(|msg| msg.to_string())
2414 .unwrap_or_default();
2415
2416 cx.background_spawn(async move {
2417 let final_project_snapshot = final_project_snapshot.await;
2418 let serialized_thread = serialized_thread.await?;
2419 let thread_data =
2420 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2421
2422 let rating = match feedback {
2423 ThreadFeedback::Positive => "positive",
2424 ThreadFeedback::Negative => "negative",
2425 };
2426 telemetry::event!(
2427 "Assistant Thread Rated",
2428 rating,
2429 thread_id,
2430 enabled_tool_names,
2431 message_id = message_id.0,
2432 message_content,
2433 thread_data,
2434 final_project_snapshot
2435 );
2436 client.telemetry().flush_events().await;
2437
2438 Ok(())
2439 })
2440 }
2441
2442 pub fn report_feedback(
2443 &mut self,
2444 feedback: ThreadFeedback,
2445 cx: &mut Context<Self>,
2446 ) -> Task<Result<()>> {
2447 let last_assistant_message_id = self
2448 .messages
2449 .iter()
2450 .rev()
2451 .find(|msg| msg.role == Role::Assistant)
2452 .map(|msg| msg.id);
2453
2454 if let Some(message_id) = last_assistant_message_id {
2455 self.report_message_feedback(message_id, feedback, cx)
2456 } else {
2457 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2458 let serialized_thread = self.serialize(cx);
2459 let thread_id = self.id().clone();
2460 let client = self.project.read(cx).client();
2461 self.feedback = Some(feedback);
2462 cx.notify();
2463
2464 cx.background_spawn(async move {
2465 let final_project_snapshot = final_project_snapshot.await;
2466 let serialized_thread = serialized_thread.await?;
2467 let thread_data = serde_json::to_value(serialized_thread)
2468 .unwrap_or_else(|_| serde_json::Value::Null);
2469
2470 let rating = match feedback {
2471 ThreadFeedback::Positive => "positive",
2472 ThreadFeedback::Negative => "negative",
2473 };
2474 telemetry::event!(
2475 "Assistant Thread Rated",
2476 rating,
2477 thread_id,
2478 thread_data,
2479 final_project_snapshot
2480 );
2481 client.telemetry().flush_events().await;
2482
2483 Ok(())
2484 })
2485 }
2486 }
2487
2488 /// Create a snapshot of the current project state including git information and unsaved buffers.
2489 fn project_snapshot(
2490 project: Entity<Project>,
2491 cx: &mut Context<Self>,
2492 ) -> Task<Arc<ProjectSnapshot>> {
2493 let git_store = project.read(cx).git_store().clone();
2494 let worktree_snapshots: Vec<_> = project
2495 .read(cx)
2496 .visible_worktrees(cx)
2497 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2498 .collect();
2499
2500 cx.spawn(async move |_, cx| {
2501 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2502
2503 let mut unsaved_buffers = Vec::new();
2504 cx.update(|app_cx| {
2505 let buffer_store = project.read(app_cx).buffer_store();
2506 for buffer_handle in buffer_store.read(app_cx).buffers() {
2507 let buffer = buffer_handle.read(app_cx);
2508 if buffer.is_dirty() {
2509 if let Some(file) = buffer.file() {
2510 let path = file.path().to_string_lossy().to_string();
2511 unsaved_buffers.push(path);
2512 }
2513 }
2514 }
2515 })
2516 .ok();
2517
2518 Arc::new(ProjectSnapshot {
2519 worktree_snapshots,
2520 unsaved_buffer_paths: unsaved_buffers,
2521 timestamp: Utc::now(),
2522 })
2523 })
2524 }
2525
2526 fn worktree_snapshot(
2527 worktree: Entity<project::Worktree>,
2528 git_store: Entity<GitStore>,
2529 cx: &App,
2530 ) -> Task<WorktreeSnapshot> {
2531 cx.spawn(async move |cx| {
2532 // Get worktree path and snapshot
2533 let worktree_info = cx.update(|app_cx| {
2534 let worktree = worktree.read(app_cx);
2535 let path = worktree.abs_path().to_string_lossy().to_string();
2536 let snapshot = worktree.snapshot();
2537 (path, snapshot)
2538 });
2539
2540 let Ok((worktree_path, _snapshot)) = worktree_info else {
2541 return WorktreeSnapshot {
2542 worktree_path: String::new(),
2543 git_state: None,
2544 };
2545 };
2546
2547 let git_state = git_store
2548 .update(cx, |git_store, cx| {
2549 git_store
2550 .repositories()
2551 .values()
2552 .find(|repo| {
2553 repo.read(cx)
2554 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2555 .is_some()
2556 })
2557 .cloned()
2558 })
2559 .ok()
2560 .flatten()
2561 .map(|repo| {
2562 repo.update(cx, |repo, _| {
2563 let current_branch =
2564 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2565 repo.send_job(None, |state, _| async move {
2566 let RepositoryState::Local { backend, .. } = state else {
2567 return GitState {
2568 remote_url: None,
2569 head_sha: None,
2570 current_branch,
2571 diff: None,
2572 };
2573 };
2574
2575 let remote_url = backend.remote_url("origin");
2576 let head_sha = backend.head_sha().await;
2577 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2578
2579 GitState {
2580 remote_url,
2581 head_sha,
2582 current_branch,
2583 diff,
2584 }
2585 })
2586 })
2587 });
2588
2589 let git_state = match git_state {
2590 Some(git_state) => match git_state.ok() {
2591 Some(git_state) => git_state.await.ok(),
2592 None => None,
2593 },
2594 None => None,
2595 };
2596
2597 WorktreeSnapshot {
2598 worktree_path,
2599 git_state,
2600 }
2601 })
2602 }
2603
2604 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2605 let mut markdown = Vec::new();
2606
2607 let summary = self.summary().or_default();
2608 writeln!(markdown, "# {summary}\n")?;
2609
2610 for message in self.messages() {
2611 writeln!(
2612 markdown,
2613 "## {role}\n",
2614 role = match message.role {
2615 Role::User => "User",
2616 Role::Assistant => "Agent",
2617 Role::System => "System",
2618 }
2619 )?;
2620
2621 if !message.loaded_context.text.is_empty() {
2622 writeln!(markdown, "{}", message.loaded_context.text)?;
2623 }
2624
2625 if !message.loaded_context.images.is_empty() {
2626 writeln!(
2627 markdown,
2628 "\n{} images attached as context.\n",
2629 message.loaded_context.images.len()
2630 )?;
2631 }
2632
2633 for segment in &message.segments {
2634 match segment {
2635 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2636 MessageSegment::Thinking { text, .. } => {
2637 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2638 }
2639 MessageSegment::RedactedThinking(_) => {}
2640 }
2641 }
2642
2643 for tool_use in self.tool_uses_for_message(message.id, cx) {
2644 writeln!(
2645 markdown,
2646 "**Use Tool: {} ({})**",
2647 tool_use.name, tool_use.id
2648 )?;
2649 writeln!(markdown, "```json")?;
2650 writeln!(
2651 markdown,
2652 "{}",
2653 serde_json::to_string_pretty(&tool_use.input)?
2654 )?;
2655 writeln!(markdown, "```")?;
2656 }
2657
2658 for tool_result in self.tool_results_for_message(message.id) {
2659 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2660 if tool_result.is_error {
2661 write!(markdown, " (Error)")?;
2662 }
2663
2664 writeln!(markdown, "**\n")?;
2665 match &tool_result.content {
2666 LanguageModelToolResultContent::Text(text) => {
2667 writeln!(markdown, "{text}")?;
2668 }
2669 LanguageModelToolResultContent::Image(image) => {
2670 writeln!(markdown, "", image.source)?;
2671 }
2672 }
2673
2674 if let Some(output) = tool_result.output.as_ref() {
2675 writeln!(
2676 markdown,
2677 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2678 serde_json::to_string_pretty(output)?
2679 )?;
2680 }
2681 }
2682 }
2683
2684 Ok(String::from_utf8_lossy(&markdown).to_string())
2685 }
2686
2687 pub fn keep_edits_in_range(
2688 &mut self,
2689 buffer: Entity<language::Buffer>,
2690 buffer_range: Range<language::Anchor>,
2691 cx: &mut Context<Self>,
2692 ) {
2693 self.action_log.update(cx, |action_log, cx| {
2694 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2695 });
2696 }
2697
2698 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2699 self.action_log
2700 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2701 }
2702
2703 pub fn reject_edits_in_ranges(
2704 &mut self,
2705 buffer: Entity<language::Buffer>,
2706 buffer_ranges: Vec<Range<language::Anchor>>,
2707 cx: &mut Context<Self>,
2708 ) -> Task<Result<()>> {
2709 self.action_log.update(cx, |action_log, cx| {
2710 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2711 })
2712 }
2713
2714 pub fn action_log(&self) -> &Entity<ActionLog> {
2715 &self.action_log
2716 }
2717
2718 pub fn project(&self) -> &Entity<Project> {
2719 &self.project
2720 }
2721
2722 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2723 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2724 return;
2725 }
2726
2727 let now = Instant::now();
2728 if let Some(last) = self.last_auto_capture_at {
2729 if now.duration_since(last).as_secs() < 10 {
2730 return;
2731 }
2732 }
2733
2734 self.last_auto_capture_at = Some(now);
2735
2736 let thread_id = self.id().clone();
2737 let github_login = self
2738 .project
2739 .read(cx)
2740 .user_store()
2741 .read(cx)
2742 .current_user()
2743 .map(|user| user.github_login.clone());
2744 let client = self.project.read(cx).client();
2745 let serialize_task = self.serialize(cx);
2746
2747 cx.background_executor()
2748 .spawn(async move {
2749 if let Ok(serialized_thread) = serialize_task.await {
2750 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2751 telemetry::event!(
2752 "Agent Thread Auto-Captured",
2753 thread_id = thread_id.to_string(),
2754 thread_data = thread_data,
2755 auto_capture_reason = "tracked_user",
2756 github_login = github_login
2757 );
2758
2759 client.telemetry().flush_events().await;
2760 }
2761 }
2762 })
2763 .detach();
2764 }
2765
2766 pub fn cumulative_token_usage(&self) -> TokenUsage {
2767 self.cumulative_token_usage
2768 }
2769
2770 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2771 let Some(model) = self.configured_model.as_ref() else {
2772 return TotalTokenUsage::default();
2773 };
2774
2775 let max = model.model.max_token_count();
2776
2777 let index = self
2778 .messages
2779 .iter()
2780 .position(|msg| msg.id == message_id)
2781 .unwrap_or(0);
2782
2783 if index == 0 {
2784 return TotalTokenUsage { total: 0, max };
2785 }
2786
2787 let token_usage = &self
2788 .request_token_usage
2789 .get(index - 1)
2790 .cloned()
2791 .unwrap_or_default();
2792
2793 TotalTokenUsage {
2794 total: token_usage.total_tokens(),
2795 max,
2796 }
2797 }
2798
2799 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2800 let model = self.configured_model.as_ref()?;
2801
2802 let max = model.model.max_token_count();
2803
2804 if let Some(exceeded_error) = &self.exceeded_window_error {
2805 if model.model.id() == exceeded_error.model_id {
2806 return Some(TotalTokenUsage {
2807 total: exceeded_error.token_count,
2808 max,
2809 });
2810 }
2811 }
2812
2813 let total = self
2814 .token_usage_at_last_message()
2815 .unwrap_or_default()
2816 .total_tokens();
2817
2818 Some(TotalTokenUsage { total, max })
2819 }
2820
2821 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2822 self.request_token_usage
2823 .get(self.messages.len().saturating_sub(1))
2824 .or_else(|| self.request_token_usage.last())
2825 .cloned()
2826 }
2827
2828 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2829 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2830 self.request_token_usage
2831 .resize(self.messages.len(), placeholder);
2832
2833 if let Some(last) = self.request_token_usage.last_mut() {
2834 *last = token_usage;
2835 }
2836 }
2837
2838 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
2839 self.project.update(cx, |project, cx| {
2840 project.user_store().update(cx, |user_store, cx| {
2841 user_store.update_model_request_usage(
2842 ModelRequestUsage(RequestUsage {
2843 amount: amount as i32,
2844 limit,
2845 }),
2846 cx,
2847 )
2848 })
2849 });
2850 }
2851
2852 pub fn deny_tool_use(
2853 &mut self,
2854 tool_use_id: LanguageModelToolUseId,
2855 tool_name: Arc<str>,
2856 window: Option<AnyWindowHandle>,
2857 cx: &mut Context<Self>,
2858 ) {
2859 let err = Err(anyhow::anyhow!(
2860 "Permission to run tool action denied by user"
2861 ));
2862
2863 self.tool_use.insert_tool_output(
2864 tool_use_id.clone(),
2865 tool_name,
2866 err,
2867 self.configured_model.as_ref(),
2868 );
2869 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2870 }
2871}
2872
2873#[derive(Debug, Clone, Error)]
2874pub enum ThreadError {
2875 #[error("Payment required")]
2876 PaymentRequired,
2877 #[error("Model request limit reached")]
2878 ModelRequestLimitReached { plan: Plan },
2879 #[error("Message {header}: {message}")]
2880 Message {
2881 header: SharedString,
2882 message: SharedString,
2883 },
2884}
2885
2886#[derive(Debug, Clone)]
2887pub enum ThreadEvent {
2888 ShowError(ThreadError),
2889 StreamedCompletion,
2890 ReceivedTextChunk,
2891 NewRequest,
2892 StreamedAssistantText(MessageId, String),
2893 StreamedAssistantThinking(MessageId, String),
2894 StreamedToolUse {
2895 tool_use_id: LanguageModelToolUseId,
2896 ui_text: Arc<str>,
2897 input: serde_json::Value,
2898 },
2899 MissingToolUse {
2900 tool_use_id: LanguageModelToolUseId,
2901 ui_text: Arc<str>,
2902 },
2903 InvalidToolInput {
2904 tool_use_id: LanguageModelToolUseId,
2905 ui_text: Arc<str>,
2906 invalid_input_json: Arc<str>,
2907 },
2908 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2909 MessageAdded(MessageId),
2910 MessageEdited(MessageId),
2911 MessageDeleted(MessageId),
2912 SummaryGenerated,
2913 SummaryChanged,
2914 UsePendingTools {
2915 tool_uses: Vec<PendingToolUse>,
2916 },
2917 ToolFinished {
2918 #[allow(unused)]
2919 tool_use_id: LanguageModelToolUseId,
2920 /// The pending tool use that corresponds to this tool.
2921 pending_tool_use: Option<PendingToolUse>,
2922 },
2923 CheckpointChanged,
2924 ToolConfirmationNeeded,
2925 ToolUseLimitReached,
2926 CancelEditing,
2927 CompletionCanceled,
2928 ProfileChanged,
2929}
2930
2931impl EventEmitter<ThreadEvent> for Thread {}
2932
2933struct PendingCompletion {
2934 id: usize,
2935 queue_state: QueueState,
2936 _task: Task<()>,
2937}
2938
2939/// Resolves tool name conflicts by ensuring all tool names are unique.
2940///
2941/// When multiple tools have the same name, this function applies the following rules:
2942/// 1. Native tools always keep their original name
2943/// 2. Context server tools get prefixed with their server ID and an underscore
2944/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
2945/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
2946///
2947/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
2948fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
2949 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
2950 let mut tool_name = tool.name();
2951 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2952 tool_name
2953 }
2954
2955 const MAX_TOOL_NAME_LENGTH: usize = 64;
2956
2957 let mut duplicated_tool_names = HashSet::default();
2958 let mut seen_tool_names = HashSet::default();
2959 for tool in tools {
2960 let tool_name = resolve_tool_name(tool);
2961 if seen_tool_names.contains(&tool_name) {
2962 debug_assert!(
2963 tool.source() != assistant_tool::ToolSource::Native,
2964 "There are two built-in tools with the same name: {}",
2965 tool_name
2966 );
2967 duplicated_tool_names.insert(tool_name);
2968 } else {
2969 seen_tool_names.insert(tool_name);
2970 }
2971 }
2972
2973 if duplicated_tool_names.is_empty() {
2974 return tools
2975 .into_iter()
2976 .map(|tool| (resolve_tool_name(tool), tool.clone()))
2977 .collect();
2978 }
2979
2980 tools
2981 .into_iter()
2982 .filter_map(|tool| {
2983 let mut tool_name = resolve_tool_name(tool);
2984 if !duplicated_tool_names.contains(&tool_name) {
2985 return Some((tool_name, tool.clone()));
2986 }
2987 match tool.source() {
2988 assistant_tool::ToolSource::Native => {
2989 // Built-in tools always keep their original name
2990 Some((tool_name, tool.clone()))
2991 }
2992 assistant_tool::ToolSource::ContextServer { id } => {
2993 // Context server tools are prefixed with the context server ID, and truncated if necessary
2994 tool_name.insert(0, '_');
2995 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
2996 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
2997 let mut id = id.to_string();
2998 id.truncate(len);
2999 tool_name.insert_str(0, &id);
3000 } else {
3001 tool_name.insert_str(0, &id);
3002 }
3003
3004 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3005
3006 if seen_tool_names.contains(&tool_name) {
3007 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
3008 None
3009 } else {
3010 Some((tool_name, tool.clone()))
3011 }
3012 }
3013 }
3014 })
3015 .collect()
3016}
3017
3018#[cfg(test)]
3019mod tests {
3020 use super::*;
3021 use crate::{
3022 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3023 };
3024 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3025 use assistant_tool::ToolRegistry;
3026 use gpui::TestAppContext;
3027 use icons::IconName;
3028 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3029 use project::{FakeFs, Project};
3030 use prompt_store::PromptBuilder;
3031 use serde_json::json;
3032 use settings::{Settings, SettingsStore};
3033 use std::sync::Arc;
3034 use theme::ThemeSettings;
3035 use util::path;
3036 use workspace::Workspace;
3037
3038 #[gpui::test]
3039 async fn test_message_with_context(cx: &mut TestAppContext) {
3040 init_test_settings(cx);
3041
3042 let project = create_test_project(
3043 cx,
3044 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3045 )
3046 .await;
3047
3048 let (_workspace, _thread_store, thread, context_store, model) =
3049 setup_test_environment(cx, project.clone()).await;
3050
3051 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3052 .await
3053 .unwrap();
3054
3055 let context =
3056 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3057 let loaded_context = cx
3058 .update(|cx| load_context(vec![context], &project, &None, cx))
3059 .await;
3060
3061 // Insert user message with context
3062 let message_id = thread.update(cx, |thread, cx| {
3063 thread.insert_user_message(
3064 "Please explain this code",
3065 loaded_context,
3066 None,
3067 Vec::new(),
3068 cx,
3069 )
3070 });
3071
3072 // Check content and context in message object
3073 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3074
3075 // Use different path format strings based on platform for the test
3076 #[cfg(windows)]
3077 let path_part = r"test\code.rs";
3078 #[cfg(not(windows))]
3079 let path_part = "test/code.rs";
3080
3081 let expected_context = format!(
3082 r#"
3083<context>
3084The following items were attached by the user. They are up-to-date and don't need to be re-read.
3085
3086<files>
3087```rs {path_part}
3088fn main() {{
3089 println!("Hello, world!");
3090}}
3091```
3092</files>
3093</context>
3094"#
3095 );
3096
3097 assert_eq!(message.role, Role::User);
3098 assert_eq!(message.segments.len(), 1);
3099 assert_eq!(
3100 message.segments[0],
3101 MessageSegment::Text("Please explain this code".to_string())
3102 );
3103 assert_eq!(message.loaded_context.text, expected_context);
3104
3105 // Check message in request
3106 let request = thread.update(cx, |thread, cx| {
3107 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3108 });
3109
3110 assert_eq!(request.messages.len(), 2);
3111 let expected_full_message = format!("{}Please explain this code", expected_context);
3112 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3113 }
3114
3115 #[gpui::test]
3116 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3117 init_test_settings(cx);
3118
3119 let project = create_test_project(
3120 cx,
3121 json!({
3122 "file1.rs": "fn function1() {}\n",
3123 "file2.rs": "fn function2() {}\n",
3124 "file3.rs": "fn function3() {}\n",
3125 "file4.rs": "fn function4() {}\n",
3126 }),
3127 )
3128 .await;
3129
3130 let (_, _thread_store, thread, context_store, model) =
3131 setup_test_environment(cx, project.clone()).await;
3132
3133 // First message with context 1
3134 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3135 .await
3136 .unwrap();
3137 let new_contexts = context_store.update(cx, |store, cx| {
3138 store.new_context_for_thread(thread.read(cx), None)
3139 });
3140 assert_eq!(new_contexts.len(), 1);
3141 let loaded_context = cx
3142 .update(|cx| load_context(new_contexts, &project, &None, cx))
3143 .await;
3144 let message1_id = thread.update(cx, |thread, cx| {
3145 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3146 });
3147
3148 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3149 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3150 .await
3151 .unwrap();
3152 let new_contexts = context_store.update(cx, |store, cx| {
3153 store.new_context_for_thread(thread.read(cx), None)
3154 });
3155 assert_eq!(new_contexts.len(), 1);
3156 let loaded_context = cx
3157 .update(|cx| load_context(new_contexts, &project, &None, cx))
3158 .await;
3159 let message2_id = thread.update(cx, |thread, cx| {
3160 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3161 });
3162
3163 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3164 //
3165 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3166 .await
3167 .unwrap();
3168 let new_contexts = context_store.update(cx, |store, cx| {
3169 store.new_context_for_thread(thread.read(cx), None)
3170 });
3171 assert_eq!(new_contexts.len(), 1);
3172 let loaded_context = cx
3173 .update(|cx| load_context(new_contexts, &project, &None, cx))
3174 .await;
3175 let message3_id = thread.update(cx, |thread, cx| {
3176 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3177 });
3178
3179 // Check what contexts are included in each message
3180 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3181 (
3182 thread.message(message1_id).unwrap().clone(),
3183 thread.message(message2_id).unwrap().clone(),
3184 thread.message(message3_id).unwrap().clone(),
3185 )
3186 });
3187
3188 // First message should include context 1
3189 assert!(message1.loaded_context.text.contains("file1.rs"));
3190
3191 // Second message should include only context 2 (not 1)
3192 assert!(!message2.loaded_context.text.contains("file1.rs"));
3193 assert!(message2.loaded_context.text.contains("file2.rs"));
3194
3195 // Third message should include only context 3 (not 1 or 2)
3196 assert!(!message3.loaded_context.text.contains("file1.rs"));
3197 assert!(!message3.loaded_context.text.contains("file2.rs"));
3198 assert!(message3.loaded_context.text.contains("file3.rs"));
3199
3200 // Check entire request to make sure all contexts are properly included
3201 let request = thread.update(cx, |thread, cx| {
3202 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3203 });
3204
3205 // The request should contain all 3 messages
3206 assert_eq!(request.messages.len(), 4);
3207
3208 // Check that the contexts are properly formatted in each message
3209 assert!(request.messages[1].string_contents().contains("file1.rs"));
3210 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3211 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3212
3213 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3214 assert!(request.messages[2].string_contents().contains("file2.rs"));
3215 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3216
3217 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3218 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3219 assert!(request.messages[3].string_contents().contains("file3.rs"));
3220
3221 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3222 .await
3223 .unwrap();
3224 let new_contexts = context_store.update(cx, |store, cx| {
3225 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3226 });
3227 assert_eq!(new_contexts.len(), 3);
3228 let loaded_context = cx
3229 .update(|cx| load_context(new_contexts, &project, &None, cx))
3230 .await
3231 .loaded_context;
3232
3233 assert!(!loaded_context.text.contains("file1.rs"));
3234 assert!(loaded_context.text.contains("file2.rs"));
3235 assert!(loaded_context.text.contains("file3.rs"));
3236 assert!(loaded_context.text.contains("file4.rs"));
3237
3238 let new_contexts = context_store.update(cx, |store, cx| {
3239 // Remove file4.rs
3240 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3241 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3242 });
3243 assert_eq!(new_contexts.len(), 2);
3244 let loaded_context = cx
3245 .update(|cx| load_context(new_contexts, &project, &None, cx))
3246 .await
3247 .loaded_context;
3248
3249 assert!(!loaded_context.text.contains("file1.rs"));
3250 assert!(loaded_context.text.contains("file2.rs"));
3251 assert!(loaded_context.text.contains("file3.rs"));
3252 assert!(!loaded_context.text.contains("file4.rs"));
3253
3254 let new_contexts = context_store.update(cx, |store, cx| {
3255 // Remove file3.rs
3256 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3257 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3258 });
3259 assert_eq!(new_contexts.len(), 1);
3260 let loaded_context = cx
3261 .update(|cx| load_context(new_contexts, &project, &None, cx))
3262 .await
3263 .loaded_context;
3264
3265 assert!(!loaded_context.text.contains("file1.rs"));
3266 assert!(loaded_context.text.contains("file2.rs"));
3267 assert!(!loaded_context.text.contains("file3.rs"));
3268 assert!(!loaded_context.text.contains("file4.rs"));
3269 }
3270
3271 #[gpui::test]
3272 async fn test_message_without_files(cx: &mut TestAppContext) {
3273 init_test_settings(cx);
3274
3275 let project = create_test_project(
3276 cx,
3277 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3278 )
3279 .await;
3280
3281 let (_, _thread_store, thread, _context_store, model) =
3282 setup_test_environment(cx, project.clone()).await;
3283
3284 // Insert user message without any context (empty context vector)
3285 let message_id = thread.update(cx, |thread, cx| {
3286 thread.insert_user_message(
3287 "What is the best way to learn Rust?",
3288 ContextLoadResult::default(),
3289 None,
3290 Vec::new(),
3291 cx,
3292 )
3293 });
3294
3295 // Check content and context in message object
3296 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3297
3298 // Context should be empty when no files are included
3299 assert_eq!(message.role, Role::User);
3300 assert_eq!(message.segments.len(), 1);
3301 assert_eq!(
3302 message.segments[0],
3303 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3304 );
3305 assert_eq!(message.loaded_context.text, "");
3306
3307 // Check message in request
3308 let request = thread.update(cx, |thread, cx| {
3309 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3310 });
3311
3312 assert_eq!(request.messages.len(), 2);
3313 assert_eq!(
3314 request.messages[1].string_contents(),
3315 "What is the best way to learn Rust?"
3316 );
3317
3318 // Add second message, also without context
3319 let message2_id = thread.update(cx, |thread, cx| {
3320 thread.insert_user_message(
3321 "Are there any good books?",
3322 ContextLoadResult::default(),
3323 None,
3324 Vec::new(),
3325 cx,
3326 )
3327 });
3328
3329 let message2 =
3330 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3331 assert_eq!(message2.loaded_context.text, "");
3332
3333 // Check that both messages appear in the request
3334 let request = thread.update(cx, |thread, cx| {
3335 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3336 });
3337
3338 assert_eq!(request.messages.len(), 3);
3339 assert_eq!(
3340 request.messages[1].string_contents(),
3341 "What is the best way to learn Rust?"
3342 );
3343 assert_eq!(
3344 request.messages[2].string_contents(),
3345 "Are there any good books?"
3346 );
3347 }
3348
3349 #[gpui::test]
3350 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3351 init_test_settings(cx);
3352
3353 let project = create_test_project(
3354 cx,
3355 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3356 )
3357 .await;
3358
3359 let (_workspace, thread_store, thread, _context_store, _model) =
3360 setup_test_environment(cx, project.clone()).await;
3361
3362 // Check that we are starting with the default profile
3363 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3364 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3365 assert_eq!(
3366 profile,
3367 AgentProfile::new(AgentProfileId::default(), tool_set)
3368 );
3369 }
3370
3371 #[gpui::test]
3372 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3373 init_test_settings(cx);
3374
3375 let project = create_test_project(
3376 cx,
3377 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3378 )
3379 .await;
3380
3381 let (_workspace, thread_store, thread, _context_store, _model) =
3382 setup_test_environment(cx, project.clone()).await;
3383
3384 // Profile gets serialized with default values
3385 let serialized = thread
3386 .update(cx, |thread, cx| thread.serialize(cx))
3387 .await
3388 .unwrap();
3389
3390 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3391
3392 let deserialized = cx.update(|cx| {
3393 thread.update(cx, |thread, cx| {
3394 Thread::deserialize(
3395 thread.id.clone(),
3396 serialized,
3397 thread.project.clone(),
3398 thread.tools.clone(),
3399 thread.prompt_builder.clone(),
3400 thread.project_context.clone(),
3401 None,
3402 cx,
3403 )
3404 })
3405 });
3406 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3407
3408 assert_eq!(
3409 deserialized.profile,
3410 AgentProfile::new(AgentProfileId::default(), tool_set)
3411 );
3412 }
3413
3414 #[gpui::test]
3415 async fn test_temperature_setting(cx: &mut TestAppContext) {
3416 init_test_settings(cx);
3417
3418 let project = create_test_project(
3419 cx,
3420 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3421 )
3422 .await;
3423
3424 let (_workspace, _thread_store, thread, _context_store, model) =
3425 setup_test_environment(cx, project.clone()).await;
3426
3427 // Both model and provider
3428 cx.update(|cx| {
3429 AgentSettings::override_global(
3430 AgentSettings {
3431 model_parameters: vec![LanguageModelParameters {
3432 provider: Some(model.provider_id().0.to_string().into()),
3433 model: Some(model.id().0.clone()),
3434 temperature: Some(0.66),
3435 }],
3436 ..AgentSettings::get_global(cx).clone()
3437 },
3438 cx,
3439 );
3440 });
3441
3442 let request = thread.update(cx, |thread, cx| {
3443 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3444 });
3445 assert_eq!(request.temperature, Some(0.66));
3446
3447 // Only model
3448 cx.update(|cx| {
3449 AgentSettings::override_global(
3450 AgentSettings {
3451 model_parameters: vec![LanguageModelParameters {
3452 provider: None,
3453 model: Some(model.id().0.clone()),
3454 temperature: Some(0.66),
3455 }],
3456 ..AgentSettings::get_global(cx).clone()
3457 },
3458 cx,
3459 );
3460 });
3461
3462 let request = thread.update(cx, |thread, cx| {
3463 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3464 });
3465 assert_eq!(request.temperature, Some(0.66));
3466
3467 // Only provider
3468 cx.update(|cx| {
3469 AgentSettings::override_global(
3470 AgentSettings {
3471 model_parameters: vec![LanguageModelParameters {
3472 provider: Some(model.provider_id().0.to_string().into()),
3473 model: None,
3474 temperature: Some(0.66),
3475 }],
3476 ..AgentSettings::get_global(cx).clone()
3477 },
3478 cx,
3479 );
3480 });
3481
3482 let request = thread.update(cx, |thread, cx| {
3483 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3484 });
3485 assert_eq!(request.temperature, Some(0.66));
3486
3487 // Same model name, different provider
3488 cx.update(|cx| {
3489 AgentSettings::override_global(
3490 AgentSettings {
3491 model_parameters: vec![LanguageModelParameters {
3492 provider: Some("anthropic".into()),
3493 model: Some(model.id().0.clone()),
3494 temperature: Some(0.66),
3495 }],
3496 ..AgentSettings::get_global(cx).clone()
3497 },
3498 cx,
3499 );
3500 });
3501
3502 let request = thread.update(cx, |thread, cx| {
3503 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3504 });
3505 assert_eq!(request.temperature, None);
3506 }
3507
3508 #[gpui::test]
3509 async fn test_thread_summary(cx: &mut TestAppContext) {
3510 init_test_settings(cx);
3511
3512 let project = create_test_project(cx, json!({})).await;
3513
3514 let (_, _thread_store, thread, _context_store, model) =
3515 setup_test_environment(cx, project.clone()).await;
3516
3517 // Initial state should be pending
3518 thread.read_with(cx, |thread, _| {
3519 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3520 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3521 });
3522
3523 // Manually setting the summary should not be allowed in this state
3524 thread.update(cx, |thread, cx| {
3525 thread.set_summary("This should not work", cx);
3526 });
3527
3528 thread.read_with(cx, |thread, _| {
3529 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3530 });
3531
3532 // Send a message
3533 thread.update(cx, |thread, cx| {
3534 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3535 thread.send_to_model(
3536 model.clone(),
3537 CompletionIntent::ThreadSummarization,
3538 None,
3539 cx,
3540 );
3541 });
3542
3543 let fake_model = model.as_fake();
3544 simulate_successful_response(&fake_model, cx);
3545
3546 // Should start generating summary when there are >= 2 messages
3547 thread.read_with(cx, |thread, _| {
3548 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3549 });
3550
3551 // Should not be able to set the summary while generating
3552 thread.update(cx, |thread, cx| {
3553 thread.set_summary("This should not work either", cx);
3554 });
3555
3556 thread.read_with(cx, |thread, _| {
3557 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3558 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3559 });
3560
3561 cx.run_until_parked();
3562 fake_model.stream_last_completion_response("Brief");
3563 fake_model.stream_last_completion_response(" Introduction");
3564 fake_model.end_last_completion_stream();
3565 cx.run_until_parked();
3566
3567 // Summary should be set
3568 thread.read_with(cx, |thread, _| {
3569 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3570 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3571 });
3572
3573 // Now we should be able to set a summary
3574 thread.update(cx, |thread, cx| {
3575 thread.set_summary("Brief Intro", cx);
3576 });
3577
3578 thread.read_with(cx, |thread, _| {
3579 assert_eq!(thread.summary().or_default(), "Brief Intro");
3580 });
3581
3582 // Test setting an empty summary (should default to DEFAULT)
3583 thread.update(cx, |thread, cx| {
3584 thread.set_summary("", cx);
3585 });
3586
3587 thread.read_with(cx, |thread, _| {
3588 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3589 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3590 });
3591 }
3592
3593 #[gpui::test]
3594 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3595 init_test_settings(cx);
3596
3597 let project = create_test_project(cx, json!({})).await;
3598
3599 let (_, _thread_store, thread, _context_store, model) =
3600 setup_test_environment(cx, project.clone()).await;
3601
3602 test_summarize_error(&model, &thread, cx);
3603
3604 // Now we should be able to set a summary
3605 thread.update(cx, |thread, cx| {
3606 thread.set_summary("Brief Intro", cx);
3607 });
3608
3609 thread.read_with(cx, |thread, _| {
3610 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3611 assert_eq!(thread.summary().or_default(), "Brief Intro");
3612 });
3613 }
3614
3615 #[gpui::test]
3616 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3617 init_test_settings(cx);
3618
3619 let project = create_test_project(cx, json!({})).await;
3620
3621 let (_, _thread_store, thread, _context_store, model) =
3622 setup_test_environment(cx, project.clone()).await;
3623
3624 test_summarize_error(&model, &thread, cx);
3625
3626 // Sending another message should not trigger another summarize request
3627 thread.update(cx, |thread, cx| {
3628 thread.insert_user_message(
3629 "How are you?",
3630 ContextLoadResult::default(),
3631 None,
3632 vec![],
3633 cx,
3634 );
3635 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3636 });
3637
3638 let fake_model = model.as_fake();
3639 simulate_successful_response(&fake_model, cx);
3640
3641 thread.read_with(cx, |thread, _| {
3642 // State is still Error, not Generating
3643 assert!(matches!(thread.summary(), ThreadSummary::Error));
3644 });
3645
3646 // But the summarize request can be invoked manually
3647 thread.update(cx, |thread, cx| {
3648 thread.summarize(cx);
3649 });
3650
3651 thread.read_with(cx, |thread, _| {
3652 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3653 });
3654
3655 cx.run_until_parked();
3656 fake_model.stream_last_completion_response("A successful summary");
3657 fake_model.end_last_completion_stream();
3658 cx.run_until_parked();
3659
3660 thread.read_with(cx, |thread, _| {
3661 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3662 assert_eq!(thread.summary().or_default(), "A successful summary");
3663 });
3664 }
3665
3666 #[gpui::test]
3667 fn test_resolve_tool_name_conflicts() {
3668 use assistant_tool::{Tool, ToolSource};
3669
3670 assert_resolve_tool_name_conflicts(
3671 vec![
3672 TestTool::new("tool1", ToolSource::Native),
3673 TestTool::new("tool2", ToolSource::Native),
3674 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3675 ],
3676 vec!["tool1", "tool2", "tool3"],
3677 );
3678
3679 assert_resolve_tool_name_conflicts(
3680 vec![
3681 TestTool::new("tool1", ToolSource::Native),
3682 TestTool::new("tool2", ToolSource::Native),
3683 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3684 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3685 ],
3686 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3687 );
3688
3689 assert_resolve_tool_name_conflicts(
3690 vec![
3691 TestTool::new("tool1", ToolSource::Native),
3692 TestTool::new("tool2", ToolSource::Native),
3693 TestTool::new("tool3", ToolSource::Native),
3694 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3695 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3696 ],
3697 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3698 );
3699
3700 // Test that tool with very long name is always truncated
3701 assert_resolve_tool_name_conflicts(
3702 vec![TestTool::new(
3703 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3704 ToolSource::Native,
3705 )],
3706 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3707 );
3708
3709 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3710 assert_resolve_tool_name_conflicts(
3711 vec![
3712 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3713 TestTool::new(
3714 "tool-with-very-very-very-long-name",
3715 ToolSource::ContextServer {
3716 id: "mcp-with-very-very-very-long-name".into(),
3717 },
3718 ),
3719 ],
3720 vec![
3721 "tool-with-very-very-very-long-name",
3722 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3723 ],
3724 );
3725
3726 fn assert_resolve_tool_name_conflicts(
3727 tools: Vec<TestTool>,
3728 expected: Vec<impl Into<String>>,
3729 ) {
3730 let tools: Vec<Arc<dyn Tool>> = tools
3731 .into_iter()
3732 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3733 .collect();
3734 let tools = resolve_tool_name_conflicts(&tools);
3735 assert_eq!(tools.len(), expected.len());
3736 for (i, expected_name) in expected.into_iter().enumerate() {
3737 let expected_name = expected_name.into();
3738 let actual_name = &tools[i].0;
3739 assert_eq!(
3740 actual_name, &expected_name,
3741 "Expected '{}' got '{}' at index {}",
3742 expected_name, actual_name, i
3743 );
3744 }
3745 }
3746
3747 struct TestTool {
3748 name: String,
3749 source: ToolSource,
3750 }
3751
3752 impl TestTool {
3753 fn new(name: impl Into<String>, source: ToolSource) -> Self {
3754 Self {
3755 name: name.into(),
3756 source,
3757 }
3758 }
3759 }
3760
3761 impl Tool for TestTool {
3762 fn name(&self) -> String {
3763 self.name.clone()
3764 }
3765
3766 fn icon(&self) -> IconName {
3767 IconName::Ai
3768 }
3769
3770 fn may_perform_edits(&self) -> bool {
3771 false
3772 }
3773
3774 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3775 true
3776 }
3777
3778 fn source(&self) -> ToolSource {
3779 self.source.clone()
3780 }
3781
3782 fn description(&self) -> String {
3783 "Test tool".to_string()
3784 }
3785
3786 fn ui_text(&self, _input: &serde_json::Value) -> String {
3787 "Test tool".to_string()
3788 }
3789
3790 fn run(
3791 self: Arc<Self>,
3792 _input: serde_json::Value,
3793 _request: Arc<LanguageModelRequest>,
3794 _project: Entity<Project>,
3795 _action_log: Entity<ActionLog>,
3796 _model: Arc<dyn LanguageModel>,
3797 _window: Option<AnyWindowHandle>,
3798 _cx: &mut App,
3799 ) -> assistant_tool::ToolResult {
3800 assistant_tool::ToolResult {
3801 output: Task::ready(Err(anyhow::anyhow!("No content"))),
3802 card: None,
3803 }
3804 }
3805 }
3806 }
3807
3808 fn test_summarize_error(
3809 model: &Arc<dyn LanguageModel>,
3810 thread: &Entity<Thread>,
3811 cx: &mut TestAppContext,
3812 ) {
3813 thread.update(cx, |thread, cx| {
3814 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3815 thread.send_to_model(
3816 model.clone(),
3817 CompletionIntent::ThreadSummarization,
3818 None,
3819 cx,
3820 );
3821 });
3822
3823 let fake_model = model.as_fake();
3824 simulate_successful_response(&fake_model, cx);
3825
3826 thread.read_with(cx, |thread, _| {
3827 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3828 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3829 });
3830
3831 // Simulate summary request ending
3832 cx.run_until_parked();
3833 fake_model.end_last_completion_stream();
3834 cx.run_until_parked();
3835
3836 // State is set to Error and default message
3837 thread.read_with(cx, |thread, _| {
3838 assert!(matches!(thread.summary(), ThreadSummary::Error));
3839 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3840 });
3841 }
3842
3843 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3844 cx.run_until_parked();
3845 fake_model.stream_last_completion_response("Assistant response");
3846 fake_model.end_last_completion_stream();
3847 cx.run_until_parked();
3848 }
3849
3850 fn init_test_settings(cx: &mut TestAppContext) {
3851 cx.update(|cx| {
3852 let settings_store = SettingsStore::test(cx);
3853 cx.set_global(settings_store);
3854 language::init(cx);
3855 Project::init_settings(cx);
3856 AgentSettings::register(cx);
3857 prompt_store::init(cx);
3858 thread_store::init(cx);
3859 workspace::init_settings(cx);
3860 language_model::init_settings(cx);
3861 ThemeSettings::register(cx);
3862 ToolRegistry::default_global(cx);
3863 });
3864 }
3865
3866 // Helper to create a test project with test files
3867 async fn create_test_project(
3868 cx: &mut TestAppContext,
3869 files: serde_json::Value,
3870 ) -> Entity<Project> {
3871 let fs = FakeFs::new(cx.executor());
3872 fs.insert_tree(path!("/test"), files).await;
3873 Project::test(fs, [path!("/test").as_ref()], cx).await
3874 }
3875
3876 async fn setup_test_environment(
3877 cx: &mut TestAppContext,
3878 project: Entity<Project>,
3879 ) -> (
3880 Entity<Workspace>,
3881 Entity<ThreadStore>,
3882 Entity<Thread>,
3883 Entity<ContextStore>,
3884 Arc<dyn LanguageModel>,
3885 ) {
3886 let (workspace, cx) =
3887 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3888
3889 let thread_store = cx
3890 .update(|_, cx| {
3891 ThreadStore::load(
3892 project.clone(),
3893 cx.new(|_| ToolWorkingSet::default()),
3894 None,
3895 Arc::new(PromptBuilder::new(None).unwrap()),
3896 cx,
3897 )
3898 })
3899 .await
3900 .unwrap();
3901
3902 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3903 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3904
3905 let provider = Arc::new(FakeLanguageModelProvider);
3906 let model = provider.test_model();
3907 let model: Arc<dyn LanguageModel> = Arc::new(model);
3908
3909 cx.update(|_, cx| {
3910 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3911 registry.set_default_model(
3912 Some(ConfiguredModel {
3913 provider: provider.clone(),
3914 model: model.clone(),
3915 }),
3916 cx,
3917 );
3918 registry.set_thread_summary_model(
3919 Some(ConfiguredModel {
3920 provider,
3921 model: model.clone(),
3922 }),
3923 cx,
3924 );
3925 })
3926 });
3927
3928 (workspace, thread_store, thread, context_store, model)
3929 }
3930
3931 async fn add_file_to_context(
3932 project: &Entity<Project>,
3933 context_store: &Entity<ContextStore>,
3934 path: &str,
3935 cx: &mut TestAppContext,
3936 ) -> Result<Entity<language::Buffer>> {
3937 let buffer_path = project
3938 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3939 .unwrap();
3940
3941 let buffer = project
3942 .update(cx, |project, cx| {
3943 project.open_buffer(buffer_path.clone(), cx)
3944 })
3945 .await
3946 .unwrap();
3947
3948 context_store.update(cx, |context_store, cx| {
3949 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3950 });
3951
3952 Ok(buffer)
3953 }
3954}