1use crate::{
2 agent_profile::AgentProfile,
3 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
4 thread_store::{
5 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
6 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
7 ThreadStore,
8 },
9 tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
10};
11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
12use anyhow::{Result, anyhow};
13use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
14use chrono::{DateTime, Utc};
15use client::{ModelRequestUsage, RequestUsage};
16use collections::{HashMap, HashSet};
17use feature_flags::{self, FeatureFlagAppExt};
18use futures::{FutureExt, StreamExt as _, future::Shared};
19use git::repository::DiffType;
20use gpui::{
21 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
22 WeakEntity, Window,
23};
24use language_model::{
25 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
26 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
27 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
28 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
29 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
30 TokenUsage,
31};
32use postage::stream::Stream as _;
33use project::{
34 Project,
35 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
36};
37use prompt_store::{ModelContext, PromptBuilder};
38use proto::Plan;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use settings::Settings;
42use std::{io::Write, ops::Range, sync::Arc, time::Instant};
43use thiserror::Error;
44use util::{ResultExt as _, post_inc};
45use uuid::Uuid;
46use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
47
48#[derive(
49 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
50)]
51pub struct ThreadId(Arc<str>);
52
53impl ThreadId {
54 pub fn new() -> Self {
55 Self(Uuid::new_v4().to_string().into())
56 }
57}
58
59impl std::fmt::Display for ThreadId {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65impl From<&str> for ThreadId {
66 fn from(value: &str) -> Self {
67 Self(value.into())
68 }
69}
70
71/// The ID of the user prompt that initiated a request.
72///
73/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
74#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
75pub struct PromptId(Arc<str>);
76
77impl PromptId {
78 pub fn new() -> Self {
79 Self(Uuid::new_v4().to_string().into())
80 }
81}
82
83impl std::fmt::Display for PromptId {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "{}", self.0)
86 }
87}
88
89#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
90pub struct MessageId(pub(crate) usize);
91
92impl MessageId {
93 fn post_inc(&mut self) -> Self {
94 Self(post_inc(&mut self.0))
95 }
96
97 pub fn as_usize(&self) -> usize {
98 self.0
99 }
100}
101
102/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
103#[derive(Clone, Debug)]
104pub struct MessageCrease {
105 pub range: Range<usize>,
106 pub icon_path: SharedString,
107 pub label: SharedString,
108 /// None for a deserialized message, Some otherwise.
109 pub context: Option<AgentContextHandle>,
110}
111
112/// A message in a [`Thread`].
113#[derive(Debug, Clone)]
114pub struct Message {
115 pub id: MessageId,
116 pub role: Role,
117 pub segments: Vec<MessageSegment>,
118 pub loaded_context: LoadedContext,
119 pub creases: Vec<MessageCrease>,
120 pub is_hidden: bool,
121}
122
123impl Message {
124 /// Returns whether the message contains any meaningful text that should be displayed
125 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
126 pub fn should_display_content(&self) -> bool {
127 self.segments.iter().all(|segment| segment.should_display())
128 }
129
130 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
131 if let Some(MessageSegment::Thinking {
132 text: segment,
133 signature: current_signature,
134 }) = self.segments.last_mut()
135 {
136 if let Some(signature) = signature {
137 *current_signature = Some(signature);
138 }
139 segment.push_str(text);
140 } else {
141 self.segments.push(MessageSegment::Thinking {
142 text: text.to_string(),
143 signature,
144 });
145 }
146 }
147
148 pub fn push_redacted_thinking(&mut self, data: String) {
149 self.segments.push(MessageSegment::RedactedThinking(data));
150 }
151
152 pub fn push_text(&mut self, text: &str) {
153 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
154 segment.push_str(text);
155 } else {
156 self.segments.push(MessageSegment::Text(text.to_string()));
157 }
158 }
159
160 pub fn to_string(&self) -> String {
161 let mut result = String::new();
162
163 if !self.loaded_context.text.is_empty() {
164 result.push_str(&self.loaded_context.text);
165 }
166
167 for segment in &self.segments {
168 match segment {
169 MessageSegment::Text(text) => result.push_str(text),
170 MessageSegment::Thinking { text, .. } => {
171 result.push_str("<think>\n");
172 result.push_str(text);
173 result.push_str("\n</think>");
174 }
175 MessageSegment::RedactedThinking(_) => {}
176 }
177 }
178
179 result
180 }
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub enum MessageSegment {
185 Text(String),
186 Thinking {
187 text: String,
188 signature: Option<String>,
189 },
190 RedactedThinking(String),
191}
192
193impl MessageSegment {
194 pub fn should_display(&self) -> bool {
195 match self {
196 Self::Text(text) => text.is_empty(),
197 Self::Thinking { text, .. } => text.is_empty(),
198 Self::RedactedThinking(_) => false,
199 }
200 }
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
204pub struct ProjectSnapshot {
205 pub worktree_snapshots: Vec<WorktreeSnapshot>,
206 pub unsaved_buffer_paths: Vec<String>,
207 pub timestamp: DateTime<Utc>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
211pub struct WorktreeSnapshot {
212 pub worktree_path: String,
213 pub git_state: Option<GitState>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
217pub struct GitState {
218 pub remote_url: Option<String>,
219 pub head_sha: Option<String>,
220 pub current_branch: Option<String>,
221 pub diff: Option<String>,
222}
223
224#[derive(Clone, Debug)]
225pub struct ThreadCheckpoint {
226 message_id: MessageId,
227 git_checkpoint: GitStoreCheckpoint,
228}
229
230#[derive(Copy, Clone, Debug, PartialEq, Eq)]
231pub enum ThreadFeedback {
232 Positive,
233 Negative,
234}
235
236pub enum LastRestoreCheckpoint {
237 Pending {
238 message_id: MessageId,
239 },
240 Error {
241 message_id: MessageId,
242 error: String,
243 },
244}
245
246impl LastRestoreCheckpoint {
247 pub fn message_id(&self) -> MessageId {
248 match self {
249 LastRestoreCheckpoint::Pending { message_id } => *message_id,
250 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
251 }
252 }
253}
254
255#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
256pub enum DetailedSummaryState {
257 #[default]
258 NotGenerated,
259 Generating {
260 message_id: MessageId,
261 },
262 Generated {
263 text: SharedString,
264 message_id: MessageId,
265 },
266}
267
268impl DetailedSummaryState {
269 fn text(&self) -> Option<SharedString> {
270 if let Self::Generated { text, .. } = self {
271 Some(text.clone())
272 } else {
273 None
274 }
275 }
276}
277
278#[derive(Default, Debug)]
279pub struct TotalTokenUsage {
280 pub total: u64,
281 pub max: u64,
282}
283
284impl TotalTokenUsage {
285 pub fn ratio(&self) -> TokenUsageRatio {
286 #[cfg(debug_assertions)]
287 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
288 .unwrap_or("0.8".to_string())
289 .parse()
290 .unwrap();
291 #[cfg(not(debug_assertions))]
292 let warning_threshold: f32 = 0.8;
293
294 // When the maximum is unknown because there is no selected model,
295 // avoid showing the token limit warning.
296 if self.max == 0 {
297 TokenUsageRatio::Normal
298 } else if self.total >= self.max {
299 TokenUsageRatio::Exceeded
300 } else if self.total as f32 / self.max as f32 >= warning_threshold {
301 TokenUsageRatio::Warning
302 } else {
303 TokenUsageRatio::Normal
304 }
305 }
306
307 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
308 TotalTokenUsage {
309 total: self.total + tokens,
310 max: self.max,
311 }
312 }
313}
314
315#[derive(Debug, Default, PartialEq, Eq)]
316pub enum TokenUsageRatio {
317 #[default]
318 Normal,
319 Warning,
320 Exceeded,
321}
322
323#[derive(Debug, Clone, Copy)]
324pub enum QueueState {
325 Sending,
326 Queued { position: usize },
327 Started,
328}
329
330/// A thread of conversation with the LLM.
331pub struct Thread {
332 id: ThreadId,
333 updated_at: DateTime<Utc>,
334 summary: ThreadSummary,
335 pending_summary: Task<Option<()>>,
336 detailed_summary_task: Task<Option<()>>,
337 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
338 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
339 completion_mode: agent_settings::CompletionMode,
340 messages: Vec<Message>,
341 next_message_id: MessageId,
342 last_prompt_id: PromptId,
343 project_context: SharedProjectContext,
344 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
345 completion_count: usize,
346 pending_completions: Vec<PendingCompletion>,
347 project: Entity<Project>,
348 prompt_builder: Arc<PromptBuilder>,
349 tools: Entity<ToolWorkingSet>,
350 tool_use: ToolUseState,
351 action_log: Entity<ActionLog>,
352 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
353 pending_checkpoint: Option<ThreadCheckpoint>,
354 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
355 request_token_usage: Vec<TokenUsage>,
356 cumulative_token_usage: TokenUsage,
357 exceeded_window_error: Option<ExceededWindowError>,
358 tool_use_limit_reached: bool,
359 feedback: Option<ThreadFeedback>,
360 message_feedback: HashMap<MessageId, ThreadFeedback>,
361 last_auto_capture_at: Option<Instant>,
362 last_received_chunk_at: Option<Instant>,
363 request_callback: Option<
364 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
365 >,
366 remaining_turns: u32,
367 configured_model: Option<ConfiguredModel>,
368 profile: AgentProfile,
369}
370
371#[derive(Clone, Debug, PartialEq, Eq)]
372pub enum ThreadSummary {
373 Pending,
374 Generating,
375 Ready(SharedString),
376 Error,
377}
378
379impl ThreadSummary {
380 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
381
382 pub fn or_default(&self) -> SharedString {
383 self.unwrap_or(Self::DEFAULT)
384 }
385
386 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
387 self.ready().unwrap_or_else(|| message.into())
388 }
389
390 pub fn ready(&self) -> Option<SharedString> {
391 match self {
392 ThreadSummary::Ready(summary) => Some(summary.clone()),
393 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
394 }
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
399pub struct ExceededWindowError {
400 /// Model used when last message exceeded context window
401 model_id: LanguageModelId,
402 /// Token count including last message
403 token_count: u64,
404}
405
406impl Thread {
407 pub fn new(
408 project: Entity<Project>,
409 tools: Entity<ToolWorkingSet>,
410 prompt_builder: Arc<PromptBuilder>,
411 system_prompt: SharedProjectContext,
412 cx: &mut Context<Self>,
413 ) -> Self {
414 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
415 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
416 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
417
418 Self {
419 id: ThreadId::new(),
420 updated_at: Utc::now(),
421 summary: ThreadSummary::Pending,
422 pending_summary: Task::ready(None),
423 detailed_summary_task: Task::ready(None),
424 detailed_summary_tx,
425 detailed_summary_rx,
426 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
427 messages: Vec::new(),
428 next_message_id: MessageId(0),
429 last_prompt_id: PromptId::new(),
430 project_context: system_prompt,
431 checkpoints_by_message: HashMap::default(),
432 completion_count: 0,
433 pending_completions: Vec::new(),
434 project: project.clone(),
435 prompt_builder,
436 tools: tools.clone(),
437 last_restore_checkpoint: None,
438 pending_checkpoint: None,
439 tool_use: ToolUseState::new(tools.clone()),
440 action_log: cx.new(|_| ActionLog::new(project.clone())),
441 initial_project_snapshot: {
442 let project_snapshot = Self::project_snapshot(project, cx);
443 cx.foreground_executor()
444 .spawn(async move { Some(project_snapshot.await) })
445 .shared()
446 },
447 request_token_usage: Vec::new(),
448 cumulative_token_usage: TokenUsage::default(),
449 exceeded_window_error: None,
450 tool_use_limit_reached: false,
451 feedback: None,
452 message_feedback: HashMap::default(),
453 last_auto_capture_at: None,
454 last_received_chunk_at: None,
455 request_callback: None,
456 remaining_turns: u32::MAX,
457 configured_model,
458 profile: AgentProfile::new(profile_id, tools),
459 }
460 }
461
462 pub fn deserialize(
463 id: ThreadId,
464 serialized: SerializedThread,
465 project: Entity<Project>,
466 tools: Entity<ToolWorkingSet>,
467 prompt_builder: Arc<PromptBuilder>,
468 project_context: SharedProjectContext,
469 window: Option<&mut Window>, // None in headless mode
470 cx: &mut Context<Self>,
471 ) -> Self {
472 let next_message_id = MessageId(
473 serialized
474 .messages
475 .last()
476 .map(|message| message.id.0 + 1)
477 .unwrap_or(0),
478 );
479 let tool_use = ToolUseState::from_serialized_messages(
480 tools.clone(),
481 &serialized.messages,
482 project.clone(),
483 window,
484 cx,
485 );
486 let (detailed_summary_tx, detailed_summary_rx) =
487 postage::watch::channel_with(serialized.detailed_summary_state);
488
489 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
490 serialized
491 .model
492 .and_then(|model| {
493 let model = SelectedModel {
494 provider: model.provider.clone().into(),
495 model: model.model.clone().into(),
496 };
497 registry.select_model(&model, cx)
498 })
499 .or_else(|| registry.default_model())
500 });
501
502 let completion_mode = serialized
503 .completion_mode
504 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
505 let profile_id = serialized
506 .profile
507 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
508
509 Self {
510 id,
511 updated_at: serialized.updated_at,
512 summary: ThreadSummary::Ready(serialized.summary),
513 pending_summary: Task::ready(None),
514 detailed_summary_task: Task::ready(None),
515 detailed_summary_tx,
516 detailed_summary_rx,
517 completion_mode,
518 messages: serialized
519 .messages
520 .into_iter()
521 .map(|message| Message {
522 id: message.id,
523 role: message.role,
524 segments: message
525 .segments
526 .into_iter()
527 .map(|segment| match segment {
528 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
529 SerializedMessageSegment::Thinking { text, signature } => {
530 MessageSegment::Thinking { text, signature }
531 }
532 SerializedMessageSegment::RedactedThinking { data } => {
533 MessageSegment::RedactedThinking(data)
534 }
535 })
536 .collect(),
537 loaded_context: LoadedContext {
538 contexts: Vec::new(),
539 text: message.context,
540 images: Vec::new(),
541 },
542 creases: message
543 .creases
544 .into_iter()
545 .map(|crease| MessageCrease {
546 range: crease.start..crease.end,
547 icon_path: crease.icon_path,
548 label: crease.label,
549 context: None,
550 })
551 .collect(),
552 is_hidden: message.is_hidden,
553 })
554 .collect(),
555 next_message_id,
556 last_prompt_id: PromptId::new(),
557 project_context,
558 checkpoints_by_message: HashMap::default(),
559 completion_count: 0,
560 pending_completions: Vec::new(),
561 last_restore_checkpoint: None,
562 pending_checkpoint: None,
563 project: project.clone(),
564 prompt_builder,
565 tools: tools.clone(),
566 tool_use,
567 action_log: cx.new(|_| ActionLog::new(project)),
568 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
569 request_token_usage: serialized.request_token_usage,
570 cumulative_token_usage: serialized.cumulative_token_usage,
571 exceeded_window_error: None,
572 tool_use_limit_reached: serialized.tool_use_limit_reached,
573 feedback: None,
574 message_feedback: HashMap::default(),
575 last_auto_capture_at: None,
576 last_received_chunk_at: None,
577 request_callback: None,
578 remaining_turns: u32::MAX,
579 configured_model,
580 profile: AgentProfile::new(profile_id, tools),
581 }
582 }
583
584 pub fn set_request_callback(
585 &mut self,
586 callback: impl 'static
587 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
588 ) {
589 self.request_callback = Some(Box::new(callback));
590 }
591
592 pub fn id(&self) -> &ThreadId {
593 &self.id
594 }
595
596 pub fn profile(&self) -> &AgentProfile {
597 &self.profile
598 }
599
600 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
601 if &id != self.profile.id() {
602 self.profile = AgentProfile::new(id, self.tools.clone());
603 cx.emit(ThreadEvent::ProfileChanged);
604 }
605 }
606
607 pub fn is_empty(&self) -> bool {
608 self.messages.is_empty()
609 }
610
611 pub fn updated_at(&self) -> DateTime<Utc> {
612 self.updated_at
613 }
614
615 pub fn touch_updated_at(&mut self) {
616 self.updated_at = Utc::now();
617 }
618
619 pub fn advance_prompt_id(&mut self) {
620 self.last_prompt_id = PromptId::new();
621 }
622
623 pub fn project_context(&self) -> SharedProjectContext {
624 self.project_context.clone()
625 }
626
627 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
628 if self.configured_model.is_none() {
629 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
630 }
631 self.configured_model.clone()
632 }
633
634 pub fn configured_model(&self) -> Option<ConfiguredModel> {
635 self.configured_model.clone()
636 }
637
638 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
639 self.configured_model = model;
640 cx.notify();
641 }
642
643 pub fn summary(&self) -> &ThreadSummary {
644 &self.summary
645 }
646
647 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
648 let current_summary = match &self.summary {
649 ThreadSummary::Pending | ThreadSummary::Generating => return,
650 ThreadSummary::Ready(summary) => summary,
651 ThreadSummary::Error => &ThreadSummary::DEFAULT,
652 };
653
654 let mut new_summary = new_summary.into();
655
656 if new_summary.is_empty() {
657 new_summary = ThreadSummary::DEFAULT;
658 }
659
660 if current_summary != &new_summary {
661 self.summary = ThreadSummary::Ready(new_summary);
662 cx.emit(ThreadEvent::SummaryChanged);
663 }
664 }
665
666 pub fn completion_mode(&self) -> CompletionMode {
667 self.completion_mode
668 }
669
670 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
671 self.completion_mode = mode;
672 }
673
674 pub fn message(&self, id: MessageId) -> Option<&Message> {
675 let index = self
676 .messages
677 .binary_search_by(|message| message.id.cmp(&id))
678 .ok()?;
679
680 self.messages.get(index)
681 }
682
683 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
684 self.messages.iter()
685 }
686
687 pub fn is_generating(&self) -> bool {
688 !self.pending_completions.is_empty() || !self.all_tools_finished()
689 }
690
691 /// Indicates whether streaming of language model events is stale.
692 /// When `is_generating()` is false, this method returns `None`.
693 pub fn is_generation_stale(&self) -> Option<bool> {
694 const STALE_THRESHOLD: u128 = 250;
695
696 self.last_received_chunk_at
697 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
698 }
699
700 fn received_chunk(&mut self) {
701 self.last_received_chunk_at = Some(Instant::now());
702 }
703
704 pub fn queue_state(&self) -> Option<QueueState> {
705 self.pending_completions
706 .first()
707 .map(|pending_completion| pending_completion.queue_state)
708 }
709
710 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
711 &self.tools
712 }
713
714 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
715 self.tool_use
716 .pending_tool_uses()
717 .into_iter()
718 .find(|tool_use| &tool_use.id == id)
719 }
720
721 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
722 self.tool_use
723 .pending_tool_uses()
724 .into_iter()
725 .filter(|tool_use| tool_use.status.needs_confirmation())
726 }
727
728 pub fn has_pending_tool_uses(&self) -> bool {
729 !self.tool_use.pending_tool_uses().is_empty()
730 }
731
732 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
733 self.checkpoints_by_message.get(&id).cloned()
734 }
735
736 pub fn restore_checkpoint(
737 &mut self,
738 checkpoint: ThreadCheckpoint,
739 cx: &mut Context<Self>,
740 ) -> Task<Result<()>> {
741 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
742 message_id: checkpoint.message_id,
743 });
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746
747 let git_store = self.project().read(cx).git_store().clone();
748 let restore = git_store.update(cx, |git_store, cx| {
749 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
750 });
751
752 cx.spawn(async move |this, cx| {
753 let result = restore.await;
754 this.update(cx, |this, cx| {
755 if let Err(err) = result.as_ref() {
756 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
757 message_id: checkpoint.message_id,
758 error: err.to_string(),
759 });
760 } else {
761 this.truncate(checkpoint.message_id, cx);
762 this.last_restore_checkpoint = None;
763 }
764 this.pending_checkpoint = None;
765 cx.emit(ThreadEvent::CheckpointChanged);
766 cx.notify();
767 })?;
768 result
769 })
770 }
771
772 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
773 let pending_checkpoint = if self.is_generating() {
774 return;
775 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
776 checkpoint
777 } else {
778 return;
779 };
780
781 self.finalize_checkpoint(pending_checkpoint, cx);
782 }
783
784 fn finalize_checkpoint(
785 &mut self,
786 pending_checkpoint: ThreadCheckpoint,
787 cx: &mut Context<Self>,
788 ) {
789 let git_store = self.project.read(cx).git_store().clone();
790 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
791 cx.spawn(async move |this, cx| match final_checkpoint.await {
792 Ok(final_checkpoint) => {
793 let equal = git_store
794 .update(cx, |store, cx| {
795 store.compare_checkpoints(
796 pending_checkpoint.git_checkpoint.clone(),
797 final_checkpoint.clone(),
798 cx,
799 )
800 })?
801 .await
802 .unwrap_or(false);
803
804 if !equal {
805 this.update(cx, |this, cx| {
806 this.insert_checkpoint(pending_checkpoint, cx)
807 })?;
808 }
809
810 Ok(())
811 }
812 Err(_) => this.update(cx, |this, cx| {
813 this.insert_checkpoint(pending_checkpoint, cx)
814 }),
815 })
816 .detach();
817 }
818
819 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
820 self.checkpoints_by_message
821 .insert(checkpoint.message_id, checkpoint);
822 cx.emit(ThreadEvent::CheckpointChanged);
823 cx.notify();
824 }
825
826 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
827 self.last_restore_checkpoint.as_ref()
828 }
829
830 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
831 let Some(message_ix) = self
832 .messages
833 .iter()
834 .rposition(|message| message.id == message_id)
835 else {
836 return;
837 };
838 for deleted_message in self.messages.drain(message_ix..) {
839 self.checkpoints_by_message.remove(&deleted_message.id);
840 }
841 cx.notify();
842 }
843
844 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
845 self.messages
846 .iter()
847 .find(|message| message.id == id)
848 .into_iter()
849 .flat_map(|message| message.loaded_context.contexts.iter())
850 }
851
852 pub fn is_turn_end(&self, ix: usize) -> bool {
853 if self.messages.is_empty() {
854 return false;
855 }
856
857 if !self.is_generating() && ix == self.messages.len() - 1 {
858 return true;
859 }
860
861 let Some(message) = self.messages.get(ix) else {
862 return false;
863 };
864
865 if message.role != Role::Assistant {
866 return false;
867 }
868
869 self.messages
870 .get(ix + 1)
871 .and_then(|message| {
872 self.message(message.id)
873 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
874 })
875 .unwrap_or(false)
876 }
877
878 pub fn tool_use_limit_reached(&self) -> bool {
879 self.tool_use_limit_reached
880 }
881
882 /// Returns whether all of the tool uses have finished running.
883 pub fn all_tools_finished(&self) -> bool {
884 // If the only pending tool uses left are the ones with errors, then
885 // that means that we've finished running all of the pending tools.
886 self.tool_use
887 .pending_tool_uses()
888 .iter()
889 .all(|pending_tool_use| pending_tool_use.status.is_error())
890 }
891
892 /// Returns whether any pending tool uses may perform edits
893 pub fn has_pending_edit_tool_uses(&self) -> bool {
894 self.tool_use
895 .pending_tool_uses()
896 .iter()
897 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
898 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
899 }
900
901 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
902 self.tool_use.tool_uses_for_message(id, cx)
903 }
904
905 pub fn tool_results_for_message(
906 &self,
907 assistant_message_id: MessageId,
908 ) -> Vec<&LanguageModelToolResult> {
909 self.tool_use.tool_results_for_message(assistant_message_id)
910 }
911
912 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
913 self.tool_use.tool_result(id)
914 }
915
916 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
917 match &self.tool_use.tool_result(id)?.content {
918 LanguageModelToolResultContent::Text(text) => Some(text),
919 LanguageModelToolResultContent::Image(_) => {
920 // TODO: We should display image
921 None
922 }
923 }
924 }
925
926 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
927 self.tool_use.tool_result_card(id).cloned()
928 }
929
930 /// Return tools that are both enabled and supported by the model
931 pub fn available_tools(
932 &self,
933 cx: &App,
934 model: Arc<dyn LanguageModel>,
935 ) -> Vec<LanguageModelRequestTool> {
936 if model.supports_tools() {
937 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
938 .into_iter()
939 .filter_map(|(name, tool)| {
940 // Skip tools that cannot be supported
941 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
942 Some(LanguageModelRequestTool {
943 name,
944 description: tool.description(),
945 input_schema,
946 })
947 })
948 .collect()
949 } else {
950 Vec::default()
951 }
952 }
953
954 pub fn insert_user_message(
955 &mut self,
956 text: impl Into<String>,
957 loaded_context: ContextLoadResult,
958 git_checkpoint: Option<GitStoreCheckpoint>,
959 creases: Vec<MessageCrease>,
960 cx: &mut Context<Self>,
961 ) -> MessageId {
962 if !loaded_context.referenced_buffers.is_empty() {
963 self.action_log.update(cx, |log, cx| {
964 for buffer in loaded_context.referenced_buffers {
965 log.buffer_read(buffer, cx);
966 }
967 });
968 }
969
970 let message_id = self.insert_message(
971 Role::User,
972 vec![MessageSegment::Text(text.into())],
973 loaded_context.loaded_context,
974 creases,
975 false,
976 cx,
977 );
978
979 if let Some(git_checkpoint) = git_checkpoint {
980 self.pending_checkpoint = Some(ThreadCheckpoint {
981 message_id,
982 git_checkpoint,
983 });
984 }
985
986 self.auto_capture_telemetry(cx);
987
988 message_id
989 }
990
991 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
992 let id = self.insert_message(
993 Role::User,
994 vec![MessageSegment::Text("Continue where you left off".into())],
995 LoadedContext::default(),
996 vec![],
997 true,
998 cx,
999 );
1000 self.pending_checkpoint = None;
1001
1002 id
1003 }
1004
1005 pub fn insert_assistant_message(
1006 &mut self,
1007 segments: Vec<MessageSegment>,
1008 cx: &mut Context<Self>,
1009 ) -> MessageId {
1010 self.insert_message(
1011 Role::Assistant,
1012 segments,
1013 LoadedContext::default(),
1014 Vec::new(),
1015 false,
1016 cx,
1017 )
1018 }
1019
1020 pub fn insert_message(
1021 &mut self,
1022 role: Role,
1023 segments: Vec<MessageSegment>,
1024 loaded_context: LoadedContext,
1025 creases: Vec<MessageCrease>,
1026 is_hidden: bool,
1027 cx: &mut Context<Self>,
1028 ) -> MessageId {
1029 let id = self.next_message_id.post_inc();
1030 self.messages.push(Message {
1031 id,
1032 role,
1033 segments,
1034 loaded_context,
1035 creases,
1036 is_hidden,
1037 });
1038 self.touch_updated_at();
1039 cx.emit(ThreadEvent::MessageAdded(id));
1040 id
1041 }
1042
1043 pub fn edit_message(
1044 &mut self,
1045 id: MessageId,
1046 new_role: Role,
1047 new_segments: Vec<MessageSegment>,
1048 creases: Vec<MessageCrease>,
1049 loaded_context: Option<LoadedContext>,
1050 checkpoint: Option<GitStoreCheckpoint>,
1051 cx: &mut Context<Self>,
1052 ) -> bool {
1053 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1054 return false;
1055 };
1056 message.role = new_role;
1057 message.segments = new_segments;
1058 message.creases = creases;
1059 if let Some(context) = loaded_context {
1060 message.loaded_context = context;
1061 }
1062 if let Some(git_checkpoint) = checkpoint {
1063 self.checkpoints_by_message.insert(
1064 id,
1065 ThreadCheckpoint {
1066 message_id: id,
1067 git_checkpoint,
1068 },
1069 );
1070 }
1071 self.touch_updated_at();
1072 cx.emit(ThreadEvent::MessageEdited(id));
1073 true
1074 }
1075
1076 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1077 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1078 return false;
1079 };
1080 self.messages.remove(index);
1081 self.touch_updated_at();
1082 cx.emit(ThreadEvent::MessageDeleted(id));
1083 true
1084 }
1085
1086 /// Returns the representation of this [`Thread`] in a textual form.
1087 ///
1088 /// This is the representation we use when attaching a thread as context to another thread.
1089 pub fn text(&self) -> String {
1090 let mut text = String::new();
1091
1092 for message in &self.messages {
1093 text.push_str(match message.role {
1094 language_model::Role::User => "User:",
1095 language_model::Role::Assistant => "Agent:",
1096 language_model::Role::System => "System:",
1097 });
1098 text.push('\n');
1099
1100 for segment in &message.segments {
1101 match segment {
1102 MessageSegment::Text(content) => text.push_str(content),
1103 MessageSegment::Thinking { text: content, .. } => {
1104 text.push_str(&format!("<think>{}</think>", content))
1105 }
1106 MessageSegment::RedactedThinking(_) => {}
1107 }
1108 }
1109 text.push('\n');
1110 }
1111
1112 text
1113 }
1114
1115 /// Serializes this thread into a format for storage or telemetry.
1116 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1117 let initial_project_snapshot = self.initial_project_snapshot.clone();
1118 cx.spawn(async move |this, cx| {
1119 let initial_project_snapshot = initial_project_snapshot.await;
1120 this.read_with(cx, |this, cx| SerializedThread {
1121 version: SerializedThread::VERSION.to_string(),
1122 summary: this.summary().or_default(),
1123 updated_at: this.updated_at(),
1124 messages: this
1125 .messages()
1126 .map(|message| SerializedMessage {
1127 id: message.id,
1128 role: message.role,
1129 segments: message
1130 .segments
1131 .iter()
1132 .map(|segment| match segment {
1133 MessageSegment::Text(text) => {
1134 SerializedMessageSegment::Text { text: text.clone() }
1135 }
1136 MessageSegment::Thinking { text, signature } => {
1137 SerializedMessageSegment::Thinking {
1138 text: text.clone(),
1139 signature: signature.clone(),
1140 }
1141 }
1142 MessageSegment::RedactedThinking(data) => {
1143 SerializedMessageSegment::RedactedThinking {
1144 data: data.clone(),
1145 }
1146 }
1147 })
1148 .collect(),
1149 tool_uses: this
1150 .tool_uses_for_message(message.id, cx)
1151 .into_iter()
1152 .map(|tool_use| SerializedToolUse {
1153 id: tool_use.id,
1154 name: tool_use.name,
1155 input: tool_use.input,
1156 })
1157 .collect(),
1158 tool_results: this
1159 .tool_results_for_message(message.id)
1160 .into_iter()
1161 .map(|tool_result| SerializedToolResult {
1162 tool_use_id: tool_result.tool_use_id.clone(),
1163 is_error: tool_result.is_error,
1164 content: tool_result.content.clone(),
1165 output: tool_result.output.clone(),
1166 })
1167 .collect(),
1168 context: message.loaded_context.text.clone(),
1169 creases: message
1170 .creases
1171 .iter()
1172 .map(|crease| SerializedCrease {
1173 start: crease.range.start,
1174 end: crease.range.end,
1175 icon_path: crease.icon_path.clone(),
1176 label: crease.label.clone(),
1177 })
1178 .collect(),
1179 is_hidden: message.is_hidden,
1180 })
1181 .collect(),
1182 initial_project_snapshot,
1183 cumulative_token_usage: this.cumulative_token_usage,
1184 request_token_usage: this.request_token_usage.clone(),
1185 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1186 exceeded_window_error: this.exceeded_window_error.clone(),
1187 model: this
1188 .configured_model
1189 .as_ref()
1190 .map(|model| SerializedLanguageModel {
1191 provider: model.provider.id().0.to_string(),
1192 model: model.model.id().0.to_string(),
1193 }),
1194 completion_mode: Some(this.completion_mode),
1195 tool_use_limit_reached: this.tool_use_limit_reached,
1196 profile: Some(this.profile.id().clone()),
1197 })
1198 })
1199 }
1200
1201 pub fn remaining_turns(&self) -> u32 {
1202 self.remaining_turns
1203 }
1204
1205 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1206 self.remaining_turns = remaining_turns;
1207 }
1208
1209 pub fn send_to_model(
1210 &mut self,
1211 model: Arc<dyn LanguageModel>,
1212 intent: CompletionIntent,
1213 window: Option<AnyWindowHandle>,
1214 cx: &mut Context<Self>,
1215 ) {
1216 if self.remaining_turns == 0 {
1217 return;
1218 }
1219
1220 self.remaining_turns -= 1;
1221
1222 let request = self.to_completion_request(model.clone(), intent, cx);
1223
1224 self.stream_completion(request, model, window, cx);
1225 }
1226
1227 pub fn used_tools_since_last_user_message(&self) -> bool {
1228 for message in self.messages.iter().rev() {
1229 if self.tool_use.message_has_tool_results(message.id) {
1230 return true;
1231 } else if message.role == Role::User {
1232 return false;
1233 }
1234 }
1235
1236 false
1237 }
1238
1239 pub fn to_completion_request(
1240 &self,
1241 model: Arc<dyn LanguageModel>,
1242 intent: CompletionIntent,
1243 cx: &mut Context<Self>,
1244 ) -> LanguageModelRequest {
1245 let mut request = LanguageModelRequest {
1246 thread_id: Some(self.id.to_string()),
1247 prompt_id: Some(self.last_prompt_id.to_string()),
1248 intent: Some(intent),
1249 mode: None,
1250 messages: vec![],
1251 tools: Vec::new(),
1252 tool_choice: None,
1253 stop: Vec::new(),
1254 temperature: AgentSettings::temperature_for_model(&model, cx),
1255 };
1256
1257 let available_tools = self.available_tools(cx, model.clone());
1258 let available_tool_names = available_tools
1259 .iter()
1260 .map(|tool| tool.name.clone())
1261 .collect();
1262
1263 let model_context = &ModelContext {
1264 available_tools: available_tool_names,
1265 };
1266
1267 if let Some(project_context) = self.project_context.borrow().as_ref() {
1268 match self
1269 .prompt_builder
1270 .generate_assistant_system_prompt(project_context, model_context)
1271 {
1272 Err(err) => {
1273 let message = format!("{err:?}").into();
1274 log::error!("{message}");
1275 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1276 header: "Error generating system prompt".into(),
1277 message,
1278 }));
1279 }
1280 Ok(system_prompt) => {
1281 request.messages.push(LanguageModelRequestMessage {
1282 role: Role::System,
1283 content: vec![MessageContent::Text(system_prompt)],
1284 cache: true,
1285 });
1286 }
1287 }
1288 } else {
1289 let message = "Context for system prompt unexpectedly not ready.".into();
1290 log::error!("{message}");
1291 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1292 header: "Error generating system prompt".into(),
1293 message,
1294 }));
1295 }
1296
1297 let mut message_ix_to_cache = None;
1298 for message in &self.messages {
1299 let mut request_message = LanguageModelRequestMessage {
1300 role: message.role,
1301 content: Vec::new(),
1302 cache: false,
1303 };
1304
1305 message
1306 .loaded_context
1307 .add_to_request_message(&mut request_message);
1308
1309 for segment in &message.segments {
1310 match segment {
1311 MessageSegment::Text(text) => {
1312 if !text.is_empty() {
1313 request_message
1314 .content
1315 .push(MessageContent::Text(text.into()));
1316 }
1317 }
1318 MessageSegment::Thinking { text, signature } => {
1319 if !text.is_empty() {
1320 request_message.content.push(MessageContent::Thinking {
1321 text: text.into(),
1322 signature: signature.clone(),
1323 });
1324 }
1325 }
1326 MessageSegment::RedactedThinking(data) => {
1327 request_message
1328 .content
1329 .push(MessageContent::RedactedThinking(data.clone()));
1330 }
1331 };
1332 }
1333
1334 let mut cache_message = true;
1335 let mut tool_results_message = LanguageModelRequestMessage {
1336 role: Role::User,
1337 content: Vec::new(),
1338 cache: false,
1339 };
1340 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1341 if let Some(tool_result) = tool_result {
1342 request_message
1343 .content
1344 .push(MessageContent::ToolUse(tool_use.clone()));
1345 tool_results_message
1346 .content
1347 .push(MessageContent::ToolResult(LanguageModelToolResult {
1348 tool_use_id: tool_use.id.clone(),
1349 tool_name: tool_result.tool_name.clone(),
1350 is_error: tool_result.is_error,
1351 content: if tool_result.content.is_empty() {
1352 // Surprisingly, the API fails if we return an empty string here.
1353 // It thinks we are sending a tool use without a tool result.
1354 "<Tool returned an empty string>".into()
1355 } else {
1356 tool_result.content.clone()
1357 },
1358 output: None,
1359 }));
1360 } else {
1361 cache_message = false;
1362 log::debug!(
1363 "skipped tool use {:?} because it is still pending",
1364 tool_use
1365 );
1366 }
1367 }
1368
1369 if cache_message {
1370 message_ix_to_cache = Some(request.messages.len());
1371 }
1372 request.messages.push(request_message);
1373
1374 if !tool_results_message.content.is_empty() {
1375 if cache_message {
1376 message_ix_to_cache = Some(request.messages.len());
1377 }
1378 request.messages.push(tool_results_message);
1379 }
1380 }
1381
1382 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1383 if let Some(message_ix_to_cache) = message_ix_to_cache {
1384 request.messages[message_ix_to_cache].cache = true;
1385 }
1386
1387 request.tools = available_tools;
1388 request.mode = if model.supports_max_mode() {
1389 Some(self.completion_mode.into())
1390 } else {
1391 Some(CompletionMode::Normal.into())
1392 };
1393
1394 request
1395 }
1396
1397 fn to_summarize_request(
1398 &self,
1399 model: &Arc<dyn LanguageModel>,
1400 intent: CompletionIntent,
1401 added_user_message: String,
1402 cx: &App,
1403 ) -> LanguageModelRequest {
1404 let mut request = LanguageModelRequest {
1405 thread_id: None,
1406 prompt_id: None,
1407 intent: Some(intent),
1408 mode: None,
1409 messages: vec![],
1410 tools: Vec::new(),
1411 tool_choice: None,
1412 stop: Vec::new(),
1413 temperature: AgentSettings::temperature_for_model(model, cx),
1414 };
1415
1416 for message in &self.messages {
1417 let mut request_message = LanguageModelRequestMessage {
1418 role: message.role,
1419 content: Vec::new(),
1420 cache: false,
1421 };
1422
1423 for segment in &message.segments {
1424 match segment {
1425 MessageSegment::Text(text) => request_message
1426 .content
1427 .push(MessageContent::Text(text.clone())),
1428 MessageSegment::Thinking { .. } => {}
1429 MessageSegment::RedactedThinking(_) => {}
1430 }
1431 }
1432
1433 if request_message.content.is_empty() {
1434 continue;
1435 }
1436
1437 request.messages.push(request_message);
1438 }
1439
1440 request.messages.push(LanguageModelRequestMessage {
1441 role: Role::User,
1442 content: vec![MessageContent::Text(added_user_message)],
1443 cache: false,
1444 });
1445
1446 request
1447 }
1448
1449 pub fn stream_completion(
1450 &mut self,
1451 request: LanguageModelRequest,
1452 model: Arc<dyn LanguageModel>,
1453 window: Option<AnyWindowHandle>,
1454 cx: &mut Context<Self>,
1455 ) {
1456 self.tool_use_limit_reached = false;
1457
1458 let pending_completion_id = post_inc(&mut self.completion_count);
1459 let mut request_callback_parameters = if self.request_callback.is_some() {
1460 Some((request.clone(), Vec::new()))
1461 } else {
1462 None
1463 };
1464 let prompt_id = self.last_prompt_id.clone();
1465 let tool_use_metadata = ToolUseMetadata {
1466 model: model.clone(),
1467 thread_id: self.id.clone(),
1468 prompt_id: prompt_id.clone(),
1469 };
1470
1471 self.last_received_chunk_at = Some(Instant::now());
1472
1473 let task = cx.spawn(async move |thread, cx| {
1474 let stream_completion_future = model.stream_completion(request, &cx);
1475 let initial_token_usage =
1476 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1477 let stream_completion = async {
1478 let mut events = stream_completion_future.await?;
1479
1480 let mut stop_reason = StopReason::EndTurn;
1481 let mut current_token_usage = TokenUsage::default();
1482
1483 thread
1484 .update(cx, |_thread, cx| {
1485 cx.emit(ThreadEvent::NewRequest);
1486 })
1487 .ok();
1488
1489 let mut request_assistant_message_id = None;
1490
1491 while let Some(event) = events.next().await {
1492 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1493 response_events
1494 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1495 }
1496
1497 thread.update(cx, |thread, cx| {
1498 let event = match event {
1499 Ok(event) => event,
1500 Err(error) => {
1501 match error {
1502 LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
1503 anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
1504 }
1505 LanguageModelCompletionError::Overloaded => {
1506 anyhow::bail!(LanguageModelKnownError::Overloaded);
1507 }
1508 LanguageModelCompletionError::ApiInternalServerError =>{
1509 anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
1510 }
1511 LanguageModelCompletionError::PromptTooLarge { tokens } => {
1512 let tokens = tokens.unwrap_or_else(|| {
1513 // We didn't get an exact token count from the API, so fall back on our estimate.
1514 thread.total_token_usage()
1515 .map(|usage| usage.total)
1516 .unwrap_or(0)
1517 // We know the context window was exceeded in practice, so if our estimate was
1518 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1519 .max(model.max_token_count().saturating_add(1))
1520 });
1521
1522 anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
1523 }
1524 LanguageModelCompletionError::ApiReadResponseError(io_error) => {
1525 anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
1526 }
1527 LanguageModelCompletionError::UnknownResponseFormat(error) => {
1528 anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
1529 }
1530 LanguageModelCompletionError::HttpResponseError { status, ref body } => {
1531 if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
1532 anyhow::bail!(known_error);
1533 } else {
1534 return Err(error.into());
1535 }
1536 }
1537 LanguageModelCompletionError::DeserializeResponse(error) => {
1538 anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
1539 }
1540 LanguageModelCompletionError::BadInputJson {
1541 id,
1542 tool_name,
1543 raw_input: invalid_input_json,
1544 json_parse_error,
1545 } => {
1546 thread.receive_invalid_tool_json(
1547 id,
1548 tool_name,
1549 invalid_input_json,
1550 json_parse_error,
1551 window,
1552 cx,
1553 );
1554 return Ok(());
1555 }
1556 // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
1557 err @ LanguageModelCompletionError::BadRequestFormat |
1558 err @ LanguageModelCompletionError::AuthenticationError |
1559 err @ LanguageModelCompletionError::PermissionError |
1560 err @ LanguageModelCompletionError::ApiEndpointNotFound |
1561 err @ LanguageModelCompletionError::SerializeRequest(_) |
1562 err @ LanguageModelCompletionError::BuildRequestBody(_) |
1563 err @ LanguageModelCompletionError::HttpSend(_) => {
1564 anyhow::bail!(err);
1565 }
1566 LanguageModelCompletionError::Other(error) => {
1567 return Err(error);
1568 }
1569 }
1570 }
1571 };
1572
1573 match event {
1574 LanguageModelCompletionEvent::StartMessage { .. } => {
1575 request_assistant_message_id =
1576 Some(thread.insert_assistant_message(
1577 vec![MessageSegment::Text(String::new())],
1578 cx,
1579 ));
1580 }
1581 LanguageModelCompletionEvent::Stop(reason) => {
1582 stop_reason = reason;
1583 }
1584 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1585 thread.update_token_usage_at_last_message(token_usage);
1586 thread.cumulative_token_usage = thread.cumulative_token_usage
1587 + token_usage
1588 - current_token_usage;
1589 current_token_usage = token_usage;
1590 }
1591 LanguageModelCompletionEvent::Text(chunk) => {
1592 thread.received_chunk();
1593
1594 cx.emit(ThreadEvent::ReceivedTextChunk);
1595 if let Some(last_message) = thread.messages.last_mut() {
1596 if last_message.role == Role::Assistant
1597 && !thread.tool_use.has_tool_results(last_message.id)
1598 {
1599 last_message.push_text(&chunk);
1600 cx.emit(ThreadEvent::StreamedAssistantText(
1601 last_message.id,
1602 chunk,
1603 ));
1604 } else {
1605 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1606 // of a new Assistant response.
1607 //
1608 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1609 // will result in duplicating the text of the chunk in the rendered Markdown.
1610 request_assistant_message_id =
1611 Some(thread.insert_assistant_message(
1612 vec![MessageSegment::Text(chunk.to_string())],
1613 cx,
1614 ));
1615 };
1616 }
1617 }
1618 LanguageModelCompletionEvent::Thinking {
1619 text: chunk,
1620 signature,
1621 } => {
1622 thread.received_chunk();
1623
1624 if let Some(last_message) = thread.messages.last_mut() {
1625 if last_message.role == Role::Assistant
1626 && !thread.tool_use.has_tool_results(last_message.id)
1627 {
1628 last_message.push_thinking(&chunk, signature);
1629 cx.emit(ThreadEvent::StreamedAssistantThinking(
1630 last_message.id,
1631 chunk,
1632 ));
1633 } else {
1634 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1635 // of a new Assistant response.
1636 //
1637 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1638 // will result in duplicating the text of the chunk in the rendered Markdown.
1639 request_assistant_message_id =
1640 Some(thread.insert_assistant_message(
1641 vec![MessageSegment::Thinking {
1642 text: chunk.to_string(),
1643 signature,
1644 }],
1645 cx,
1646 ));
1647 };
1648 }
1649 }
1650 LanguageModelCompletionEvent::RedactedThinking {
1651 data
1652 } => {
1653 thread.received_chunk();
1654
1655 if let Some(last_message) = thread.messages.last_mut() {
1656 if last_message.role == Role::Assistant
1657 && !thread.tool_use.has_tool_results(last_message.id)
1658 {
1659 last_message.push_redacted_thinking(data);
1660 } else {
1661 request_assistant_message_id =
1662 Some(thread.insert_assistant_message(
1663 vec![MessageSegment::RedactedThinking(data)],
1664 cx,
1665 ));
1666 };
1667 }
1668 }
1669 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1670 let last_assistant_message_id = request_assistant_message_id
1671 .unwrap_or_else(|| {
1672 let new_assistant_message_id =
1673 thread.insert_assistant_message(vec![], cx);
1674 request_assistant_message_id =
1675 Some(new_assistant_message_id);
1676 new_assistant_message_id
1677 });
1678
1679 let tool_use_id = tool_use.id.clone();
1680 let streamed_input = if tool_use.is_input_complete {
1681 None
1682 } else {
1683 Some((&tool_use.input).clone())
1684 };
1685
1686 let ui_text = thread.tool_use.request_tool_use(
1687 last_assistant_message_id,
1688 tool_use,
1689 tool_use_metadata.clone(),
1690 cx,
1691 );
1692
1693 if let Some(input) = streamed_input {
1694 cx.emit(ThreadEvent::StreamedToolUse {
1695 tool_use_id,
1696 ui_text,
1697 input,
1698 });
1699 }
1700 }
1701 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1702 if let Some(completion) = thread
1703 .pending_completions
1704 .iter_mut()
1705 .find(|completion| completion.id == pending_completion_id)
1706 {
1707 match status_update {
1708 CompletionRequestStatus::Queued {
1709 position,
1710 } => {
1711 completion.queue_state = QueueState::Queued { position };
1712 }
1713 CompletionRequestStatus::Started => {
1714 completion.queue_state = QueueState::Started;
1715 }
1716 CompletionRequestStatus::Failed {
1717 code, message, request_id
1718 } => {
1719 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1720 }
1721 CompletionRequestStatus::UsageUpdated {
1722 amount, limit
1723 } => {
1724 thread.update_model_request_usage(amount as u32, limit, cx);
1725 }
1726 CompletionRequestStatus::ToolUseLimitReached => {
1727 thread.tool_use_limit_reached = true;
1728 cx.emit(ThreadEvent::ToolUseLimitReached);
1729 }
1730 }
1731 }
1732 }
1733 }
1734
1735 thread.touch_updated_at();
1736 cx.emit(ThreadEvent::StreamedCompletion);
1737 cx.notify();
1738
1739 thread.auto_capture_telemetry(cx);
1740 Ok(())
1741 })??;
1742
1743 smol::future::yield_now().await;
1744 }
1745
1746 thread.update(cx, |thread, cx| {
1747 thread.last_received_chunk_at = None;
1748 thread
1749 .pending_completions
1750 .retain(|completion| completion.id != pending_completion_id);
1751
1752 // If there is a response without tool use, summarize the message. Otherwise,
1753 // allow two tool uses before summarizing.
1754 if matches!(thread.summary, ThreadSummary::Pending)
1755 && thread.messages.len() >= 2
1756 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1757 {
1758 thread.summarize(cx);
1759 }
1760 })?;
1761
1762 anyhow::Ok(stop_reason)
1763 };
1764
1765 let result = stream_completion.await;
1766
1767 thread
1768 .update(cx, |thread, cx| {
1769 thread.finalize_pending_checkpoint(cx);
1770 match result.as_ref() {
1771 Ok(stop_reason) => match stop_reason {
1772 StopReason::ToolUse => {
1773 let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
1774 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1775 }
1776 StopReason::EndTurn | StopReason::MaxTokens => {
1777 thread.project.update(cx, |project, cx| {
1778 project.set_agent_location(None, cx);
1779 });
1780 }
1781 StopReason::Refusal => {
1782 thread.project.update(cx, |project, cx| {
1783 project.set_agent_location(None, cx);
1784 });
1785
1786 // Remove the turn that was refused.
1787 //
1788 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1789 {
1790 let mut messages_to_remove = Vec::new();
1791
1792 for (ix, message) in thread.messages.iter().enumerate().rev() {
1793 messages_to_remove.push(message.id);
1794
1795 if message.role == Role::User {
1796 if ix == 0 {
1797 break;
1798 }
1799
1800 if let Some(prev_message) = thread.messages.get(ix - 1) {
1801 if prev_message.role == Role::Assistant {
1802 break;
1803 }
1804 }
1805 }
1806 }
1807
1808 for message_id in messages_to_remove {
1809 thread.delete_message(message_id, cx);
1810 }
1811 }
1812
1813 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1814 header: "Language model refusal".into(),
1815 message: "Model refused to generate content for safety reasons.".into(),
1816 }));
1817 }
1818 },
1819 Err(error) => {
1820 thread.project.update(cx, |project, cx| {
1821 project.set_agent_location(None, cx);
1822 });
1823
1824 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1825 let error_message = error
1826 .chain()
1827 .map(|err| err.to_string())
1828 .collect::<Vec<_>>()
1829 .join("\n");
1830 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1831 header: "Error interacting with language model".into(),
1832 message: SharedString::from(error_message.clone()),
1833 }));
1834 }
1835
1836 if error.is::<PaymentRequiredError>() {
1837 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1838 } else if let Some(error) =
1839 error.downcast_ref::<ModelRequestLimitReachedError>()
1840 {
1841 cx.emit(ThreadEvent::ShowError(
1842 ThreadError::ModelRequestLimitReached { plan: error.plan },
1843 ));
1844 } else if let Some(known_error) =
1845 error.downcast_ref::<LanguageModelKnownError>()
1846 {
1847 match known_error {
1848 LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
1849 thread.exceeded_window_error = Some(ExceededWindowError {
1850 model_id: model.id(),
1851 token_count: *tokens,
1852 });
1853 cx.notify();
1854 }
1855 LanguageModelKnownError::RateLimitExceeded { .. } => {
1856 // In the future we will report the error to the user, wait retry_after, and then retry.
1857 emit_generic_error(error, cx);
1858 }
1859 LanguageModelKnownError::Overloaded => {
1860 // In the future we will wait and then retry, up to N times.
1861 emit_generic_error(error, cx);
1862 }
1863 LanguageModelKnownError::ApiInternalServerError => {
1864 // In the future we will retry the request, but only once.
1865 emit_generic_error(error, cx);
1866 }
1867 LanguageModelKnownError::ReadResponseError(_) |
1868 LanguageModelKnownError::DeserializeResponse(_) |
1869 LanguageModelKnownError::UnknownResponseFormat(_) => {
1870 // In the future we will attempt to re-roll response, but only once
1871 emit_generic_error(error, cx);
1872 }
1873 }
1874 } else {
1875 emit_generic_error(error, cx);
1876 }
1877
1878 thread.cancel_last_completion(window, cx);
1879 }
1880 }
1881
1882 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1883
1884 if let Some((request_callback, (request, response_events))) = thread
1885 .request_callback
1886 .as_mut()
1887 .zip(request_callback_parameters.as_ref())
1888 {
1889 request_callback(request, response_events);
1890 }
1891
1892 thread.auto_capture_telemetry(cx);
1893
1894 if let Ok(initial_usage) = initial_token_usage {
1895 let usage = thread.cumulative_token_usage - initial_usage;
1896
1897 telemetry::event!(
1898 "Assistant Thread Completion",
1899 thread_id = thread.id().to_string(),
1900 prompt_id = prompt_id,
1901 model = model.telemetry_id(),
1902 model_provider = model.provider_id().to_string(),
1903 input_tokens = usage.input_tokens,
1904 output_tokens = usage.output_tokens,
1905 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1906 cache_read_input_tokens = usage.cache_read_input_tokens,
1907 );
1908 }
1909 })
1910 .ok();
1911 });
1912
1913 self.pending_completions.push(PendingCompletion {
1914 id: pending_completion_id,
1915 queue_state: QueueState::Sending,
1916 _task: task,
1917 });
1918 }
1919
1920 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1921 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1922 println!("No thread summary model");
1923 return;
1924 };
1925
1926 if !model.provider.is_authenticated(cx) {
1927 return;
1928 }
1929
1930 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1931
1932 let request = self.to_summarize_request(
1933 &model.model,
1934 CompletionIntent::ThreadSummarization,
1935 added_user_message.into(),
1936 cx,
1937 );
1938
1939 self.summary = ThreadSummary::Generating;
1940
1941 self.pending_summary = cx.spawn(async move |this, cx| {
1942 let result = async {
1943 let mut messages = model.model.stream_completion(request, &cx).await?;
1944
1945 let mut new_summary = String::new();
1946 while let Some(event) = messages.next().await {
1947 let Ok(event) = event else {
1948 continue;
1949 };
1950 let text = match event {
1951 LanguageModelCompletionEvent::Text(text) => text,
1952 LanguageModelCompletionEvent::StatusUpdate(
1953 CompletionRequestStatus::UsageUpdated { amount, limit },
1954 ) => {
1955 this.update(cx, |thread, cx| {
1956 thread.update_model_request_usage(amount as u32, limit, cx);
1957 })?;
1958 continue;
1959 }
1960 _ => continue,
1961 };
1962
1963 let mut lines = text.lines();
1964 new_summary.extend(lines.next());
1965
1966 // Stop if the LLM generated multiple lines.
1967 if lines.next().is_some() {
1968 break;
1969 }
1970 }
1971
1972 anyhow::Ok(new_summary)
1973 }
1974 .await;
1975
1976 this.update(cx, |this, cx| {
1977 match result {
1978 Ok(new_summary) => {
1979 if new_summary.is_empty() {
1980 this.summary = ThreadSummary::Error;
1981 } else {
1982 this.summary = ThreadSummary::Ready(new_summary.into());
1983 }
1984 }
1985 Err(err) => {
1986 this.summary = ThreadSummary::Error;
1987 log::error!("Failed to generate thread summary: {}", err);
1988 }
1989 }
1990 cx.emit(ThreadEvent::SummaryGenerated);
1991 })
1992 .log_err()?;
1993
1994 Some(())
1995 });
1996 }
1997
1998 pub fn start_generating_detailed_summary_if_needed(
1999 &mut self,
2000 thread_store: WeakEntity<ThreadStore>,
2001 cx: &mut Context<Self>,
2002 ) {
2003 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
2004 return;
2005 };
2006
2007 match &*self.detailed_summary_rx.borrow() {
2008 DetailedSummaryState::Generating { message_id, .. }
2009 | DetailedSummaryState::Generated { message_id, .. }
2010 if *message_id == last_message_id =>
2011 {
2012 // Already up-to-date
2013 return;
2014 }
2015 _ => {}
2016 }
2017
2018 let Some(ConfiguredModel { model, provider }) =
2019 LanguageModelRegistry::read_global(cx).thread_summary_model()
2020 else {
2021 return;
2022 };
2023
2024 if !provider.is_authenticated(cx) {
2025 return;
2026 }
2027
2028 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2029
2030 let request = self.to_summarize_request(
2031 &model,
2032 CompletionIntent::ThreadContextSummarization,
2033 added_user_message.into(),
2034 cx,
2035 );
2036
2037 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2038 message_id: last_message_id,
2039 };
2040
2041 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2042 // be better to allow the old task to complete, but this would require logic for choosing
2043 // which result to prefer (the old task could complete after the new one, resulting in a
2044 // stale summary).
2045 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2046 let stream = model.stream_completion_text(request, &cx);
2047 let Some(mut messages) = stream.await.log_err() else {
2048 thread
2049 .update(cx, |thread, _cx| {
2050 *thread.detailed_summary_tx.borrow_mut() =
2051 DetailedSummaryState::NotGenerated;
2052 })
2053 .ok()?;
2054 return None;
2055 };
2056
2057 let mut new_detailed_summary = String::new();
2058
2059 while let Some(chunk) = messages.stream.next().await {
2060 if let Some(chunk) = chunk.log_err() {
2061 new_detailed_summary.push_str(&chunk);
2062 }
2063 }
2064
2065 thread
2066 .update(cx, |thread, _cx| {
2067 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2068 text: new_detailed_summary.into(),
2069 message_id: last_message_id,
2070 };
2071 })
2072 .ok()?;
2073
2074 // Save thread so its summary can be reused later
2075 if let Some(thread) = thread.upgrade() {
2076 if let Ok(Ok(save_task)) = cx.update(|cx| {
2077 thread_store
2078 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2079 }) {
2080 save_task.await.log_err();
2081 }
2082 }
2083
2084 Some(())
2085 });
2086 }
2087
2088 pub async fn wait_for_detailed_summary_or_text(
2089 this: &Entity<Self>,
2090 cx: &mut AsyncApp,
2091 ) -> Option<SharedString> {
2092 let mut detailed_summary_rx = this
2093 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2094 .ok()?;
2095 loop {
2096 match detailed_summary_rx.recv().await? {
2097 DetailedSummaryState::Generating { .. } => {}
2098 DetailedSummaryState::NotGenerated => {
2099 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2100 }
2101 DetailedSummaryState::Generated { text, .. } => return Some(text),
2102 }
2103 }
2104 }
2105
2106 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2107 self.detailed_summary_rx
2108 .borrow()
2109 .text()
2110 .unwrap_or_else(|| self.text().into())
2111 }
2112
2113 pub fn is_generating_detailed_summary(&self) -> bool {
2114 matches!(
2115 &*self.detailed_summary_rx.borrow(),
2116 DetailedSummaryState::Generating { .. }
2117 )
2118 }
2119
2120 pub fn use_pending_tools(
2121 &mut self,
2122 window: Option<AnyWindowHandle>,
2123 model: Arc<dyn LanguageModel>,
2124 cx: &mut Context<Self>,
2125 ) -> Vec<PendingToolUse> {
2126 self.auto_capture_telemetry(cx);
2127 let request =
2128 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2129 let pending_tool_uses = self
2130 .tool_use
2131 .pending_tool_uses()
2132 .into_iter()
2133 .filter(|tool_use| tool_use.status.is_idle())
2134 .cloned()
2135 .collect::<Vec<_>>();
2136
2137 for tool_use in pending_tool_uses.iter() {
2138 self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx);
2139 }
2140
2141 pending_tool_uses
2142 }
2143
2144 fn use_pending_tool(
2145 &mut self,
2146 tool_use: PendingToolUse,
2147 request: Arc<LanguageModelRequest>,
2148 model: Arc<dyn LanguageModel>,
2149 window: Option<AnyWindowHandle>,
2150 cx: &mut Context<Self>,
2151 ) {
2152 let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else {
2153 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2154 };
2155
2156 if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
2157 return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx);
2158 }
2159
2160 if tool.needs_confirmation(&tool_use.input, cx)
2161 && !AgentSettings::get_global(cx).always_allow_tool_actions
2162 {
2163 self.tool_use.confirm_tool_use(
2164 tool_use.id,
2165 tool_use.ui_text,
2166 tool_use.input,
2167 request,
2168 tool,
2169 );
2170 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2171 } else {
2172 self.run_tool(
2173 tool_use.id,
2174 tool_use.ui_text,
2175 tool_use.input,
2176 request,
2177 tool,
2178 model,
2179 window,
2180 cx,
2181 );
2182 }
2183 }
2184
2185 pub fn handle_hallucinated_tool_use(
2186 &mut self,
2187 tool_use_id: LanguageModelToolUseId,
2188 hallucinated_tool_name: Arc<str>,
2189 window: Option<AnyWindowHandle>,
2190 cx: &mut Context<Thread>,
2191 ) {
2192 let available_tools = self.profile.enabled_tools(cx);
2193
2194 let tool_list = available_tools
2195 .iter()
2196 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2197 .collect::<Vec<_>>()
2198 .join("\n");
2199
2200 let error_message = format!(
2201 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2202 hallucinated_tool_name, tool_list
2203 );
2204
2205 let pending_tool_use = self.tool_use.insert_tool_output(
2206 tool_use_id.clone(),
2207 hallucinated_tool_name,
2208 Err(anyhow!("Missing tool call: {error_message}")),
2209 self.configured_model.as_ref(),
2210 );
2211
2212 cx.emit(ThreadEvent::MissingToolUse {
2213 tool_use_id: tool_use_id.clone(),
2214 ui_text: error_message.into(),
2215 });
2216
2217 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2218 }
2219
2220 pub fn receive_invalid_tool_json(
2221 &mut self,
2222 tool_use_id: LanguageModelToolUseId,
2223 tool_name: Arc<str>,
2224 invalid_json: Arc<str>,
2225 error: String,
2226 window: Option<AnyWindowHandle>,
2227 cx: &mut Context<Thread>,
2228 ) {
2229 log::error!("The model returned invalid input JSON: {invalid_json}");
2230
2231 let pending_tool_use = self.tool_use.insert_tool_output(
2232 tool_use_id.clone(),
2233 tool_name,
2234 Err(anyhow!("Error parsing input JSON: {error}")),
2235 self.configured_model.as_ref(),
2236 );
2237 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2238 pending_tool_use.ui_text.clone()
2239 } else {
2240 log::error!(
2241 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2242 );
2243 format!("Unknown tool {}", tool_use_id).into()
2244 };
2245
2246 cx.emit(ThreadEvent::InvalidToolInput {
2247 tool_use_id: tool_use_id.clone(),
2248 ui_text,
2249 invalid_input_json: invalid_json,
2250 });
2251
2252 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2253 }
2254
2255 pub fn run_tool(
2256 &mut self,
2257 tool_use_id: LanguageModelToolUseId,
2258 ui_text: impl Into<SharedString>,
2259 input: serde_json::Value,
2260 request: Arc<LanguageModelRequest>,
2261 tool: Arc<dyn Tool>,
2262 model: Arc<dyn LanguageModel>,
2263 window: Option<AnyWindowHandle>,
2264 cx: &mut Context<Thread>,
2265 ) {
2266 let task =
2267 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2268 self.tool_use
2269 .run_pending_tool(tool_use_id, ui_text.into(), task);
2270 }
2271
2272 fn spawn_tool_use(
2273 &mut self,
2274 tool_use_id: LanguageModelToolUseId,
2275 request: Arc<LanguageModelRequest>,
2276 input: serde_json::Value,
2277 tool: Arc<dyn Tool>,
2278 model: Arc<dyn LanguageModel>,
2279 window: Option<AnyWindowHandle>,
2280 cx: &mut Context<Thread>,
2281 ) -> Task<()> {
2282 let tool_name: Arc<str> = tool.name().into();
2283
2284 let tool_result = tool.run(
2285 input,
2286 request,
2287 self.project.clone(),
2288 self.action_log.clone(),
2289 model,
2290 window,
2291 cx,
2292 );
2293
2294 // Store the card separately if it exists
2295 if let Some(card) = tool_result.card.clone() {
2296 self.tool_use
2297 .insert_tool_result_card(tool_use_id.clone(), card);
2298 }
2299
2300 cx.spawn({
2301 async move |thread: WeakEntity<Thread>, cx| {
2302 let output = tool_result.output.await;
2303
2304 thread
2305 .update(cx, |thread, cx| {
2306 let pending_tool_use = thread.tool_use.insert_tool_output(
2307 tool_use_id.clone(),
2308 tool_name,
2309 output,
2310 thread.configured_model.as_ref(),
2311 );
2312 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2313 })
2314 .ok();
2315 }
2316 })
2317 }
2318
2319 fn tool_finished(
2320 &mut self,
2321 tool_use_id: LanguageModelToolUseId,
2322 pending_tool_use: Option<PendingToolUse>,
2323 canceled: bool,
2324 window: Option<AnyWindowHandle>,
2325 cx: &mut Context<Self>,
2326 ) {
2327 if self.all_tools_finished() {
2328 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2329 if !canceled {
2330 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2331 }
2332 self.auto_capture_telemetry(cx);
2333 }
2334 }
2335
2336 cx.emit(ThreadEvent::ToolFinished {
2337 tool_use_id,
2338 pending_tool_use,
2339 });
2340 }
2341
2342 /// Cancels the last pending completion, if there are any pending.
2343 ///
2344 /// Returns whether a completion was canceled.
2345 pub fn cancel_last_completion(
2346 &mut self,
2347 window: Option<AnyWindowHandle>,
2348 cx: &mut Context<Self>,
2349 ) -> bool {
2350 let mut canceled = self.pending_completions.pop().is_some();
2351
2352 for pending_tool_use in self.tool_use.cancel_pending() {
2353 canceled = true;
2354 self.tool_finished(
2355 pending_tool_use.id.clone(),
2356 Some(pending_tool_use),
2357 true,
2358 window,
2359 cx,
2360 );
2361 }
2362
2363 if canceled {
2364 cx.emit(ThreadEvent::CompletionCanceled);
2365
2366 // When canceled, we always want to insert the checkpoint.
2367 // (We skip over finalize_pending_checkpoint, because it
2368 // would conclude we didn't have anything to insert here.)
2369 if let Some(checkpoint) = self.pending_checkpoint.take() {
2370 self.insert_checkpoint(checkpoint, cx);
2371 }
2372 } else {
2373 self.finalize_pending_checkpoint(cx);
2374 }
2375
2376 canceled
2377 }
2378
2379 /// Signals that any in-progress editing should be canceled.
2380 ///
2381 /// This method is used to notify listeners (like ActiveThread) that
2382 /// they should cancel any editing operations.
2383 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2384 cx.emit(ThreadEvent::CancelEditing);
2385 }
2386
2387 pub fn feedback(&self) -> Option<ThreadFeedback> {
2388 self.feedback
2389 }
2390
2391 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2392 self.message_feedback.get(&message_id).copied()
2393 }
2394
2395 pub fn report_message_feedback(
2396 &mut self,
2397 message_id: MessageId,
2398 feedback: ThreadFeedback,
2399 cx: &mut Context<Self>,
2400 ) -> Task<Result<()>> {
2401 if self.message_feedback.get(&message_id) == Some(&feedback) {
2402 return Task::ready(Ok(()));
2403 }
2404
2405 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2406 let serialized_thread = self.serialize(cx);
2407 let thread_id = self.id().clone();
2408 let client = self.project.read(cx).client();
2409
2410 let enabled_tool_names: Vec<String> = self
2411 .profile
2412 .enabled_tools(cx)
2413 .iter()
2414 .map(|tool| tool.name())
2415 .collect();
2416
2417 self.message_feedback.insert(message_id, feedback);
2418
2419 cx.notify();
2420
2421 let message_content = self
2422 .message(message_id)
2423 .map(|msg| msg.to_string())
2424 .unwrap_or_default();
2425
2426 cx.background_spawn(async move {
2427 let final_project_snapshot = final_project_snapshot.await;
2428 let serialized_thread = serialized_thread.await?;
2429 let thread_data =
2430 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2431
2432 let rating = match feedback {
2433 ThreadFeedback::Positive => "positive",
2434 ThreadFeedback::Negative => "negative",
2435 };
2436 telemetry::event!(
2437 "Assistant Thread Rated",
2438 rating,
2439 thread_id,
2440 enabled_tool_names,
2441 message_id = message_id.0,
2442 message_content,
2443 thread_data,
2444 final_project_snapshot
2445 );
2446 client.telemetry().flush_events().await;
2447
2448 Ok(())
2449 })
2450 }
2451
2452 pub fn report_feedback(
2453 &mut self,
2454 feedback: ThreadFeedback,
2455 cx: &mut Context<Self>,
2456 ) -> Task<Result<()>> {
2457 let last_assistant_message_id = self
2458 .messages
2459 .iter()
2460 .rev()
2461 .find(|msg| msg.role == Role::Assistant)
2462 .map(|msg| msg.id);
2463
2464 if let Some(message_id) = last_assistant_message_id {
2465 self.report_message_feedback(message_id, feedback, cx)
2466 } else {
2467 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2468 let serialized_thread = self.serialize(cx);
2469 let thread_id = self.id().clone();
2470 let client = self.project.read(cx).client();
2471 self.feedback = Some(feedback);
2472 cx.notify();
2473
2474 cx.background_spawn(async move {
2475 let final_project_snapshot = final_project_snapshot.await;
2476 let serialized_thread = serialized_thread.await?;
2477 let thread_data = serde_json::to_value(serialized_thread)
2478 .unwrap_or_else(|_| serde_json::Value::Null);
2479
2480 let rating = match feedback {
2481 ThreadFeedback::Positive => "positive",
2482 ThreadFeedback::Negative => "negative",
2483 };
2484 telemetry::event!(
2485 "Assistant Thread Rated",
2486 rating,
2487 thread_id,
2488 thread_data,
2489 final_project_snapshot
2490 );
2491 client.telemetry().flush_events().await;
2492
2493 Ok(())
2494 })
2495 }
2496 }
2497
2498 /// Create a snapshot of the current project state including git information and unsaved buffers.
2499 fn project_snapshot(
2500 project: Entity<Project>,
2501 cx: &mut Context<Self>,
2502 ) -> Task<Arc<ProjectSnapshot>> {
2503 let git_store = project.read(cx).git_store().clone();
2504 let worktree_snapshots: Vec<_> = project
2505 .read(cx)
2506 .visible_worktrees(cx)
2507 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2508 .collect();
2509
2510 cx.spawn(async move |_, cx| {
2511 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2512
2513 let mut unsaved_buffers = Vec::new();
2514 cx.update(|app_cx| {
2515 let buffer_store = project.read(app_cx).buffer_store();
2516 for buffer_handle in buffer_store.read(app_cx).buffers() {
2517 let buffer = buffer_handle.read(app_cx);
2518 if buffer.is_dirty() {
2519 if let Some(file) = buffer.file() {
2520 let path = file.path().to_string_lossy().to_string();
2521 unsaved_buffers.push(path);
2522 }
2523 }
2524 }
2525 })
2526 .ok();
2527
2528 Arc::new(ProjectSnapshot {
2529 worktree_snapshots,
2530 unsaved_buffer_paths: unsaved_buffers,
2531 timestamp: Utc::now(),
2532 })
2533 })
2534 }
2535
2536 fn worktree_snapshot(
2537 worktree: Entity<project::Worktree>,
2538 git_store: Entity<GitStore>,
2539 cx: &App,
2540 ) -> Task<WorktreeSnapshot> {
2541 cx.spawn(async move |cx| {
2542 // Get worktree path and snapshot
2543 let worktree_info = cx.update(|app_cx| {
2544 let worktree = worktree.read(app_cx);
2545 let path = worktree.abs_path().to_string_lossy().to_string();
2546 let snapshot = worktree.snapshot();
2547 (path, snapshot)
2548 });
2549
2550 let Ok((worktree_path, _snapshot)) = worktree_info else {
2551 return WorktreeSnapshot {
2552 worktree_path: String::new(),
2553 git_state: None,
2554 };
2555 };
2556
2557 let git_state = git_store
2558 .update(cx, |git_store, cx| {
2559 git_store
2560 .repositories()
2561 .values()
2562 .find(|repo| {
2563 repo.read(cx)
2564 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2565 .is_some()
2566 })
2567 .cloned()
2568 })
2569 .ok()
2570 .flatten()
2571 .map(|repo| {
2572 repo.update(cx, |repo, _| {
2573 let current_branch =
2574 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2575 repo.send_job(None, |state, _| async move {
2576 let RepositoryState::Local { backend, .. } = state else {
2577 return GitState {
2578 remote_url: None,
2579 head_sha: None,
2580 current_branch,
2581 diff: None,
2582 };
2583 };
2584
2585 let remote_url = backend.remote_url("origin");
2586 let head_sha = backend.head_sha().await;
2587 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2588
2589 GitState {
2590 remote_url,
2591 head_sha,
2592 current_branch,
2593 diff,
2594 }
2595 })
2596 })
2597 });
2598
2599 let git_state = match git_state {
2600 Some(git_state) => match git_state.ok() {
2601 Some(git_state) => git_state.await.ok(),
2602 None => None,
2603 },
2604 None => None,
2605 };
2606
2607 WorktreeSnapshot {
2608 worktree_path,
2609 git_state,
2610 }
2611 })
2612 }
2613
2614 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2615 let mut markdown = Vec::new();
2616
2617 let summary = self.summary().or_default();
2618 writeln!(markdown, "# {summary}\n")?;
2619
2620 for message in self.messages() {
2621 writeln!(
2622 markdown,
2623 "## {role}\n",
2624 role = match message.role {
2625 Role::User => "User",
2626 Role::Assistant => "Agent",
2627 Role::System => "System",
2628 }
2629 )?;
2630
2631 if !message.loaded_context.text.is_empty() {
2632 writeln!(markdown, "{}", message.loaded_context.text)?;
2633 }
2634
2635 if !message.loaded_context.images.is_empty() {
2636 writeln!(
2637 markdown,
2638 "\n{} images attached as context.\n",
2639 message.loaded_context.images.len()
2640 )?;
2641 }
2642
2643 for segment in &message.segments {
2644 match segment {
2645 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2646 MessageSegment::Thinking { text, .. } => {
2647 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2648 }
2649 MessageSegment::RedactedThinking(_) => {}
2650 }
2651 }
2652
2653 for tool_use in self.tool_uses_for_message(message.id, cx) {
2654 writeln!(
2655 markdown,
2656 "**Use Tool: {} ({})**",
2657 tool_use.name, tool_use.id
2658 )?;
2659 writeln!(markdown, "```json")?;
2660 writeln!(
2661 markdown,
2662 "{}",
2663 serde_json::to_string_pretty(&tool_use.input)?
2664 )?;
2665 writeln!(markdown, "```")?;
2666 }
2667
2668 for tool_result in self.tool_results_for_message(message.id) {
2669 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2670 if tool_result.is_error {
2671 write!(markdown, " (Error)")?;
2672 }
2673
2674 writeln!(markdown, "**\n")?;
2675 match &tool_result.content {
2676 LanguageModelToolResultContent::Text(text) => {
2677 writeln!(markdown, "{text}")?;
2678 }
2679 LanguageModelToolResultContent::Image(image) => {
2680 writeln!(markdown, "", image.source)?;
2681 }
2682 }
2683
2684 if let Some(output) = tool_result.output.as_ref() {
2685 writeln!(
2686 markdown,
2687 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2688 serde_json::to_string_pretty(output)?
2689 )?;
2690 }
2691 }
2692 }
2693
2694 Ok(String::from_utf8_lossy(&markdown).to_string())
2695 }
2696
2697 pub fn keep_edits_in_range(
2698 &mut self,
2699 buffer: Entity<language::Buffer>,
2700 buffer_range: Range<language::Anchor>,
2701 cx: &mut Context<Self>,
2702 ) {
2703 self.action_log.update(cx, |action_log, cx| {
2704 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2705 });
2706 }
2707
2708 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2709 self.action_log
2710 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2711 }
2712
2713 pub fn reject_edits_in_ranges(
2714 &mut self,
2715 buffer: Entity<language::Buffer>,
2716 buffer_ranges: Vec<Range<language::Anchor>>,
2717 cx: &mut Context<Self>,
2718 ) -> Task<Result<()>> {
2719 self.action_log.update(cx, |action_log, cx| {
2720 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2721 })
2722 }
2723
2724 pub fn action_log(&self) -> &Entity<ActionLog> {
2725 &self.action_log
2726 }
2727
2728 pub fn project(&self) -> &Entity<Project> {
2729 &self.project
2730 }
2731
2732 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2733 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2734 return;
2735 }
2736
2737 let now = Instant::now();
2738 if let Some(last) = self.last_auto_capture_at {
2739 if now.duration_since(last).as_secs() < 10 {
2740 return;
2741 }
2742 }
2743
2744 self.last_auto_capture_at = Some(now);
2745
2746 let thread_id = self.id().clone();
2747 let github_login = self
2748 .project
2749 .read(cx)
2750 .user_store()
2751 .read(cx)
2752 .current_user()
2753 .map(|user| user.github_login.clone());
2754 let client = self.project.read(cx).client();
2755 let serialize_task = self.serialize(cx);
2756
2757 cx.background_executor()
2758 .spawn(async move {
2759 if let Ok(serialized_thread) = serialize_task.await {
2760 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2761 telemetry::event!(
2762 "Agent Thread Auto-Captured",
2763 thread_id = thread_id.to_string(),
2764 thread_data = thread_data,
2765 auto_capture_reason = "tracked_user",
2766 github_login = github_login
2767 );
2768
2769 client.telemetry().flush_events().await;
2770 }
2771 }
2772 })
2773 .detach();
2774 }
2775
2776 pub fn cumulative_token_usage(&self) -> TokenUsage {
2777 self.cumulative_token_usage
2778 }
2779
2780 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2781 let Some(model) = self.configured_model.as_ref() else {
2782 return TotalTokenUsage::default();
2783 };
2784
2785 let max = model.model.max_token_count();
2786
2787 let index = self
2788 .messages
2789 .iter()
2790 .position(|msg| msg.id == message_id)
2791 .unwrap_or(0);
2792
2793 if index == 0 {
2794 return TotalTokenUsage { total: 0, max };
2795 }
2796
2797 let token_usage = &self
2798 .request_token_usage
2799 .get(index - 1)
2800 .cloned()
2801 .unwrap_or_default();
2802
2803 TotalTokenUsage {
2804 total: token_usage.total_tokens(),
2805 max,
2806 }
2807 }
2808
2809 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2810 let model = self.configured_model.as_ref()?;
2811
2812 let max = model.model.max_token_count();
2813
2814 if let Some(exceeded_error) = &self.exceeded_window_error {
2815 if model.model.id() == exceeded_error.model_id {
2816 return Some(TotalTokenUsage {
2817 total: exceeded_error.token_count,
2818 max,
2819 });
2820 }
2821 }
2822
2823 let total = self
2824 .token_usage_at_last_message()
2825 .unwrap_or_default()
2826 .total_tokens();
2827
2828 Some(TotalTokenUsage { total, max })
2829 }
2830
2831 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2832 self.request_token_usage
2833 .get(self.messages.len().saturating_sub(1))
2834 .or_else(|| self.request_token_usage.last())
2835 .cloned()
2836 }
2837
2838 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2839 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2840 self.request_token_usage
2841 .resize(self.messages.len(), placeholder);
2842
2843 if let Some(last) = self.request_token_usage.last_mut() {
2844 *last = token_usage;
2845 }
2846 }
2847
2848 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
2849 self.project.update(cx, |project, cx| {
2850 project.user_store().update(cx, |user_store, cx| {
2851 user_store.update_model_request_usage(
2852 ModelRequestUsage(RequestUsage {
2853 amount: amount as i32,
2854 limit,
2855 }),
2856 cx,
2857 )
2858 })
2859 });
2860 }
2861
2862 pub fn deny_tool_use(
2863 &mut self,
2864 tool_use_id: LanguageModelToolUseId,
2865 tool_name: Arc<str>,
2866 window: Option<AnyWindowHandle>,
2867 cx: &mut Context<Self>,
2868 ) {
2869 let err = Err(anyhow::anyhow!(
2870 "Permission to run tool action denied by user"
2871 ));
2872
2873 self.tool_use.insert_tool_output(
2874 tool_use_id.clone(),
2875 tool_name,
2876 err,
2877 self.configured_model.as_ref(),
2878 );
2879 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2880 }
2881}
2882
2883#[derive(Debug, Clone, Error)]
2884pub enum ThreadError {
2885 #[error("Payment required")]
2886 PaymentRequired,
2887 #[error("Model request limit reached")]
2888 ModelRequestLimitReached { plan: Plan },
2889 #[error("Message {header}: {message}")]
2890 Message {
2891 header: SharedString,
2892 message: SharedString,
2893 },
2894}
2895
2896#[derive(Debug, Clone)]
2897pub enum ThreadEvent {
2898 ShowError(ThreadError),
2899 StreamedCompletion,
2900 ReceivedTextChunk,
2901 NewRequest,
2902 StreamedAssistantText(MessageId, String),
2903 StreamedAssistantThinking(MessageId, String),
2904 StreamedToolUse {
2905 tool_use_id: LanguageModelToolUseId,
2906 ui_text: Arc<str>,
2907 input: serde_json::Value,
2908 },
2909 MissingToolUse {
2910 tool_use_id: LanguageModelToolUseId,
2911 ui_text: Arc<str>,
2912 },
2913 InvalidToolInput {
2914 tool_use_id: LanguageModelToolUseId,
2915 ui_text: Arc<str>,
2916 invalid_input_json: Arc<str>,
2917 },
2918 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2919 MessageAdded(MessageId),
2920 MessageEdited(MessageId),
2921 MessageDeleted(MessageId),
2922 SummaryGenerated,
2923 SummaryChanged,
2924 UsePendingTools {
2925 tool_uses: Vec<PendingToolUse>,
2926 },
2927 ToolFinished {
2928 #[allow(unused)]
2929 tool_use_id: LanguageModelToolUseId,
2930 /// The pending tool use that corresponds to this tool.
2931 pending_tool_use: Option<PendingToolUse>,
2932 },
2933 CheckpointChanged,
2934 ToolConfirmationNeeded,
2935 ToolUseLimitReached,
2936 CancelEditing,
2937 CompletionCanceled,
2938 ProfileChanged,
2939}
2940
2941impl EventEmitter<ThreadEvent> for Thread {}
2942
2943struct PendingCompletion {
2944 id: usize,
2945 queue_state: QueueState,
2946 _task: Task<()>,
2947}
2948
2949/// Resolves tool name conflicts by ensuring all tool names are unique.
2950///
2951/// When multiple tools have the same name, this function applies the following rules:
2952/// 1. Native tools always keep their original name
2953/// 2. Context server tools get prefixed with their server ID and an underscore
2954/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
2955/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
2956///
2957/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
2958fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
2959 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
2960 let mut tool_name = tool.name();
2961 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2962 tool_name
2963 }
2964
2965 const MAX_TOOL_NAME_LENGTH: usize = 64;
2966
2967 let mut duplicated_tool_names = HashSet::default();
2968 let mut seen_tool_names = HashSet::default();
2969 for tool in tools {
2970 let tool_name = resolve_tool_name(tool);
2971 if seen_tool_names.contains(&tool_name) {
2972 debug_assert!(
2973 tool.source() != assistant_tool::ToolSource::Native,
2974 "There are two built-in tools with the same name: {}",
2975 tool_name
2976 );
2977 duplicated_tool_names.insert(tool_name);
2978 } else {
2979 seen_tool_names.insert(tool_name);
2980 }
2981 }
2982
2983 if duplicated_tool_names.is_empty() {
2984 return tools
2985 .into_iter()
2986 .map(|tool| (resolve_tool_name(tool), tool.clone()))
2987 .collect();
2988 }
2989
2990 tools
2991 .into_iter()
2992 .filter_map(|tool| {
2993 let mut tool_name = resolve_tool_name(tool);
2994 if !duplicated_tool_names.contains(&tool_name) {
2995 return Some((tool_name, tool.clone()));
2996 }
2997 match tool.source() {
2998 assistant_tool::ToolSource::Native => {
2999 // Built-in tools always keep their original name
3000 Some((tool_name, tool.clone()))
3001 }
3002 assistant_tool::ToolSource::ContextServer { id } => {
3003 // Context server tools are prefixed with the context server ID, and truncated if necessary
3004 tool_name.insert(0, '_');
3005 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
3006 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
3007 let mut id = id.to_string();
3008 id.truncate(len);
3009 tool_name.insert_str(0, &id);
3010 } else {
3011 tool_name.insert_str(0, &id);
3012 }
3013
3014 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
3015
3016 if seen_tool_names.contains(&tool_name) {
3017 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
3018 None
3019 } else {
3020 Some((tool_name, tool.clone()))
3021 }
3022 }
3023 }
3024 })
3025 .collect()
3026}
3027
3028#[cfg(test)]
3029mod tests {
3030 use super::*;
3031 use crate::{
3032 context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore,
3033 };
3034 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3035 use assistant_tool::ToolRegistry;
3036 use gpui::TestAppContext;
3037 use icons::IconName;
3038 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3039 use project::{FakeFs, Project};
3040 use prompt_store::PromptBuilder;
3041 use serde_json::json;
3042 use settings::{Settings, SettingsStore};
3043 use std::sync::Arc;
3044 use theme::ThemeSettings;
3045 use util::path;
3046 use workspace::Workspace;
3047
3048 #[gpui::test]
3049 async fn test_message_with_context(cx: &mut TestAppContext) {
3050 init_test_settings(cx);
3051
3052 let project = create_test_project(
3053 cx,
3054 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3055 )
3056 .await;
3057
3058 let (_workspace, _thread_store, thread, context_store, model) =
3059 setup_test_environment(cx, project.clone()).await;
3060
3061 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3062 .await
3063 .unwrap();
3064
3065 let context =
3066 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3067 let loaded_context = cx
3068 .update(|cx| load_context(vec![context], &project, &None, cx))
3069 .await;
3070
3071 // Insert user message with context
3072 let message_id = thread.update(cx, |thread, cx| {
3073 thread.insert_user_message(
3074 "Please explain this code",
3075 loaded_context,
3076 None,
3077 Vec::new(),
3078 cx,
3079 )
3080 });
3081
3082 // Check content and context in message object
3083 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3084
3085 // Use different path format strings based on platform for the test
3086 #[cfg(windows)]
3087 let path_part = r"test\code.rs";
3088 #[cfg(not(windows))]
3089 let path_part = "test/code.rs";
3090
3091 let expected_context = format!(
3092 r#"
3093<context>
3094The following items were attached by the user. They are up-to-date and don't need to be re-read.
3095
3096<files>
3097```rs {path_part}
3098fn main() {{
3099 println!("Hello, world!");
3100}}
3101```
3102</files>
3103</context>
3104"#
3105 );
3106
3107 assert_eq!(message.role, Role::User);
3108 assert_eq!(message.segments.len(), 1);
3109 assert_eq!(
3110 message.segments[0],
3111 MessageSegment::Text("Please explain this code".to_string())
3112 );
3113 assert_eq!(message.loaded_context.text, expected_context);
3114
3115 // Check message in request
3116 let request = thread.update(cx, |thread, cx| {
3117 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3118 });
3119
3120 assert_eq!(request.messages.len(), 2);
3121 let expected_full_message = format!("{}Please explain this code", expected_context);
3122 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3123 }
3124
3125 #[gpui::test]
3126 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3127 init_test_settings(cx);
3128
3129 let project = create_test_project(
3130 cx,
3131 json!({
3132 "file1.rs": "fn function1() {}\n",
3133 "file2.rs": "fn function2() {}\n",
3134 "file3.rs": "fn function3() {}\n",
3135 "file4.rs": "fn function4() {}\n",
3136 }),
3137 )
3138 .await;
3139
3140 let (_, _thread_store, thread, context_store, model) =
3141 setup_test_environment(cx, project.clone()).await;
3142
3143 // First message with context 1
3144 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3145 .await
3146 .unwrap();
3147 let new_contexts = context_store.update(cx, |store, cx| {
3148 store.new_context_for_thread(thread.read(cx), None)
3149 });
3150 assert_eq!(new_contexts.len(), 1);
3151 let loaded_context = cx
3152 .update(|cx| load_context(new_contexts, &project, &None, cx))
3153 .await;
3154 let message1_id = thread.update(cx, |thread, cx| {
3155 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3156 });
3157
3158 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3159 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3160 .await
3161 .unwrap();
3162 let new_contexts = context_store.update(cx, |store, cx| {
3163 store.new_context_for_thread(thread.read(cx), None)
3164 });
3165 assert_eq!(new_contexts.len(), 1);
3166 let loaded_context = cx
3167 .update(|cx| load_context(new_contexts, &project, &None, cx))
3168 .await;
3169 let message2_id = thread.update(cx, |thread, cx| {
3170 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3171 });
3172
3173 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3174 //
3175 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3176 .await
3177 .unwrap();
3178 let new_contexts = context_store.update(cx, |store, cx| {
3179 store.new_context_for_thread(thread.read(cx), None)
3180 });
3181 assert_eq!(new_contexts.len(), 1);
3182 let loaded_context = cx
3183 .update(|cx| load_context(new_contexts, &project, &None, cx))
3184 .await;
3185 let message3_id = thread.update(cx, |thread, cx| {
3186 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3187 });
3188
3189 // Check what contexts are included in each message
3190 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3191 (
3192 thread.message(message1_id).unwrap().clone(),
3193 thread.message(message2_id).unwrap().clone(),
3194 thread.message(message3_id).unwrap().clone(),
3195 )
3196 });
3197
3198 // First message should include context 1
3199 assert!(message1.loaded_context.text.contains("file1.rs"));
3200
3201 // Second message should include only context 2 (not 1)
3202 assert!(!message2.loaded_context.text.contains("file1.rs"));
3203 assert!(message2.loaded_context.text.contains("file2.rs"));
3204
3205 // Third message should include only context 3 (not 1 or 2)
3206 assert!(!message3.loaded_context.text.contains("file1.rs"));
3207 assert!(!message3.loaded_context.text.contains("file2.rs"));
3208 assert!(message3.loaded_context.text.contains("file3.rs"));
3209
3210 // Check entire request to make sure all contexts are properly included
3211 let request = thread.update(cx, |thread, cx| {
3212 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3213 });
3214
3215 // The request should contain all 3 messages
3216 assert_eq!(request.messages.len(), 4);
3217
3218 // Check that the contexts are properly formatted in each message
3219 assert!(request.messages[1].string_contents().contains("file1.rs"));
3220 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3221 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3222
3223 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3224 assert!(request.messages[2].string_contents().contains("file2.rs"));
3225 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3226
3227 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3228 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3229 assert!(request.messages[3].string_contents().contains("file3.rs"));
3230
3231 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3232 .await
3233 .unwrap();
3234 let new_contexts = context_store.update(cx, |store, cx| {
3235 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3236 });
3237 assert_eq!(new_contexts.len(), 3);
3238 let loaded_context = cx
3239 .update(|cx| load_context(new_contexts, &project, &None, cx))
3240 .await
3241 .loaded_context;
3242
3243 assert!(!loaded_context.text.contains("file1.rs"));
3244 assert!(loaded_context.text.contains("file2.rs"));
3245 assert!(loaded_context.text.contains("file3.rs"));
3246 assert!(loaded_context.text.contains("file4.rs"));
3247
3248 let new_contexts = context_store.update(cx, |store, cx| {
3249 // Remove file4.rs
3250 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3251 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3252 });
3253 assert_eq!(new_contexts.len(), 2);
3254 let loaded_context = cx
3255 .update(|cx| load_context(new_contexts, &project, &None, cx))
3256 .await
3257 .loaded_context;
3258
3259 assert!(!loaded_context.text.contains("file1.rs"));
3260 assert!(loaded_context.text.contains("file2.rs"));
3261 assert!(loaded_context.text.contains("file3.rs"));
3262 assert!(!loaded_context.text.contains("file4.rs"));
3263
3264 let new_contexts = context_store.update(cx, |store, cx| {
3265 // Remove file3.rs
3266 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3267 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3268 });
3269 assert_eq!(new_contexts.len(), 1);
3270 let loaded_context = cx
3271 .update(|cx| load_context(new_contexts, &project, &None, cx))
3272 .await
3273 .loaded_context;
3274
3275 assert!(!loaded_context.text.contains("file1.rs"));
3276 assert!(loaded_context.text.contains("file2.rs"));
3277 assert!(!loaded_context.text.contains("file3.rs"));
3278 assert!(!loaded_context.text.contains("file4.rs"));
3279 }
3280
3281 #[gpui::test]
3282 async fn test_message_without_files(cx: &mut TestAppContext) {
3283 init_test_settings(cx);
3284
3285 let project = create_test_project(
3286 cx,
3287 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3288 )
3289 .await;
3290
3291 let (_, _thread_store, thread, _context_store, model) =
3292 setup_test_environment(cx, project.clone()).await;
3293
3294 // Insert user message without any context (empty context vector)
3295 let message_id = thread.update(cx, |thread, cx| {
3296 thread.insert_user_message(
3297 "What is the best way to learn Rust?",
3298 ContextLoadResult::default(),
3299 None,
3300 Vec::new(),
3301 cx,
3302 )
3303 });
3304
3305 // Check content and context in message object
3306 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3307
3308 // Context should be empty when no files are included
3309 assert_eq!(message.role, Role::User);
3310 assert_eq!(message.segments.len(), 1);
3311 assert_eq!(
3312 message.segments[0],
3313 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3314 );
3315 assert_eq!(message.loaded_context.text, "");
3316
3317 // Check message in request
3318 let request = thread.update(cx, |thread, cx| {
3319 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3320 });
3321
3322 assert_eq!(request.messages.len(), 2);
3323 assert_eq!(
3324 request.messages[1].string_contents(),
3325 "What is the best way to learn Rust?"
3326 );
3327
3328 // Add second message, also without context
3329 let message2_id = thread.update(cx, |thread, cx| {
3330 thread.insert_user_message(
3331 "Are there any good books?",
3332 ContextLoadResult::default(),
3333 None,
3334 Vec::new(),
3335 cx,
3336 )
3337 });
3338
3339 let message2 =
3340 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3341 assert_eq!(message2.loaded_context.text, "");
3342
3343 // Check that both messages appear in the request
3344 let request = thread.update(cx, |thread, cx| {
3345 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3346 });
3347
3348 assert_eq!(request.messages.len(), 3);
3349 assert_eq!(
3350 request.messages[1].string_contents(),
3351 "What is the best way to learn Rust?"
3352 );
3353 assert_eq!(
3354 request.messages[2].string_contents(),
3355 "Are there any good books?"
3356 );
3357 }
3358
3359 #[gpui::test]
3360 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3361 init_test_settings(cx);
3362
3363 let project = create_test_project(
3364 cx,
3365 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3366 )
3367 .await;
3368
3369 let (_workspace, thread_store, thread, _context_store, _model) =
3370 setup_test_environment(cx, project.clone()).await;
3371
3372 // Check that we are starting with the default profile
3373 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3374 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3375 assert_eq!(
3376 profile,
3377 AgentProfile::new(AgentProfileId::default(), tool_set)
3378 );
3379 }
3380
3381 #[gpui::test]
3382 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3383 init_test_settings(cx);
3384
3385 let project = create_test_project(
3386 cx,
3387 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3388 )
3389 .await;
3390
3391 let (_workspace, thread_store, thread, _context_store, _model) =
3392 setup_test_environment(cx, project.clone()).await;
3393
3394 // Profile gets serialized with default values
3395 let serialized = thread
3396 .update(cx, |thread, cx| thread.serialize(cx))
3397 .await
3398 .unwrap();
3399
3400 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3401
3402 let deserialized = cx.update(|cx| {
3403 thread.update(cx, |thread, cx| {
3404 Thread::deserialize(
3405 thread.id.clone(),
3406 serialized,
3407 thread.project.clone(),
3408 thread.tools.clone(),
3409 thread.prompt_builder.clone(),
3410 thread.project_context.clone(),
3411 None,
3412 cx,
3413 )
3414 })
3415 });
3416 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3417
3418 assert_eq!(
3419 deserialized.profile,
3420 AgentProfile::new(AgentProfileId::default(), tool_set)
3421 );
3422 }
3423
3424 #[gpui::test]
3425 async fn test_temperature_setting(cx: &mut TestAppContext) {
3426 init_test_settings(cx);
3427
3428 let project = create_test_project(
3429 cx,
3430 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3431 )
3432 .await;
3433
3434 let (_workspace, _thread_store, thread, _context_store, model) =
3435 setup_test_environment(cx, project.clone()).await;
3436
3437 // Both model and provider
3438 cx.update(|cx| {
3439 AgentSettings::override_global(
3440 AgentSettings {
3441 model_parameters: vec![LanguageModelParameters {
3442 provider: Some(model.provider_id().0.to_string().into()),
3443 model: Some(model.id().0.clone()),
3444 temperature: Some(0.66),
3445 }],
3446 ..AgentSettings::get_global(cx).clone()
3447 },
3448 cx,
3449 );
3450 });
3451
3452 let request = thread.update(cx, |thread, cx| {
3453 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3454 });
3455 assert_eq!(request.temperature, Some(0.66));
3456
3457 // Only model
3458 cx.update(|cx| {
3459 AgentSettings::override_global(
3460 AgentSettings {
3461 model_parameters: vec![LanguageModelParameters {
3462 provider: None,
3463 model: Some(model.id().0.clone()),
3464 temperature: Some(0.66),
3465 }],
3466 ..AgentSettings::get_global(cx).clone()
3467 },
3468 cx,
3469 );
3470 });
3471
3472 let request = thread.update(cx, |thread, cx| {
3473 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3474 });
3475 assert_eq!(request.temperature, Some(0.66));
3476
3477 // Only provider
3478 cx.update(|cx| {
3479 AgentSettings::override_global(
3480 AgentSettings {
3481 model_parameters: vec![LanguageModelParameters {
3482 provider: Some(model.provider_id().0.to_string().into()),
3483 model: None,
3484 temperature: Some(0.66),
3485 }],
3486 ..AgentSettings::get_global(cx).clone()
3487 },
3488 cx,
3489 );
3490 });
3491
3492 let request = thread.update(cx, |thread, cx| {
3493 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3494 });
3495 assert_eq!(request.temperature, Some(0.66));
3496
3497 // Same model name, different provider
3498 cx.update(|cx| {
3499 AgentSettings::override_global(
3500 AgentSettings {
3501 model_parameters: vec![LanguageModelParameters {
3502 provider: Some("anthropic".into()),
3503 model: Some(model.id().0.clone()),
3504 temperature: Some(0.66),
3505 }],
3506 ..AgentSettings::get_global(cx).clone()
3507 },
3508 cx,
3509 );
3510 });
3511
3512 let request = thread.update(cx, |thread, cx| {
3513 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3514 });
3515 assert_eq!(request.temperature, None);
3516 }
3517
3518 #[gpui::test]
3519 async fn test_thread_summary(cx: &mut TestAppContext) {
3520 init_test_settings(cx);
3521
3522 let project = create_test_project(cx, json!({})).await;
3523
3524 let (_, _thread_store, thread, _context_store, model) =
3525 setup_test_environment(cx, project.clone()).await;
3526
3527 // Initial state should be pending
3528 thread.read_with(cx, |thread, _| {
3529 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3530 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3531 });
3532
3533 // Manually setting the summary should not be allowed in this state
3534 thread.update(cx, |thread, cx| {
3535 thread.set_summary("This should not work", cx);
3536 });
3537
3538 thread.read_with(cx, |thread, _| {
3539 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3540 });
3541
3542 // Send a message
3543 thread.update(cx, |thread, cx| {
3544 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3545 thread.send_to_model(
3546 model.clone(),
3547 CompletionIntent::ThreadSummarization,
3548 None,
3549 cx,
3550 );
3551 });
3552
3553 let fake_model = model.as_fake();
3554 simulate_successful_response(&fake_model, cx);
3555
3556 // Should start generating summary when there are >= 2 messages
3557 thread.read_with(cx, |thread, _| {
3558 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3559 });
3560
3561 // Should not be able to set the summary while generating
3562 thread.update(cx, |thread, cx| {
3563 thread.set_summary("This should not work either", cx);
3564 });
3565
3566 thread.read_with(cx, |thread, _| {
3567 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3568 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3569 });
3570
3571 cx.run_until_parked();
3572 fake_model.stream_last_completion_response("Brief");
3573 fake_model.stream_last_completion_response(" Introduction");
3574 fake_model.end_last_completion_stream();
3575 cx.run_until_parked();
3576
3577 // Summary should be set
3578 thread.read_with(cx, |thread, _| {
3579 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3580 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3581 });
3582
3583 // Now we should be able to set a summary
3584 thread.update(cx, |thread, cx| {
3585 thread.set_summary("Brief Intro", cx);
3586 });
3587
3588 thread.read_with(cx, |thread, _| {
3589 assert_eq!(thread.summary().or_default(), "Brief Intro");
3590 });
3591
3592 // Test setting an empty summary (should default to DEFAULT)
3593 thread.update(cx, |thread, cx| {
3594 thread.set_summary("", cx);
3595 });
3596
3597 thread.read_with(cx, |thread, _| {
3598 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3599 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3600 });
3601 }
3602
3603 #[gpui::test]
3604 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3605 init_test_settings(cx);
3606
3607 let project = create_test_project(cx, json!({})).await;
3608
3609 let (_, _thread_store, thread, _context_store, model) =
3610 setup_test_environment(cx, project.clone()).await;
3611
3612 test_summarize_error(&model, &thread, cx);
3613
3614 // Now we should be able to set a summary
3615 thread.update(cx, |thread, cx| {
3616 thread.set_summary("Brief Intro", cx);
3617 });
3618
3619 thread.read_with(cx, |thread, _| {
3620 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3621 assert_eq!(thread.summary().or_default(), "Brief Intro");
3622 });
3623 }
3624
3625 #[gpui::test]
3626 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3627 init_test_settings(cx);
3628
3629 let project = create_test_project(cx, json!({})).await;
3630
3631 let (_, _thread_store, thread, _context_store, model) =
3632 setup_test_environment(cx, project.clone()).await;
3633
3634 test_summarize_error(&model, &thread, cx);
3635
3636 // Sending another message should not trigger another summarize request
3637 thread.update(cx, |thread, cx| {
3638 thread.insert_user_message(
3639 "How are you?",
3640 ContextLoadResult::default(),
3641 None,
3642 vec![],
3643 cx,
3644 );
3645 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3646 });
3647
3648 let fake_model = model.as_fake();
3649 simulate_successful_response(&fake_model, cx);
3650
3651 thread.read_with(cx, |thread, _| {
3652 // State is still Error, not Generating
3653 assert!(matches!(thread.summary(), ThreadSummary::Error));
3654 });
3655
3656 // But the summarize request can be invoked manually
3657 thread.update(cx, |thread, cx| {
3658 thread.summarize(cx);
3659 });
3660
3661 thread.read_with(cx, |thread, _| {
3662 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3663 });
3664
3665 cx.run_until_parked();
3666 fake_model.stream_last_completion_response("A successful summary");
3667 fake_model.end_last_completion_stream();
3668 cx.run_until_parked();
3669
3670 thread.read_with(cx, |thread, _| {
3671 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3672 assert_eq!(thread.summary().or_default(), "A successful summary");
3673 });
3674 }
3675
3676 #[gpui::test]
3677 fn test_resolve_tool_name_conflicts() {
3678 use assistant_tool::{Tool, ToolSource};
3679
3680 assert_resolve_tool_name_conflicts(
3681 vec![
3682 TestTool::new("tool1", ToolSource::Native),
3683 TestTool::new("tool2", ToolSource::Native),
3684 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3685 ],
3686 vec!["tool1", "tool2", "tool3"],
3687 );
3688
3689 assert_resolve_tool_name_conflicts(
3690 vec![
3691 TestTool::new("tool1", ToolSource::Native),
3692 TestTool::new("tool2", ToolSource::Native),
3693 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3694 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3695 ],
3696 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3697 );
3698
3699 assert_resolve_tool_name_conflicts(
3700 vec![
3701 TestTool::new("tool1", ToolSource::Native),
3702 TestTool::new("tool2", ToolSource::Native),
3703 TestTool::new("tool3", ToolSource::Native),
3704 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3705 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3706 ],
3707 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3708 );
3709
3710 // Test that tool with very long name is always truncated
3711 assert_resolve_tool_name_conflicts(
3712 vec![TestTool::new(
3713 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3714 ToolSource::Native,
3715 )],
3716 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3717 );
3718
3719 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3720 assert_resolve_tool_name_conflicts(
3721 vec![
3722 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3723 TestTool::new(
3724 "tool-with-very-very-very-long-name",
3725 ToolSource::ContextServer {
3726 id: "mcp-with-very-very-very-long-name".into(),
3727 },
3728 ),
3729 ],
3730 vec![
3731 "tool-with-very-very-very-long-name",
3732 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3733 ],
3734 );
3735
3736 fn assert_resolve_tool_name_conflicts(
3737 tools: Vec<TestTool>,
3738 expected: Vec<impl Into<String>>,
3739 ) {
3740 let tools: Vec<Arc<dyn Tool>> = tools
3741 .into_iter()
3742 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3743 .collect();
3744 let tools = resolve_tool_name_conflicts(&tools);
3745 assert_eq!(tools.len(), expected.len());
3746 for (i, expected_name) in expected.into_iter().enumerate() {
3747 let expected_name = expected_name.into();
3748 let actual_name = &tools[i].0;
3749 assert_eq!(
3750 actual_name, &expected_name,
3751 "Expected '{}' got '{}' at index {}",
3752 expected_name, actual_name, i
3753 );
3754 }
3755 }
3756
3757 struct TestTool {
3758 name: String,
3759 source: ToolSource,
3760 }
3761
3762 impl TestTool {
3763 fn new(name: impl Into<String>, source: ToolSource) -> Self {
3764 Self {
3765 name: name.into(),
3766 source,
3767 }
3768 }
3769 }
3770
3771 impl Tool for TestTool {
3772 fn name(&self) -> String {
3773 self.name.clone()
3774 }
3775
3776 fn icon(&self) -> IconName {
3777 IconName::Ai
3778 }
3779
3780 fn may_perform_edits(&self) -> bool {
3781 false
3782 }
3783
3784 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3785 true
3786 }
3787
3788 fn source(&self) -> ToolSource {
3789 self.source.clone()
3790 }
3791
3792 fn description(&self) -> String {
3793 "Test tool".to_string()
3794 }
3795
3796 fn ui_text(&self, _input: &serde_json::Value) -> String {
3797 "Test tool".to_string()
3798 }
3799
3800 fn run(
3801 self: Arc<Self>,
3802 _input: serde_json::Value,
3803 _request: Arc<LanguageModelRequest>,
3804 _project: Entity<Project>,
3805 _action_log: Entity<ActionLog>,
3806 _model: Arc<dyn LanguageModel>,
3807 _window: Option<AnyWindowHandle>,
3808 _cx: &mut App,
3809 ) -> assistant_tool::ToolResult {
3810 assistant_tool::ToolResult {
3811 output: Task::ready(Err(anyhow::anyhow!("No content"))),
3812 card: None,
3813 }
3814 }
3815 }
3816 }
3817
3818 fn test_summarize_error(
3819 model: &Arc<dyn LanguageModel>,
3820 thread: &Entity<Thread>,
3821 cx: &mut TestAppContext,
3822 ) {
3823 thread.update(cx, |thread, cx| {
3824 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3825 thread.send_to_model(
3826 model.clone(),
3827 CompletionIntent::ThreadSummarization,
3828 None,
3829 cx,
3830 );
3831 });
3832
3833 let fake_model = model.as_fake();
3834 simulate_successful_response(&fake_model, cx);
3835
3836 thread.read_with(cx, |thread, _| {
3837 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3838 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3839 });
3840
3841 // Simulate summary request ending
3842 cx.run_until_parked();
3843 fake_model.end_last_completion_stream();
3844 cx.run_until_parked();
3845
3846 // State is set to Error and default message
3847 thread.read_with(cx, |thread, _| {
3848 assert!(matches!(thread.summary(), ThreadSummary::Error));
3849 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3850 });
3851 }
3852
3853 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3854 cx.run_until_parked();
3855 fake_model.stream_last_completion_response("Assistant response");
3856 fake_model.end_last_completion_stream();
3857 cx.run_until_parked();
3858 }
3859
3860 fn init_test_settings(cx: &mut TestAppContext) {
3861 cx.update(|cx| {
3862 let settings_store = SettingsStore::test(cx);
3863 cx.set_global(settings_store);
3864 language::init(cx);
3865 Project::init_settings(cx);
3866 AgentSettings::register(cx);
3867 prompt_store::init(cx);
3868 thread_store::init(cx);
3869 workspace::init_settings(cx);
3870 language_model::init_settings(cx);
3871 ThemeSettings::register(cx);
3872 ToolRegistry::default_global(cx);
3873 });
3874 }
3875
3876 // Helper to create a test project with test files
3877 async fn create_test_project(
3878 cx: &mut TestAppContext,
3879 files: serde_json::Value,
3880 ) -> Entity<Project> {
3881 let fs = FakeFs::new(cx.executor());
3882 fs.insert_tree(path!("/test"), files).await;
3883 Project::test(fs, [path!("/test").as_ref()], cx).await
3884 }
3885
3886 async fn setup_test_environment(
3887 cx: &mut TestAppContext,
3888 project: Entity<Project>,
3889 ) -> (
3890 Entity<Workspace>,
3891 Entity<ThreadStore>,
3892 Entity<Thread>,
3893 Entity<ContextStore>,
3894 Arc<dyn LanguageModel>,
3895 ) {
3896 let (workspace, cx) =
3897 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3898
3899 let thread_store = cx
3900 .update(|_, cx| {
3901 ThreadStore::load(
3902 project.clone(),
3903 cx.new(|_| ToolWorkingSet::default()),
3904 None,
3905 Arc::new(PromptBuilder::new(None).unwrap()),
3906 cx,
3907 )
3908 })
3909 .await
3910 .unwrap();
3911
3912 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3913 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3914
3915 let provider = Arc::new(FakeLanguageModelProvider);
3916 let model = provider.test_model();
3917 let model: Arc<dyn LanguageModel> = Arc::new(model);
3918
3919 cx.update(|_, cx| {
3920 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3921 registry.set_default_model(
3922 Some(ConfiguredModel {
3923 provider: provider.clone(),
3924 model: model.clone(),
3925 }),
3926 cx,
3927 );
3928 registry.set_thread_summary_model(
3929 Some(ConfiguredModel {
3930 provider,
3931 model: model.clone(),
3932 }),
3933 cx,
3934 );
3935 })
3936 });
3937
3938 (workspace, thread_store, thread, context_store, model)
3939 }
3940
3941 async fn add_file_to_context(
3942 project: &Entity<Project>,
3943 context_store: &Entity<ContextStore>,
3944 path: &str,
3945 cx: &mut TestAppContext,
3946 ) -> Result<Entity<language::Buffer>> {
3947 let buffer_path = project
3948 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3949 .unwrap();
3950
3951 let buffer = project
3952 .update(cx, |project, cx| {
3953 project.open_buffer(buffer_path.clone(), cx)
3954 })
3955 .await
3956 .unwrap();
3957
3958 context_store.update(cx, |context_store, cx| {
3959 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3960 });
3961
3962 Ok(buffer)
3963 }
3964}