1use std::io::Write;
2use std::ops::Range;
3use std::sync::Arc;
4use std::time::Instant;
5
6use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
7use anyhow::{Result, anyhow};
8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
9use chrono::{DateTime, Utc};
10use client::{ModelRequestUsage, RequestUsage};
11use collections::{HashMap, HashSet};
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
27 TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40
41use uuid::Uuid;
42use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
43
44use crate::ThreadStore;
45use crate::agent_profile::AgentProfile;
46use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
47use crate::thread_store::{
48 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
49 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
50};
51use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
52
53#[derive(
54 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
55)]
56pub struct ThreadId(Arc<str>);
57
58impl ThreadId {
59 pub fn new() -> Self {
60 Self(Uuid::new_v4().to_string().into())
61 }
62}
63
64impl std::fmt::Display for ThreadId {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 write!(f, "{}", self.0)
67 }
68}
69
70impl From<&str> for ThreadId {
71 fn from(value: &str) -> Self {
72 Self(value.into())
73 }
74}
75
76/// The ID of the user prompt that initiated a request.
77///
78/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
79#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
80pub struct PromptId(Arc<str>);
81
82impl PromptId {
83 pub fn new() -> Self {
84 Self(Uuid::new_v4().to_string().into())
85 }
86}
87
88impl std::fmt::Display for PromptId {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(f, "{}", self.0)
91 }
92}
93
94#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
95pub struct MessageId(pub(crate) usize);
96
97impl MessageId {
98 fn post_inc(&mut self) -> Self {
99 Self(post_inc(&mut self.0))
100 }
101}
102
103/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
104#[derive(Clone, Debug)]
105pub struct MessageCrease {
106 pub range: Range<usize>,
107 pub metadata: CreaseMetadata,
108 /// None for a deserialized message, Some otherwise.
109 pub context: Option<AgentContextHandle>,
110}
111
112/// A message in a [`Thread`].
113#[derive(Debug, Clone)]
114pub struct Message {
115 pub id: MessageId,
116 pub role: Role,
117 pub segments: Vec<MessageSegment>,
118 pub loaded_context: LoadedContext,
119 pub creases: Vec<MessageCrease>,
120 pub is_hidden: bool,
121}
122
123impl Message {
124 /// Returns whether the message contains any meaningful text that should be displayed
125 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
126 pub fn should_display_content(&self) -> bool {
127 self.segments.iter().all(|segment| segment.should_display())
128 }
129
130 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
131 if let Some(MessageSegment::Thinking {
132 text: segment,
133 signature: current_signature,
134 }) = self.segments.last_mut()
135 {
136 if let Some(signature) = signature {
137 *current_signature = Some(signature);
138 }
139 segment.push_str(text);
140 } else {
141 self.segments.push(MessageSegment::Thinking {
142 text: text.to_string(),
143 signature,
144 });
145 }
146 }
147
148 pub fn push_text(&mut self, text: &str) {
149 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
150 segment.push_str(text);
151 } else {
152 self.segments.push(MessageSegment::Text(text.to_string()));
153 }
154 }
155
156 pub fn to_string(&self) -> String {
157 let mut result = String::new();
158
159 if !self.loaded_context.text.is_empty() {
160 result.push_str(&self.loaded_context.text);
161 }
162
163 for segment in &self.segments {
164 match segment {
165 MessageSegment::Text(text) => result.push_str(text),
166 MessageSegment::Thinking { text, .. } => {
167 result.push_str("<think>\n");
168 result.push_str(text);
169 result.push_str("\n</think>");
170 }
171 MessageSegment::RedactedThinking(_) => {}
172 }
173 }
174
175 result
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
180pub enum MessageSegment {
181 Text(String),
182 Thinking {
183 text: String,
184 signature: Option<String>,
185 },
186 RedactedThinking(Vec<u8>),
187}
188
189impl MessageSegment {
190 pub fn should_display(&self) -> bool {
191 match self {
192 Self::Text(text) => text.is_empty(),
193 Self::Thinking { text, .. } => text.is_empty(),
194 Self::RedactedThinking(_) => false,
195 }
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
200pub struct ProjectSnapshot {
201 pub worktree_snapshots: Vec<WorktreeSnapshot>,
202 pub unsaved_buffer_paths: Vec<String>,
203 pub timestamp: DateTime<Utc>,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
207pub struct WorktreeSnapshot {
208 pub worktree_path: String,
209 pub git_state: Option<GitState>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
213pub struct GitState {
214 pub remote_url: Option<String>,
215 pub head_sha: Option<String>,
216 pub current_branch: Option<String>,
217 pub diff: Option<String>,
218}
219
220#[derive(Clone, Debug)]
221pub struct ThreadCheckpoint {
222 message_id: MessageId,
223 git_checkpoint: GitStoreCheckpoint,
224}
225
226#[derive(Copy, Clone, Debug, PartialEq, Eq)]
227pub enum ThreadFeedback {
228 Positive,
229 Negative,
230}
231
232pub enum LastRestoreCheckpoint {
233 Pending {
234 message_id: MessageId,
235 },
236 Error {
237 message_id: MessageId,
238 error: String,
239 },
240}
241
242impl LastRestoreCheckpoint {
243 pub fn message_id(&self) -> MessageId {
244 match self {
245 LastRestoreCheckpoint::Pending { message_id } => *message_id,
246 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
247 }
248 }
249}
250
251#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
252pub enum DetailedSummaryState {
253 #[default]
254 NotGenerated,
255 Generating {
256 message_id: MessageId,
257 },
258 Generated {
259 text: SharedString,
260 message_id: MessageId,
261 },
262}
263
264impl DetailedSummaryState {
265 fn text(&self) -> Option<SharedString> {
266 if let Self::Generated { text, .. } = self {
267 Some(text.clone())
268 } else {
269 None
270 }
271 }
272}
273
274#[derive(Default, Debug)]
275pub struct TotalTokenUsage {
276 pub total: u64,
277 pub max: u64,
278}
279
280impl TotalTokenUsage {
281 pub fn ratio(&self) -> TokenUsageRatio {
282 #[cfg(debug_assertions)]
283 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
284 .unwrap_or("0.8".to_string())
285 .parse()
286 .unwrap();
287 #[cfg(not(debug_assertions))]
288 let warning_threshold: f32 = 0.8;
289
290 // When the maximum is unknown because there is no selected model,
291 // avoid showing the token limit warning.
292 if self.max == 0 {
293 TokenUsageRatio::Normal
294 } else if self.total >= self.max {
295 TokenUsageRatio::Exceeded
296 } else if self.total as f32 / self.max as f32 >= warning_threshold {
297 TokenUsageRatio::Warning
298 } else {
299 TokenUsageRatio::Normal
300 }
301 }
302
303 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
304 TotalTokenUsage {
305 total: self.total + tokens,
306 max: self.max,
307 }
308 }
309}
310
311#[derive(Debug, Default, PartialEq, Eq)]
312pub enum TokenUsageRatio {
313 #[default]
314 Normal,
315 Warning,
316 Exceeded,
317}
318
319#[derive(Debug, Clone, Copy)]
320pub enum QueueState {
321 Sending,
322 Queued { position: usize },
323 Started,
324}
325
326/// A thread of conversation with the LLM.
327pub struct Thread {
328 id: ThreadId,
329 updated_at: DateTime<Utc>,
330 summary: ThreadSummary,
331 pending_summary: Task<Option<()>>,
332 detailed_summary_task: Task<Option<()>>,
333 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
334 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
335 completion_mode: agent_settings::CompletionMode,
336 messages: Vec<Message>,
337 next_message_id: MessageId,
338 last_prompt_id: PromptId,
339 project_context: SharedProjectContext,
340 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
341 completion_count: usize,
342 pending_completions: Vec<PendingCompletion>,
343 project: Entity<Project>,
344 prompt_builder: Arc<PromptBuilder>,
345 tools: Entity<ToolWorkingSet>,
346 tool_use: ToolUseState,
347 action_log: Entity<ActionLog>,
348 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
349 pending_checkpoint: Option<ThreadCheckpoint>,
350 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
351 request_token_usage: Vec<TokenUsage>,
352 cumulative_token_usage: TokenUsage,
353 exceeded_window_error: Option<ExceededWindowError>,
354 tool_use_limit_reached: bool,
355 feedback: Option<ThreadFeedback>,
356 message_feedback: HashMap<MessageId, ThreadFeedback>,
357 last_auto_capture_at: Option<Instant>,
358 last_received_chunk_at: Option<Instant>,
359 request_callback: Option<
360 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
361 >,
362 remaining_turns: u32,
363 configured_model: Option<ConfiguredModel>,
364 profile: AgentProfile,
365}
366
367#[derive(Clone, Debug, PartialEq, Eq)]
368pub enum ThreadSummary {
369 Pending,
370 Generating,
371 Ready(SharedString),
372 Error,
373}
374
375impl ThreadSummary {
376 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
377
378 pub fn or_default(&self) -> SharedString {
379 self.unwrap_or(Self::DEFAULT)
380 }
381
382 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
383 self.ready().unwrap_or_else(|| message.into())
384 }
385
386 pub fn ready(&self) -> Option<SharedString> {
387 match self {
388 ThreadSummary::Ready(summary) => Some(summary.clone()),
389 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
390 }
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
395pub struct ExceededWindowError {
396 /// Model used when last message exceeded context window
397 model_id: LanguageModelId,
398 /// Token count including last message
399 token_count: u64,
400}
401
402impl Thread {
403 pub fn new(
404 project: Entity<Project>,
405 tools: Entity<ToolWorkingSet>,
406 prompt_builder: Arc<PromptBuilder>,
407 system_prompt: SharedProjectContext,
408 cx: &mut Context<Self>,
409 ) -> Self {
410 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
411 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
412 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
413
414 Self {
415 id: ThreadId::new(),
416 updated_at: Utc::now(),
417 summary: ThreadSummary::Pending,
418 pending_summary: Task::ready(None),
419 detailed_summary_task: Task::ready(None),
420 detailed_summary_tx,
421 detailed_summary_rx,
422 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
423 messages: Vec::new(),
424 next_message_id: MessageId(0),
425 last_prompt_id: PromptId::new(),
426 project_context: system_prompt,
427 checkpoints_by_message: HashMap::default(),
428 completion_count: 0,
429 pending_completions: Vec::new(),
430 project: project.clone(),
431 prompt_builder,
432 tools: tools.clone(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 tool_use: ToolUseState::new(tools.clone()),
436 action_log: cx.new(|_| ActionLog::new(project.clone())),
437 initial_project_snapshot: {
438 let project_snapshot = Self::project_snapshot(project, cx);
439 cx.foreground_executor()
440 .spawn(async move { Some(project_snapshot.await) })
441 .shared()
442 },
443 request_token_usage: Vec::new(),
444 cumulative_token_usage: TokenUsage::default(),
445 exceeded_window_error: None,
446 tool_use_limit_reached: false,
447 feedback: None,
448 message_feedback: HashMap::default(),
449 last_auto_capture_at: None,
450 last_received_chunk_at: None,
451 request_callback: None,
452 remaining_turns: u32::MAX,
453 configured_model,
454 profile: AgentProfile::new(profile_id, tools),
455 }
456 }
457
458 pub fn deserialize(
459 id: ThreadId,
460 serialized: SerializedThread,
461 project: Entity<Project>,
462 tools: Entity<ToolWorkingSet>,
463 prompt_builder: Arc<PromptBuilder>,
464 project_context: SharedProjectContext,
465 window: Option<&mut Window>, // None in headless mode
466 cx: &mut Context<Self>,
467 ) -> Self {
468 let next_message_id = MessageId(
469 serialized
470 .messages
471 .last()
472 .map(|message| message.id.0 + 1)
473 .unwrap_or(0),
474 );
475 let tool_use = ToolUseState::from_serialized_messages(
476 tools.clone(),
477 &serialized.messages,
478 project.clone(),
479 window,
480 cx,
481 );
482 let (detailed_summary_tx, detailed_summary_rx) =
483 postage::watch::channel_with(serialized.detailed_summary_state);
484
485 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
486 serialized
487 .model
488 .and_then(|model| {
489 let model = SelectedModel {
490 provider: model.provider.clone().into(),
491 model: model.model.clone().into(),
492 };
493 registry.select_model(&model, cx)
494 })
495 .or_else(|| registry.default_model())
496 });
497
498 let completion_mode = serialized
499 .completion_mode
500 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
501 let profile_id = serialized
502 .profile
503 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
504
505 Self {
506 id,
507 updated_at: serialized.updated_at,
508 summary: ThreadSummary::Ready(serialized.summary),
509 pending_summary: Task::ready(None),
510 detailed_summary_task: Task::ready(None),
511 detailed_summary_tx,
512 detailed_summary_rx,
513 completion_mode,
514 messages: serialized
515 .messages
516 .into_iter()
517 .map(|message| Message {
518 id: message.id,
519 role: message.role,
520 segments: message
521 .segments
522 .into_iter()
523 .map(|segment| match segment {
524 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
525 SerializedMessageSegment::Thinking { text, signature } => {
526 MessageSegment::Thinking { text, signature }
527 }
528 SerializedMessageSegment::RedactedThinking { data } => {
529 MessageSegment::RedactedThinking(data)
530 }
531 })
532 .collect(),
533 loaded_context: LoadedContext {
534 contexts: Vec::new(),
535 text: message.context,
536 images: Vec::new(),
537 },
538 creases: message
539 .creases
540 .into_iter()
541 .map(|crease| MessageCrease {
542 range: crease.start..crease.end,
543 metadata: CreaseMetadata {
544 icon_path: crease.icon_path,
545 label: crease.label,
546 },
547 context: None,
548 })
549 .collect(),
550 is_hidden: message.is_hidden,
551 })
552 .collect(),
553 next_message_id,
554 last_prompt_id: PromptId::new(),
555 project_context,
556 checkpoints_by_message: HashMap::default(),
557 completion_count: 0,
558 pending_completions: Vec::new(),
559 last_restore_checkpoint: None,
560 pending_checkpoint: None,
561 project: project.clone(),
562 prompt_builder,
563 tools: tools.clone(),
564 tool_use,
565 action_log: cx.new(|_| ActionLog::new(project)),
566 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
567 request_token_usage: serialized.request_token_usage,
568 cumulative_token_usage: serialized.cumulative_token_usage,
569 exceeded_window_error: None,
570 tool_use_limit_reached: serialized.tool_use_limit_reached,
571 feedback: None,
572 message_feedback: HashMap::default(),
573 last_auto_capture_at: None,
574 last_received_chunk_at: None,
575 request_callback: None,
576 remaining_turns: u32::MAX,
577 configured_model,
578 profile: AgentProfile::new(profile_id, tools),
579 }
580 }
581
582 pub fn set_request_callback(
583 &mut self,
584 callback: impl 'static
585 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
586 ) {
587 self.request_callback = Some(Box::new(callback));
588 }
589
590 pub fn id(&self) -> &ThreadId {
591 &self.id
592 }
593
594 pub fn profile(&self) -> &AgentProfile {
595 &self.profile
596 }
597
598 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
599 if &id != self.profile.id() {
600 self.profile = AgentProfile::new(id, self.tools.clone());
601 cx.emit(ThreadEvent::ProfileChanged);
602 }
603 }
604
605 pub fn is_empty(&self) -> bool {
606 self.messages.is_empty()
607 }
608
609 pub fn updated_at(&self) -> DateTime<Utc> {
610 self.updated_at
611 }
612
613 pub fn touch_updated_at(&mut self) {
614 self.updated_at = Utc::now();
615 }
616
617 pub fn advance_prompt_id(&mut self) {
618 self.last_prompt_id = PromptId::new();
619 }
620
621 pub fn project_context(&self) -> SharedProjectContext {
622 self.project_context.clone()
623 }
624
625 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
626 if self.configured_model.is_none() {
627 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
628 }
629 self.configured_model.clone()
630 }
631
632 pub fn configured_model(&self) -> Option<ConfiguredModel> {
633 self.configured_model.clone()
634 }
635
636 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
637 self.configured_model = model;
638 cx.notify();
639 }
640
641 pub fn summary(&self) -> &ThreadSummary {
642 &self.summary
643 }
644
645 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
646 let current_summary = match &self.summary {
647 ThreadSummary::Pending | ThreadSummary::Generating => return,
648 ThreadSummary::Ready(summary) => summary,
649 ThreadSummary::Error => &ThreadSummary::DEFAULT,
650 };
651
652 let mut new_summary = new_summary.into();
653
654 if new_summary.is_empty() {
655 new_summary = ThreadSummary::DEFAULT;
656 }
657
658 if current_summary != &new_summary {
659 self.summary = ThreadSummary::Ready(new_summary);
660 cx.emit(ThreadEvent::SummaryChanged);
661 }
662 }
663
664 pub fn completion_mode(&self) -> CompletionMode {
665 self.completion_mode
666 }
667
668 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
669 self.completion_mode = mode;
670 }
671
672 pub fn message(&self, id: MessageId) -> Option<&Message> {
673 let index = self
674 .messages
675 .binary_search_by(|message| message.id.cmp(&id))
676 .ok()?;
677
678 self.messages.get(index)
679 }
680
681 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
682 self.messages.iter()
683 }
684
685 pub fn is_generating(&self) -> bool {
686 !self.pending_completions.is_empty() || !self.all_tools_finished()
687 }
688
689 /// Indicates whether streaming of language model events is stale.
690 /// When `is_generating()` is false, this method returns `None`.
691 pub fn is_generation_stale(&self) -> Option<bool> {
692 const STALE_THRESHOLD: u128 = 250;
693
694 self.last_received_chunk_at
695 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
696 }
697
698 fn received_chunk(&mut self) {
699 self.last_received_chunk_at = Some(Instant::now());
700 }
701
702 pub fn queue_state(&self) -> Option<QueueState> {
703 self.pending_completions
704 .first()
705 .map(|pending_completion| pending_completion.queue_state)
706 }
707
708 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
709 &self.tools
710 }
711
712 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
713 self.tool_use
714 .pending_tool_uses()
715 .into_iter()
716 .find(|tool_use| &tool_use.id == id)
717 }
718
719 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
720 self.tool_use
721 .pending_tool_uses()
722 .into_iter()
723 .filter(|tool_use| tool_use.status.needs_confirmation())
724 }
725
726 pub fn has_pending_tool_uses(&self) -> bool {
727 !self.tool_use.pending_tool_uses().is_empty()
728 }
729
730 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
731 self.checkpoints_by_message.get(&id).cloned()
732 }
733
734 pub fn restore_checkpoint(
735 &mut self,
736 checkpoint: ThreadCheckpoint,
737 cx: &mut Context<Self>,
738 ) -> Task<Result<()>> {
739 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
740 message_id: checkpoint.message_id,
741 });
742 cx.emit(ThreadEvent::CheckpointChanged);
743 cx.notify();
744
745 let git_store = self.project().read(cx).git_store().clone();
746 let restore = git_store.update(cx, |git_store, cx| {
747 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
748 });
749
750 cx.spawn(async move |this, cx| {
751 let result = restore.await;
752 this.update(cx, |this, cx| {
753 if let Err(err) = result.as_ref() {
754 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
755 message_id: checkpoint.message_id,
756 error: err.to_string(),
757 });
758 } else {
759 this.truncate(checkpoint.message_id, cx);
760 this.last_restore_checkpoint = None;
761 }
762 this.pending_checkpoint = None;
763 cx.emit(ThreadEvent::CheckpointChanged);
764 cx.notify();
765 })?;
766 result
767 })
768 }
769
770 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
771 let pending_checkpoint = if self.is_generating() {
772 return;
773 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
774 checkpoint
775 } else {
776 return;
777 };
778
779 self.finalize_checkpoint(pending_checkpoint, cx);
780 }
781
782 fn finalize_checkpoint(
783 &mut self,
784 pending_checkpoint: ThreadCheckpoint,
785 cx: &mut Context<Self>,
786 ) {
787 let git_store = self.project.read(cx).git_store().clone();
788 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
789 cx.spawn(async move |this, cx| match final_checkpoint.await {
790 Ok(final_checkpoint) => {
791 let equal = git_store
792 .update(cx, |store, cx| {
793 store.compare_checkpoints(
794 pending_checkpoint.git_checkpoint.clone(),
795 final_checkpoint.clone(),
796 cx,
797 )
798 })?
799 .await
800 .unwrap_or(false);
801
802 if !equal {
803 this.update(cx, |this, cx| {
804 this.insert_checkpoint(pending_checkpoint, cx)
805 })?;
806 }
807
808 Ok(())
809 }
810 Err(_) => this.update(cx, |this, cx| {
811 this.insert_checkpoint(pending_checkpoint, cx)
812 }),
813 })
814 .detach();
815 }
816
817 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
818 self.checkpoints_by_message
819 .insert(checkpoint.message_id, checkpoint);
820 cx.emit(ThreadEvent::CheckpointChanged);
821 cx.notify();
822 }
823
824 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
825 self.last_restore_checkpoint.as_ref()
826 }
827
828 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
829 let Some(message_ix) = self
830 .messages
831 .iter()
832 .rposition(|message| message.id == message_id)
833 else {
834 return;
835 };
836 for deleted_message in self.messages.drain(message_ix..) {
837 self.checkpoints_by_message.remove(&deleted_message.id);
838 }
839 cx.notify();
840 }
841
842 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
843 self.messages
844 .iter()
845 .find(|message| message.id == id)
846 .into_iter()
847 .flat_map(|message| message.loaded_context.contexts.iter())
848 }
849
850 pub fn is_turn_end(&self, ix: usize) -> bool {
851 if self.messages.is_empty() {
852 return false;
853 }
854
855 if !self.is_generating() && ix == self.messages.len() - 1 {
856 return true;
857 }
858
859 let Some(message) = self.messages.get(ix) else {
860 return false;
861 };
862
863 if message.role != Role::Assistant {
864 return false;
865 }
866
867 self.messages
868 .get(ix + 1)
869 .and_then(|message| {
870 self.message(message.id)
871 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
872 })
873 .unwrap_or(false)
874 }
875
876 pub fn tool_use_limit_reached(&self) -> bool {
877 self.tool_use_limit_reached
878 }
879
880 /// Returns whether all of the tool uses have finished running.
881 pub fn all_tools_finished(&self) -> bool {
882 // If the only pending tool uses left are the ones with errors, then
883 // that means that we've finished running all of the pending tools.
884 self.tool_use
885 .pending_tool_uses()
886 .iter()
887 .all(|pending_tool_use| pending_tool_use.status.is_error())
888 }
889
890 /// Returns whether any pending tool uses may perform edits
891 pub fn has_pending_edit_tool_uses(&self) -> bool {
892 self.tool_use
893 .pending_tool_uses()
894 .iter()
895 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
896 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
897 }
898
899 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
900 self.tool_use.tool_uses_for_message(id, cx)
901 }
902
903 pub fn tool_results_for_message(
904 &self,
905 assistant_message_id: MessageId,
906 ) -> Vec<&LanguageModelToolResult> {
907 self.tool_use.tool_results_for_message(assistant_message_id)
908 }
909
910 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
911 self.tool_use.tool_result(id)
912 }
913
914 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
915 match &self.tool_use.tool_result(id)?.content {
916 LanguageModelToolResultContent::Text(text) => Some(text),
917 LanguageModelToolResultContent::Image(_) => {
918 // TODO: We should display image
919 None
920 }
921 }
922 }
923
924 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
925 self.tool_use.tool_result_card(id).cloned()
926 }
927
928 /// Return tools that are both enabled and supported by the model
929 pub fn available_tools(
930 &self,
931 cx: &App,
932 model: Arc<dyn LanguageModel>,
933 ) -> Vec<LanguageModelRequestTool> {
934 if model.supports_tools() {
935 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
936 .into_iter()
937 .filter_map(|(name, tool)| {
938 // Skip tools that cannot be supported
939 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
940 Some(LanguageModelRequestTool {
941 name,
942 description: tool.description(),
943 input_schema,
944 })
945 })
946 .collect()
947 } else {
948 Vec::default()
949 }
950 }
951
952 pub fn insert_user_message(
953 &mut self,
954 text: impl Into<String>,
955 loaded_context: ContextLoadResult,
956 git_checkpoint: Option<GitStoreCheckpoint>,
957 creases: Vec<MessageCrease>,
958 cx: &mut Context<Self>,
959 ) -> MessageId {
960 if !loaded_context.referenced_buffers.is_empty() {
961 self.action_log.update(cx, |log, cx| {
962 for buffer in loaded_context.referenced_buffers {
963 log.buffer_read(buffer, cx);
964 }
965 });
966 }
967
968 let message_id = self.insert_message(
969 Role::User,
970 vec![MessageSegment::Text(text.into())],
971 loaded_context.loaded_context,
972 creases,
973 false,
974 cx,
975 );
976
977 if let Some(git_checkpoint) = git_checkpoint {
978 self.pending_checkpoint = Some(ThreadCheckpoint {
979 message_id,
980 git_checkpoint,
981 });
982 }
983
984 self.auto_capture_telemetry(cx);
985
986 message_id
987 }
988
989 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
990 let id = self.insert_message(
991 Role::User,
992 vec![MessageSegment::Text("Continue where you left off".into())],
993 LoadedContext::default(),
994 vec![],
995 true,
996 cx,
997 );
998 self.pending_checkpoint = None;
999
1000 id
1001 }
1002
1003 pub fn insert_assistant_message(
1004 &mut self,
1005 segments: Vec<MessageSegment>,
1006 cx: &mut Context<Self>,
1007 ) -> MessageId {
1008 self.insert_message(
1009 Role::Assistant,
1010 segments,
1011 LoadedContext::default(),
1012 Vec::new(),
1013 false,
1014 cx,
1015 )
1016 }
1017
1018 pub fn insert_message(
1019 &mut self,
1020 role: Role,
1021 segments: Vec<MessageSegment>,
1022 loaded_context: LoadedContext,
1023 creases: Vec<MessageCrease>,
1024 is_hidden: bool,
1025 cx: &mut Context<Self>,
1026 ) -> MessageId {
1027 let id = self.next_message_id.post_inc();
1028 self.messages.push(Message {
1029 id,
1030 role,
1031 segments,
1032 loaded_context,
1033 creases,
1034 is_hidden,
1035 });
1036 self.touch_updated_at();
1037 cx.emit(ThreadEvent::MessageAdded(id));
1038 id
1039 }
1040
1041 pub fn edit_message(
1042 &mut self,
1043 id: MessageId,
1044 new_role: Role,
1045 new_segments: Vec<MessageSegment>,
1046 creases: Vec<MessageCrease>,
1047 loaded_context: Option<LoadedContext>,
1048 checkpoint: Option<GitStoreCheckpoint>,
1049 cx: &mut Context<Self>,
1050 ) -> bool {
1051 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1052 return false;
1053 };
1054 message.role = new_role;
1055 message.segments = new_segments;
1056 message.creases = creases;
1057 if let Some(context) = loaded_context {
1058 message.loaded_context = context;
1059 }
1060 if let Some(git_checkpoint) = checkpoint {
1061 self.checkpoints_by_message.insert(
1062 id,
1063 ThreadCheckpoint {
1064 message_id: id,
1065 git_checkpoint,
1066 },
1067 );
1068 }
1069 self.touch_updated_at();
1070 cx.emit(ThreadEvent::MessageEdited(id));
1071 true
1072 }
1073
1074 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1075 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1076 return false;
1077 };
1078 self.messages.remove(index);
1079 self.touch_updated_at();
1080 cx.emit(ThreadEvent::MessageDeleted(id));
1081 true
1082 }
1083
1084 /// Returns the representation of this [`Thread`] in a textual form.
1085 ///
1086 /// This is the representation we use when attaching a thread as context to another thread.
1087 pub fn text(&self) -> String {
1088 let mut text = String::new();
1089
1090 for message in &self.messages {
1091 text.push_str(match message.role {
1092 language_model::Role::User => "User:",
1093 language_model::Role::Assistant => "Agent:",
1094 language_model::Role::System => "System:",
1095 });
1096 text.push('\n');
1097
1098 for segment in &message.segments {
1099 match segment {
1100 MessageSegment::Text(content) => text.push_str(content),
1101 MessageSegment::Thinking { text: content, .. } => {
1102 text.push_str(&format!("<think>{}</think>", content))
1103 }
1104 MessageSegment::RedactedThinking(_) => {}
1105 }
1106 }
1107 text.push('\n');
1108 }
1109
1110 text
1111 }
1112
1113 /// Serializes this thread into a format for storage or telemetry.
1114 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1115 let initial_project_snapshot = self.initial_project_snapshot.clone();
1116 cx.spawn(async move |this, cx| {
1117 let initial_project_snapshot = initial_project_snapshot.await;
1118 this.read_with(cx, |this, cx| SerializedThread {
1119 version: SerializedThread::VERSION.to_string(),
1120 summary: this.summary().or_default(),
1121 updated_at: this.updated_at(),
1122 messages: this
1123 .messages()
1124 .map(|message| SerializedMessage {
1125 id: message.id,
1126 role: message.role,
1127 segments: message
1128 .segments
1129 .iter()
1130 .map(|segment| match segment {
1131 MessageSegment::Text(text) => {
1132 SerializedMessageSegment::Text { text: text.clone() }
1133 }
1134 MessageSegment::Thinking { text, signature } => {
1135 SerializedMessageSegment::Thinking {
1136 text: text.clone(),
1137 signature: signature.clone(),
1138 }
1139 }
1140 MessageSegment::RedactedThinking(data) => {
1141 SerializedMessageSegment::RedactedThinking {
1142 data: data.clone(),
1143 }
1144 }
1145 })
1146 .collect(),
1147 tool_uses: this
1148 .tool_uses_for_message(message.id, cx)
1149 .into_iter()
1150 .map(|tool_use| SerializedToolUse {
1151 id: tool_use.id,
1152 name: tool_use.name,
1153 input: tool_use.input,
1154 })
1155 .collect(),
1156 tool_results: this
1157 .tool_results_for_message(message.id)
1158 .into_iter()
1159 .map(|tool_result| SerializedToolResult {
1160 tool_use_id: tool_result.tool_use_id.clone(),
1161 is_error: tool_result.is_error,
1162 content: tool_result.content.clone(),
1163 output: tool_result.output.clone(),
1164 })
1165 .collect(),
1166 context: message.loaded_context.text.clone(),
1167 creases: message
1168 .creases
1169 .iter()
1170 .map(|crease| SerializedCrease {
1171 start: crease.range.start,
1172 end: crease.range.end,
1173 icon_path: crease.metadata.icon_path.clone(),
1174 label: crease.metadata.label.clone(),
1175 })
1176 .collect(),
1177 is_hidden: message.is_hidden,
1178 })
1179 .collect(),
1180 initial_project_snapshot,
1181 cumulative_token_usage: this.cumulative_token_usage,
1182 request_token_usage: this.request_token_usage.clone(),
1183 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1184 exceeded_window_error: this.exceeded_window_error.clone(),
1185 model: this
1186 .configured_model
1187 .as_ref()
1188 .map(|model| SerializedLanguageModel {
1189 provider: model.provider.id().0.to_string(),
1190 model: model.model.id().0.to_string(),
1191 }),
1192 completion_mode: Some(this.completion_mode),
1193 tool_use_limit_reached: this.tool_use_limit_reached,
1194 profile: Some(this.profile.id().clone()),
1195 })
1196 })
1197 }
1198
1199 pub fn remaining_turns(&self) -> u32 {
1200 self.remaining_turns
1201 }
1202
1203 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1204 self.remaining_turns = remaining_turns;
1205 }
1206
1207 pub fn send_to_model(
1208 &mut self,
1209 model: Arc<dyn LanguageModel>,
1210 intent: CompletionIntent,
1211 window: Option<AnyWindowHandle>,
1212 cx: &mut Context<Self>,
1213 ) {
1214 if self.remaining_turns == 0 {
1215 return;
1216 }
1217
1218 self.remaining_turns -= 1;
1219
1220 let request = self.to_completion_request(model.clone(), intent, cx);
1221
1222 self.stream_completion(request, model, window, cx);
1223 }
1224
1225 pub fn used_tools_since_last_user_message(&self) -> bool {
1226 for message in self.messages.iter().rev() {
1227 if self.tool_use.message_has_tool_results(message.id) {
1228 return true;
1229 } else if message.role == Role::User {
1230 return false;
1231 }
1232 }
1233
1234 false
1235 }
1236
1237 pub fn to_completion_request(
1238 &self,
1239 model: Arc<dyn LanguageModel>,
1240 intent: CompletionIntent,
1241 cx: &mut Context<Self>,
1242 ) -> LanguageModelRequest {
1243 let mut request = LanguageModelRequest {
1244 thread_id: Some(self.id.to_string()),
1245 prompt_id: Some(self.last_prompt_id.to_string()),
1246 intent: Some(intent),
1247 mode: None,
1248 messages: vec![],
1249 tools: Vec::new(),
1250 tool_choice: None,
1251 stop: Vec::new(),
1252 temperature: AgentSettings::temperature_for_model(&model, cx),
1253 };
1254
1255 let available_tools = self.available_tools(cx, model.clone());
1256 let available_tool_names = available_tools
1257 .iter()
1258 .map(|tool| tool.name.clone())
1259 .collect();
1260
1261 let model_context = &ModelContext {
1262 available_tools: available_tool_names,
1263 };
1264
1265 if let Some(project_context) = self.project_context.borrow().as_ref() {
1266 match self
1267 .prompt_builder
1268 .generate_assistant_system_prompt(project_context, model_context)
1269 {
1270 Err(err) => {
1271 let message = format!("{err:?}").into();
1272 log::error!("{message}");
1273 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1274 header: "Error generating system prompt".into(),
1275 message,
1276 }));
1277 }
1278 Ok(system_prompt) => {
1279 request.messages.push(LanguageModelRequestMessage {
1280 role: Role::System,
1281 content: vec![MessageContent::Text(system_prompt)],
1282 cache: true,
1283 });
1284 }
1285 }
1286 } else {
1287 let message = "Context for system prompt unexpectedly not ready.".into();
1288 log::error!("{message}");
1289 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1290 header: "Error generating system prompt".into(),
1291 message,
1292 }));
1293 }
1294
1295 let mut message_ix_to_cache = None;
1296 for message in &self.messages {
1297 let mut request_message = LanguageModelRequestMessage {
1298 role: message.role,
1299 content: Vec::new(),
1300 cache: false,
1301 };
1302
1303 message
1304 .loaded_context
1305 .add_to_request_message(&mut request_message);
1306
1307 for segment in &message.segments {
1308 match segment {
1309 MessageSegment::Text(text) => {
1310 if !text.is_empty() {
1311 request_message
1312 .content
1313 .push(MessageContent::Text(text.into()));
1314 }
1315 }
1316 MessageSegment::Thinking { text, signature } => {
1317 if !text.is_empty() {
1318 request_message.content.push(MessageContent::Thinking {
1319 text: text.into(),
1320 signature: signature.clone(),
1321 });
1322 }
1323 }
1324 MessageSegment::RedactedThinking(data) => {
1325 request_message
1326 .content
1327 .push(MessageContent::RedactedThinking(data.clone()));
1328 }
1329 };
1330 }
1331
1332 let mut cache_message = true;
1333 let mut tool_results_message = LanguageModelRequestMessage {
1334 role: Role::User,
1335 content: Vec::new(),
1336 cache: false,
1337 };
1338 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1339 if let Some(tool_result) = tool_result {
1340 request_message
1341 .content
1342 .push(MessageContent::ToolUse(tool_use.clone()));
1343 tool_results_message
1344 .content
1345 .push(MessageContent::ToolResult(LanguageModelToolResult {
1346 tool_use_id: tool_use.id.clone(),
1347 tool_name: tool_result.tool_name.clone(),
1348 is_error: tool_result.is_error,
1349 content: if tool_result.content.is_empty() {
1350 // Surprisingly, the API fails if we return an empty string here.
1351 // It thinks we are sending a tool use without a tool result.
1352 "<Tool returned an empty string>".into()
1353 } else {
1354 tool_result.content.clone()
1355 },
1356 output: None,
1357 }));
1358 } else {
1359 cache_message = false;
1360 log::debug!(
1361 "skipped tool use {:?} because it is still pending",
1362 tool_use
1363 );
1364 }
1365 }
1366
1367 if cache_message {
1368 message_ix_to_cache = Some(request.messages.len());
1369 }
1370 request.messages.push(request_message);
1371
1372 if !tool_results_message.content.is_empty() {
1373 if cache_message {
1374 message_ix_to_cache = Some(request.messages.len());
1375 }
1376 request.messages.push(tool_results_message);
1377 }
1378 }
1379
1380 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1381 if let Some(message_ix_to_cache) = message_ix_to_cache {
1382 request.messages[message_ix_to_cache].cache = true;
1383 }
1384
1385 request.tools = available_tools;
1386 request.mode = if model.supports_max_mode() {
1387 Some(self.completion_mode.into())
1388 } else {
1389 Some(CompletionMode::Normal.into())
1390 };
1391
1392 request
1393 }
1394
1395 fn to_summarize_request(
1396 &self,
1397 model: &Arc<dyn LanguageModel>,
1398 intent: CompletionIntent,
1399 added_user_message: String,
1400 cx: &App,
1401 ) -> LanguageModelRequest {
1402 let mut request = LanguageModelRequest {
1403 thread_id: None,
1404 prompt_id: None,
1405 intent: Some(intent),
1406 mode: None,
1407 messages: vec![],
1408 tools: Vec::new(),
1409 tool_choice: None,
1410 stop: Vec::new(),
1411 temperature: AgentSettings::temperature_for_model(model, cx),
1412 };
1413
1414 for message in &self.messages {
1415 let mut request_message = LanguageModelRequestMessage {
1416 role: message.role,
1417 content: Vec::new(),
1418 cache: false,
1419 };
1420
1421 for segment in &message.segments {
1422 match segment {
1423 MessageSegment::Text(text) => request_message
1424 .content
1425 .push(MessageContent::Text(text.clone())),
1426 MessageSegment::Thinking { .. } => {}
1427 MessageSegment::RedactedThinking(_) => {}
1428 }
1429 }
1430
1431 if request_message.content.is_empty() {
1432 continue;
1433 }
1434
1435 request.messages.push(request_message);
1436 }
1437
1438 request.messages.push(LanguageModelRequestMessage {
1439 role: Role::User,
1440 content: vec![MessageContent::Text(added_user_message)],
1441 cache: false,
1442 });
1443
1444 request
1445 }
1446
1447 pub fn stream_completion(
1448 &mut self,
1449 request: LanguageModelRequest,
1450 model: Arc<dyn LanguageModel>,
1451 window: Option<AnyWindowHandle>,
1452 cx: &mut Context<Self>,
1453 ) {
1454 self.tool_use_limit_reached = false;
1455
1456 let pending_completion_id = post_inc(&mut self.completion_count);
1457 let mut request_callback_parameters = if self.request_callback.is_some() {
1458 Some((request.clone(), Vec::new()))
1459 } else {
1460 None
1461 };
1462 let prompt_id = self.last_prompt_id.clone();
1463 let tool_use_metadata = ToolUseMetadata {
1464 model: model.clone(),
1465 thread_id: self.id.clone(),
1466 prompt_id: prompt_id.clone(),
1467 };
1468
1469 self.last_received_chunk_at = Some(Instant::now());
1470
1471 let task = cx.spawn(async move |thread, cx| {
1472 let stream_completion_future = model.stream_completion(request, &cx);
1473 let initial_token_usage =
1474 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1475 let stream_completion = async {
1476 let mut events = stream_completion_future.await?;
1477
1478 let mut stop_reason = StopReason::EndTurn;
1479 let mut current_token_usage = TokenUsage::default();
1480
1481 thread
1482 .update(cx, |_thread, cx| {
1483 cx.emit(ThreadEvent::NewRequest);
1484 })
1485 .ok();
1486
1487 let mut request_assistant_message_id = None;
1488
1489 while let Some(event) = events.next().await {
1490 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1491 response_events
1492 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1493 }
1494
1495 thread.update(cx, |thread, cx| {
1496 let event = match event {
1497 Ok(event) => event,
1498 Err(error) => {
1499 match error {
1500 LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
1501 anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
1502 }
1503 LanguageModelCompletionError::Overloaded => {
1504 anyhow::bail!(LanguageModelKnownError::Overloaded);
1505 }
1506 LanguageModelCompletionError::ApiInternalServerError =>{
1507 anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
1508 }
1509 LanguageModelCompletionError::PromptTooLarge { tokens } => {
1510 let tokens = tokens.unwrap_or_else(|| {
1511 // We didn't get an exact token count from the API, so fall back on our estimate.
1512 thread.total_token_usage()
1513 .map(|usage| usage.total)
1514 .unwrap_or(0)
1515 // We know the context window was exceeded in practice, so if our estimate was
1516 // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
1517 .max(model.max_token_count().saturating_add(1))
1518 });
1519
1520 anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
1521 }
1522 LanguageModelCompletionError::ApiReadResponseError(io_error) => {
1523 anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
1524 }
1525 LanguageModelCompletionError::UnknownResponseFormat(error) => {
1526 anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
1527 }
1528 LanguageModelCompletionError::HttpResponseError { status, ref body } => {
1529 if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
1530 anyhow::bail!(known_error);
1531 } else {
1532 return Err(error.into());
1533 }
1534 }
1535 LanguageModelCompletionError::DeserializeResponse(error) => {
1536 anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
1537 }
1538 LanguageModelCompletionError::BadInputJson {
1539 id,
1540 tool_name,
1541 raw_input: invalid_input_json,
1542 json_parse_error,
1543 } => {
1544 thread.receive_invalid_tool_json(
1545 id,
1546 tool_name,
1547 invalid_input_json,
1548 json_parse_error,
1549 window,
1550 cx,
1551 );
1552 return Ok(());
1553 }
1554 // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
1555 err @ LanguageModelCompletionError::BadRequestFormat |
1556 err @ LanguageModelCompletionError::AuthenticationError |
1557 err @ LanguageModelCompletionError::PermissionError |
1558 err @ LanguageModelCompletionError::ApiEndpointNotFound |
1559 err @ LanguageModelCompletionError::SerializeRequest(_) |
1560 err @ LanguageModelCompletionError::BuildRequestBody(_) |
1561 err @ LanguageModelCompletionError::HttpSend(_) => {
1562 anyhow::bail!(err);
1563 }
1564 LanguageModelCompletionError::Other(error) => {
1565 return Err(error);
1566 }
1567 }
1568 }
1569 };
1570
1571 match event {
1572 LanguageModelCompletionEvent::StartMessage { .. } => {
1573 request_assistant_message_id =
1574 Some(thread.insert_assistant_message(
1575 vec![MessageSegment::Text(String::new())],
1576 cx,
1577 ));
1578 }
1579 LanguageModelCompletionEvent::Stop(reason) => {
1580 stop_reason = reason;
1581 }
1582 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1583 thread.update_token_usage_at_last_message(token_usage);
1584 thread.cumulative_token_usage = thread.cumulative_token_usage
1585 + token_usage
1586 - current_token_usage;
1587 current_token_usage = token_usage;
1588 }
1589 LanguageModelCompletionEvent::Text(chunk) => {
1590 thread.received_chunk();
1591
1592 cx.emit(ThreadEvent::ReceivedTextChunk);
1593 if let Some(last_message) = thread.messages.last_mut() {
1594 if last_message.role == Role::Assistant
1595 && !thread.tool_use.has_tool_results(last_message.id)
1596 {
1597 last_message.push_text(&chunk);
1598 cx.emit(ThreadEvent::StreamedAssistantText(
1599 last_message.id,
1600 chunk,
1601 ));
1602 } else {
1603 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1604 // of a new Assistant response.
1605 //
1606 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1607 // will result in duplicating the text of the chunk in the rendered Markdown.
1608 request_assistant_message_id =
1609 Some(thread.insert_assistant_message(
1610 vec![MessageSegment::Text(chunk.to_string())],
1611 cx,
1612 ));
1613 };
1614 }
1615 }
1616 LanguageModelCompletionEvent::Thinking {
1617 text: chunk,
1618 signature,
1619 } => {
1620 thread.received_chunk();
1621
1622 if let Some(last_message) = thread.messages.last_mut() {
1623 if last_message.role == Role::Assistant
1624 && !thread.tool_use.has_tool_results(last_message.id)
1625 {
1626 last_message.push_thinking(&chunk, signature);
1627 cx.emit(ThreadEvent::StreamedAssistantThinking(
1628 last_message.id,
1629 chunk,
1630 ));
1631 } else {
1632 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1633 // of a new Assistant response.
1634 //
1635 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1636 // will result in duplicating the text of the chunk in the rendered Markdown.
1637 request_assistant_message_id =
1638 Some(thread.insert_assistant_message(
1639 vec![MessageSegment::Thinking {
1640 text: chunk.to_string(),
1641 signature,
1642 }],
1643 cx,
1644 ));
1645 };
1646 }
1647 }
1648 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1649 let last_assistant_message_id = request_assistant_message_id
1650 .unwrap_or_else(|| {
1651 let new_assistant_message_id =
1652 thread.insert_assistant_message(vec![], cx);
1653 request_assistant_message_id =
1654 Some(new_assistant_message_id);
1655 new_assistant_message_id
1656 });
1657
1658 let tool_use_id = tool_use.id.clone();
1659 let streamed_input = if tool_use.is_input_complete {
1660 None
1661 } else {
1662 Some((&tool_use.input).clone())
1663 };
1664
1665 let ui_text = thread.tool_use.request_tool_use(
1666 last_assistant_message_id,
1667 tool_use,
1668 tool_use_metadata.clone(),
1669 cx,
1670 );
1671
1672 if let Some(input) = streamed_input {
1673 cx.emit(ThreadEvent::StreamedToolUse {
1674 tool_use_id,
1675 ui_text,
1676 input,
1677 });
1678 }
1679 }
1680 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1681 if let Some(completion) = thread
1682 .pending_completions
1683 .iter_mut()
1684 .find(|completion| completion.id == pending_completion_id)
1685 {
1686 match status_update {
1687 CompletionRequestStatus::Queued {
1688 position,
1689 } => {
1690 completion.queue_state = QueueState::Queued { position };
1691 }
1692 CompletionRequestStatus::Started => {
1693 completion.queue_state = QueueState::Started;
1694 }
1695 CompletionRequestStatus::Failed {
1696 code, message, request_id
1697 } => {
1698 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1699 }
1700 CompletionRequestStatus::UsageUpdated {
1701 amount, limit
1702 } => {
1703 thread.update_model_request_usage(amount as u32, limit, cx);
1704 }
1705 CompletionRequestStatus::ToolUseLimitReached => {
1706 thread.tool_use_limit_reached = true;
1707 cx.emit(ThreadEvent::ToolUseLimitReached);
1708 }
1709 }
1710 }
1711 }
1712 }
1713
1714 thread.touch_updated_at();
1715 cx.emit(ThreadEvent::StreamedCompletion);
1716 cx.notify();
1717
1718 thread.auto_capture_telemetry(cx);
1719 Ok(())
1720 })??;
1721
1722 smol::future::yield_now().await;
1723 }
1724
1725 thread.update(cx, |thread, cx| {
1726 thread.last_received_chunk_at = None;
1727 thread
1728 .pending_completions
1729 .retain(|completion| completion.id != pending_completion_id);
1730
1731 // If there is a response without tool use, summarize the message. Otherwise,
1732 // allow two tool uses before summarizing.
1733 if matches!(thread.summary, ThreadSummary::Pending)
1734 && thread.messages.len() >= 2
1735 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1736 {
1737 thread.summarize(cx);
1738 }
1739 })?;
1740
1741 anyhow::Ok(stop_reason)
1742 };
1743
1744 let result = stream_completion.await;
1745
1746 thread
1747 .update(cx, |thread, cx| {
1748 thread.finalize_pending_checkpoint(cx);
1749 match result.as_ref() {
1750 Ok(stop_reason) => match stop_reason {
1751 StopReason::ToolUse => {
1752 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1753 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1754 }
1755 StopReason::EndTurn | StopReason::MaxTokens => {
1756 thread.project.update(cx, |project, cx| {
1757 project.set_agent_location(None, cx);
1758 });
1759 }
1760 StopReason::Refusal => {
1761 thread.project.update(cx, |project, cx| {
1762 project.set_agent_location(None, cx);
1763 });
1764
1765 // Remove the turn that was refused.
1766 //
1767 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1768 {
1769 let mut messages_to_remove = Vec::new();
1770
1771 for (ix, message) in thread.messages.iter().enumerate().rev() {
1772 messages_to_remove.push(message.id);
1773
1774 if message.role == Role::User {
1775 if ix == 0 {
1776 break;
1777 }
1778
1779 if let Some(prev_message) = thread.messages.get(ix - 1) {
1780 if prev_message.role == Role::Assistant {
1781 break;
1782 }
1783 }
1784 }
1785 }
1786
1787 for message_id in messages_to_remove {
1788 thread.delete_message(message_id, cx);
1789 }
1790 }
1791
1792 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1793 header: "Language model refusal".into(),
1794 message: "Model refused to generate content for safety reasons.".into(),
1795 }));
1796 }
1797 },
1798 Err(error) => {
1799 thread.project.update(cx, |project, cx| {
1800 project.set_agent_location(None, cx);
1801 });
1802
1803 fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
1804 let error_message = error
1805 .chain()
1806 .map(|err| err.to_string())
1807 .collect::<Vec<_>>()
1808 .join("\n");
1809 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1810 header: "Error interacting with language model".into(),
1811 message: SharedString::from(error_message.clone()),
1812 }));
1813 }
1814
1815 if error.is::<PaymentRequiredError>() {
1816 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1817 } else if let Some(error) =
1818 error.downcast_ref::<ModelRequestLimitReachedError>()
1819 {
1820 cx.emit(ThreadEvent::ShowError(
1821 ThreadError::ModelRequestLimitReached { plan: error.plan },
1822 ));
1823 } else if let Some(known_error) =
1824 error.downcast_ref::<LanguageModelKnownError>()
1825 {
1826 match known_error {
1827 LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
1828 thread.exceeded_window_error = Some(ExceededWindowError {
1829 model_id: model.id(),
1830 token_count: *tokens,
1831 });
1832 cx.notify();
1833 }
1834 LanguageModelKnownError::RateLimitExceeded { .. } => {
1835 // In the future we will report the error to the user, wait retry_after, and then retry.
1836 emit_generic_error(error, cx);
1837 }
1838 LanguageModelKnownError::Overloaded => {
1839 // In the future we will wait and then retry, up to N times.
1840 emit_generic_error(error, cx);
1841 }
1842 LanguageModelKnownError::ApiInternalServerError => {
1843 // In the future we will retry the request, but only once.
1844 emit_generic_error(error, cx);
1845 }
1846 LanguageModelKnownError::ReadResponseError(_) |
1847 LanguageModelKnownError::DeserializeResponse(_) |
1848 LanguageModelKnownError::UnknownResponseFormat(_) => {
1849 // In the future we will attempt to re-roll response, but only once
1850 emit_generic_error(error, cx);
1851 }
1852 }
1853 } else {
1854 emit_generic_error(error, cx);
1855 }
1856
1857 thread.cancel_last_completion(window, cx);
1858 }
1859 }
1860
1861 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1862
1863 if let Some((request_callback, (request, response_events))) = thread
1864 .request_callback
1865 .as_mut()
1866 .zip(request_callback_parameters.as_ref())
1867 {
1868 request_callback(request, response_events);
1869 }
1870
1871 thread.auto_capture_telemetry(cx);
1872
1873 if let Ok(initial_usage) = initial_token_usage {
1874 let usage = thread.cumulative_token_usage - initial_usage;
1875
1876 telemetry::event!(
1877 "Assistant Thread Completion",
1878 thread_id = thread.id().to_string(),
1879 prompt_id = prompt_id,
1880 model = model.telemetry_id(),
1881 model_provider = model.provider_id().to_string(),
1882 input_tokens = usage.input_tokens,
1883 output_tokens = usage.output_tokens,
1884 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1885 cache_read_input_tokens = usage.cache_read_input_tokens,
1886 );
1887 }
1888 })
1889 .ok();
1890 });
1891
1892 self.pending_completions.push(PendingCompletion {
1893 id: pending_completion_id,
1894 queue_state: QueueState::Sending,
1895 _task: task,
1896 });
1897 }
1898
1899 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1900 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1901 println!("No thread summary model");
1902 return;
1903 };
1904
1905 if !model.provider.is_authenticated(cx) {
1906 return;
1907 }
1908
1909 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1910
1911 let request = self.to_summarize_request(
1912 &model.model,
1913 CompletionIntent::ThreadSummarization,
1914 added_user_message.into(),
1915 cx,
1916 );
1917
1918 self.summary = ThreadSummary::Generating;
1919
1920 self.pending_summary = cx.spawn(async move |this, cx| {
1921 let result = async {
1922 let mut messages = model.model.stream_completion(request, &cx).await?;
1923
1924 let mut new_summary = String::new();
1925 while let Some(event) = messages.next().await {
1926 let Ok(event) = event else {
1927 continue;
1928 };
1929 let text = match event {
1930 LanguageModelCompletionEvent::Text(text) => text,
1931 LanguageModelCompletionEvent::StatusUpdate(
1932 CompletionRequestStatus::UsageUpdated { amount, limit },
1933 ) => {
1934 this.update(cx, |thread, cx| {
1935 thread.update_model_request_usage(amount as u32, limit, cx);
1936 })?;
1937 continue;
1938 }
1939 _ => continue,
1940 };
1941
1942 let mut lines = text.lines();
1943 new_summary.extend(lines.next());
1944
1945 // Stop if the LLM generated multiple lines.
1946 if lines.next().is_some() {
1947 break;
1948 }
1949 }
1950
1951 anyhow::Ok(new_summary)
1952 }
1953 .await;
1954
1955 this.update(cx, |this, cx| {
1956 match result {
1957 Ok(new_summary) => {
1958 if new_summary.is_empty() {
1959 this.summary = ThreadSummary::Error;
1960 } else {
1961 this.summary = ThreadSummary::Ready(new_summary.into());
1962 }
1963 }
1964 Err(err) => {
1965 this.summary = ThreadSummary::Error;
1966 log::error!("Failed to generate thread summary: {}", err);
1967 }
1968 }
1969 cx.emit(ThreadEvent::SummaryGenerated);
1970 })
1971 .log_err()?;
1972
1973 Some(())
1974 });
1975 }
1976
1977 pub fn start_generating_detailed_summary_if_needed(
1978 &mut self,
1979 thread_store: WeakEntity<ThreadStore>,
1980 cx: &mut Context<Self>,
1981 ) {
1982 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1983 return;
1984 };
1985
1986 match &*self.detailed_summary_rx.borrow() {
1987 DetailedSummaryState::Generating { message_id, .. }
1988 | DetailedSummaryState::Generated { message_id, .. }
1989 if *message_id == last_message_id =>
1990 {
1991 // Already up-to-date
1992 return;
1993 }
1994 _ => {}
1995 }
1996
1997 let Some(ConfiguredModel { model, provider }) =
1998 LanguageModelRegistry::read_global(cx).thread_summary_model()
1999 else {
2000 return;
2001 };
2002
2003 if !provider.is_authenticated(cx) {
2004 return;
2005 }
2006
2007 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2008
2009 let request = self.to_summarize_request(
2010 &model,
2011 CompletionIntent::ThreadContextSummarization,
2012 added_user_message.into(),
2013 cx,
2014 );
2015
2016 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2017 message_id: last_message_id,
2018 };
2019
2020 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2021 // be better to allow the old task to complete, but this would require logic for choosing
2022 // which result to prefer (the old task could complete after the new one, resulting in a
2023 // stale summary).
2024 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2025 let stream = model.stream_completion_text(request, &cx);
2026 let Some(mut messages) = stream.await.log_err() else {
2027 thread
2028 .update(cx, |thread, _cx| {
2029 *thread.detailed_summary_tx.borrow_mut() =
2030 DetailedSummaryState::NotGenerated;
2031 })
2032 .ok()?;
2033 return None;
2034 };
2035
2036 let mut new_detailed_summary = String::new();
2037
2038 while let Some(chunk) = messages.stream.next().await {
2039 if let Some(chunk) = chunk.log_err() {
2040 new_detailed_summary.push_str(&chunk);
2041 }
2042 }
2043
2044 thread
2045 .update(cx, |thread, _cx| {
2046 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2047 text: new_detailed_summary.into(),
2048 message_id: last_message_id,
2049 };
2050 })
2051 .ok()?;
2052
2053 // Save thread so its summary can be reused later
2054 if let Some(thread) = thread.upgrade() {
2055 if let Ok(Ok(save_task)) = cx.update(|cx| {
2056 thread_store
2057 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2058 }) {
2059 save_task.await.log_err();
2060 }
2061 }
2062
2063 Some(())
2064 });
2065 }
2066
2067 pub async fn wait_for_detailed_summary_or_text(
2068 this: &Entity<Self>,
2069 cx: &mut AsyncApp,
2070 ) -> Option<SharedString> {
2071 let mut detailed_summary_rx = this
2072 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2073 .ok()?;
2074 loop {
2075 match detailed_summary_rx.recv().await? {
2076 DetailedSummaryState::Generating { .. } => {}
2077 DetailedSummaryState::NotGenerated => {
2078 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2079 }
2080 DetailedSummaryState::Generated { text, .. } => return Some(text),
2081 }
2082 }
2083 }
2084
2085 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2086 self.detailed_summary_rx
2087 .borrow()
2088 .text()
2089 .unwrap_or_else(|| self.text().into())
2090 }
2091
2092 pub fn is_generating_detailed_summary(&self) -> bool {
2093 matches!(
2094 &*self.detailed_summary_rx.borrow(),
2095 DetailedSummaryState::Generating { .. }
2096 )
2097 }
2098
2099 pub fn use_pending_tools(
2100 &mut self,
2101 window: Option<AnyWindowHandle>,
2102 cx: &mut Context<Self>,
2103 model: Arc<dyn LanguageModel>,
2104 ) -> Vec<PendingToolUse> {
2105 self.auto_capture_telemetry(cx);
2106 let request =
2107 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2108 let pending_tool_uses = self
2109 .tool_use
2110 .pending_tool_uses()
2111 .into_iter()
2112 .filter(|tool_use| tool_use.status.is_idle())
2113 .cloned()
2114 .collect::<Vec<_>>();
2115
2116 for tool_use in pending_tool_uses.iter() {
2117 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2118 if tool.needs_confirmation(&tool_use.input, cx)
2119 && !AgentSettings::get_global(cx).always_allow_tool_actions
2120 {
2121 self.tool_use.confirm_tool_use(
2122 tool_use.id.clone(),
2123 tool_use.ui_text.clone(),
2124 tool_use.input.clone(),
2125 request.clone(),
2126 tool,
2127 );
2128 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2129 } else {
2130 self.run_tool(
2131 tool_use.id.clone(),
2132 tool_use.ui_text.clone(),
2133 tool_use.input.clone(),
2134 request.clone(),
2135 tool,
2136 model.clone(),
2137 window,
2138 cx,
2139 );
2140 }
2141 } else {
2142 self.handle_hallucinated_tool_use(
2143 tool_use.id.clone(),
2144 tool_use.name.clone(),
2145 window,
2146 cx,
2147 );
2148 }
2149 }
2150
2151 pending_tool_uses
2152 }
2153
2154 pub fn handle_hallucinated_tool_use(
2155 &mut self,
2156 tool_use_id: LanguageModelToolUseId,
2157 hallucinated_tool_name: Arc<str>,
2158 window: Option<AnyWindowHandle>,
2159 cx: &mut Context<Thread>,
2160 ) {
2161 let available_tools = self.profile.enabled_tools(cx);
2162
2163 let tool_list = available_tools
2164 .iter()
2165 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2166 .collect::<Vec<_>>()
2167 .join("\n");
2168
2169 let error_message = format!(
2170 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2171 hallucinated_tool_name, tool_list
2172 );
2173
2174 let pending_tool_use = self.tool_use.insert_tool_output(
2175 tool_use_id.clone(),
2176 hallucinated_tool_name,
2177 Err(anyhow!("Missing tool call: {error_message}")),
2178 self.configured_model.as_ref(),
2179 );
2180
2181 cx.emit(ThreadEvent::MissingToolUse {
2182 tool_use_id: tool_use_id.clone(),
2183 ui_text: error_message.into(),
2184 });
2185
2186 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2187 }
2188
2189 pub fn receive_invalid_tool_json(
2190 &mut self,
2191 tool_use_id: LanguageModelToolUseId,
2192 tool_name: Arc<str>,
2193 invalid_json: Arc<str>,
2194 error: String,
2195 window: Option<AnyWindowHandle>,
2196 cx: &mut Context<Thread>,
2197 ) {
2198 log::error!("The model returned invalid input JSON: {invalid_json}");
2199
2200 let pending_tool_use = self.tool_use.insert_tool_output(
2201 tool_use_id.clone(),
2202 tool_name,
2203 Err(anyhow!("Error parsing input JSON: {error}")),
2204 self.configured_model.as_ref(),
2205 );
2206 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2207 pending_tool_use.ui_text.clone()
2208 } else {
2209 log::error!(
2210 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2211 );
2212 format!("Unknown tool {}", tool_use_id).into()
2213 };
2214
2215 cx.emit(ThreadEvent::InvalidToolInput {
2216 tool_use_id: tool_use_id.clone(),
2217 ui_text,
2218 invalid_input_json: invalid_json,
2219 });
2220
2221 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2222 }
2223
2224 pub fn run_tool(
2225 &mut self,
2226 tool_use_id: LanguageModelToolUseId,
2227 ui_text: impl Into<SharedString>,
2228 input: serde_json::Value,
2229 request: Arc<LanguageModelRequest>,
2230 tool: Arc<dyn Tool>,
2231 model: Arc<dyn LanguageModel>,
2232 window: Option<AnyWindowHandle>,
2233 cx: &mut Context<Thread>,
2234 ) {
2235 let task =
2236 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2237 self.tool_use
2238 .run_pending_tool(tool_use_id, ui_text.into(), task);
2239 }
2240
2241 fn spawn_tool_use(
2242 &mut self,
2243 tool_use_id: LanguageModelToolUseId,
2244 request: Arc<LanguageModelRequest>,
2245 input: serde_json::Value,
2246 tool: Arc<dyn Tool>,
2247 model: Arc<dyn LanguageModel>,
2248 window: Option<AnyWindowHandle>,
2249 cx: &mut Context<Thread>,
2250 ) -> Task<()> {
2251 let tool_name: Arc<str> = tool.name().into();
2252
2253 let tool_result = tool.run(
2254 input,
2255 request,
2256 self.project.clone(),
2257 self.action_log.clone(),
2258 model,
2259 window,
2260 cx,
2261 );
2262
2263 // Store the card separately if it exists
2264 if let Some(card) = tool_result.card.clone() {
2265 self.tool_use
2266 .insert_tool_result_card(tool_use_id.clone(), card);
2267 }
2268
2269 cx.spawn({
2270 async move |thread: WeakEntity<Thread>, cx| {
2271 let output = tool_result.output.await;
2272
2273 thread
2274 .update(cx, |thread, cx| {
2275 let pending_tool_use = thread.tool_use.insert_tool_output(
2276 tool_use_id.clone(),
2277 tool_name,
2278 output,
2279 thread.configured_model.as_ref(),
2280 );
2281 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2282 })
2283 .ok();
2284 }
2285 })
2286 }
2287
2288 fn tool_finished(
2289 &mut self,
2290 tool_use_id: LanguageModelToolUseId,
2291 pending_tool_use: Option<PendingToolUse>,
2292 canceled: bool,
2293 window: Option<AnyWindowHandle>,
2294 cx: &mut Context<Self>,
2295 ) {
2296 if self.all_tools_finished() {
2297 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2298 if !canceled {
2299 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2300 }
2301 self.auto_capture_telemetry(cx);
2302 }
2303 }
2304
2305 cx.emit(ThreadEvent::ToolFinished {
2306 tool_use_id,
2307 pending_tool_use,
2308 });
2309 }
2310
2311 /// Cancels the last pending completion, if there are any pending.
2312 ///
2313 /// Returns whether a completion was canceled.
2314 pub fn cancel_last_completion(
2315 &mut self,
2316 window: Option<AnyWindowHandle>,
2317 cx: &mut Context<Self>,
2318 ) -> bool {
2319 let mut canceled = self.pending_completions.pop().is_some();
2320
2321 for pending_tool_use in self.tool_use.cancel_pending() {
2322 canceled = true;
2323 self.tool_finished(
2324 pending_tool_use.id.clone(),
2325 Some(pending_tool_use),
2326 true,
2327 window,
2328 cx,
2329 );
2330 }
2331
2332 if canceled {
2333 cx.emit(ThreadEvent::CompletionCanceled);
2334
2335 // When canceled, we always want to insert the checkpoint.
2336 // (We skip over finalize_pending_checkpoint, because it
2337 // would conclude we didn't have anything to insert here.)
2338 if let Some(checkpoint) = self.pending_checkpoint.take() {
2339 self.insert_checkpoint(checkpoint, cx);
2340 }
2341 } else {
2342 self.finalize_pending_checkpoint(cx);
2343 }
2344
2345 canceled
2346 }
2347
2348 /// Signals that any in-progress editing should be canceled.
2349 ///
2350 /// This method is used to notify listeners (like ActiveThread) that
2351 /// they should cancel any editing operations.
2352 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2353 cx.emit(ThreadEvent::CancelEditing);
2354 }
2355
2356 pub fn feedback(&self) -> Option<ThreadFeedback> {
2357 self.feedback
2358 }
2359
2360 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2361 self.message_feedback.get(&message_id).copied()
2362 }
2363
2364 pub fn report_message_feedback(
2365 &mut self,
2366 message_id: MessageId,
2367 feedback: ThreadFeedback,
2368 cx: &mut Context<Self>,
2369 ) -> Task<Result<()>> {
2370 if self.message_feedback.get(&message_id) == Some(&feedback) {
2371 return Task::ready(Ok(()));
2372 }
2373
2374 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2375 let serialized_thread = self.serialize(cx);
2376 let thread_id = self.id().clone();
2377 let client = self.project.read(cx).client();
2378
2379 let enabled_tool_names: Vec<String> = self
2380 .profile
2381 .enabled_tools(cx)
2382 .iter()
2383 .map(|tool| tool.name())
2384 .collect();
2385
2386 self.message_feedback.insert(message_id, feedback);
2387
2388 cx.notify();
2389
2390 let message_content = self
2391 .message(message_id)
2392 .map(|msg| msg.to_string())
2393 .unwrap_or_default();
2394
2395 cx.background_spawn(async move {
2396 let final_project_snapshot = final_project_snapshot.await;
2397 let serialized_thread = serialized_thread.await?;
2398 let thread_data =
2399 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2400
2401 let rating = match feedback {
2402 ThreadFeedback::Positive => "positive",
2403 ThreadFeedback::Negative => "negative",
2404 };
2405 telemetry::event!(
2406 "Assistant Thread Rated",
2407 rating,
2408 thread_id,
2409 enabled_tool_names,
2410 message_id = message_id.0,
2411 message_content,
2412 thread_data,
2413 final_project_snapshot
2414 );
2415 client.telemetry().flush_events().await;
2416
2417 Ok(())
2418 })
2419 }
2420
2421 pub fn report_feedback(
2422 &mut self,
2423 feedback: ThreadFeedback,
2424 cx: &mut Context<Self>,
2425 ) -> Task<Result<()>> {
2426 let last_assistant_message_id = self
2427 .messages
2428 .iter()
2429 .rev()
2430 .find(|msg| msg.role == Role::Assistant)
2431 .map(|msg| msg.id);
2432
2433 if let Some(message_id) = last_assistant_message_id {
2434 self.report_message_feedback(message_id, feedback, cx)
2435 } else {
2436 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2437 let serialized_thread = self.serialize(cx);
2438 let thread_id = self.id().clone();
2439 let client = self.project.read(cx).client();
2440 self.feedback = Some(feedback);
2441 cx.notify();
2442
2443 cx.background_spawn(async move {
2444 let final_project_snapshot = final_project_snapshot.await;
2445 let serialized_thread = serialized_thread.await?;
2446 let thread_data = serde_json::to_value(serialized_thread)
2447 .unwrap_or_else(|_| serde_json::Value::Null);
2448
2449 let rating = match feedback {
2450 ThreadFeedback::Positive => "positive",
2451 ThreadFeedback::Negative => "negative",
2452 };
2453 telemetry::event!(
2454 "Assistant Thread Rated",
2455 rating,
2456 thread_id,
2457 thread_data,
2458 final_project_snapshot
2459 );
2460 client.telemetry().flush_events().await;
2461
2462 Ok(())
2463 })
2464 }
2465 }
2466
2467 /// Create a snapshot of the current project state including git information and unsaved buffers.
2468 fn project_snapshot(
2469 project: Entity<Project>,
2470 cx: &mut Context<Self>,
2471 ) -> Task<Arc<ProjectSnapshot>> {
2472 let git_store = project.read(cx).git_store().clone();
2473 let worktree_snapshots: Vec<_> = project
2474 .read(cx)
2475 .visible_worktrees(cx)
2476 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2477 .collect();
2478
2479 cx.spawn(async move |_, cx| {
2480 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2481
2482 let mut unsaved_buffers = Vec::new();
2483 cx.update(|app_cx| {
2484 let buffer_store = project.read(app_cx).buffer_store();
2485 for buffer_handle in buffer_store.read(app_cx).buffers() {
2486 let buffer = buffer_handle.read(app_cx);
2487 if buffer.is_dirty() {
2488 if let Some(file) = buffer.file() {
2489 let path = file.path().to_string_lossy().to_string();
2490 unsaved_buffers.push(path);
2491 }
2492 }
2493 }
2494 })
2495 .ok();
2496
2497 Arc::new(ProjectSnapshot {
2498 worktree_snapshots,
2499 unsaved_buffer_paths: unsaved_buffers,
2500 timestamp: Utc::now(),
2501 })
2502 })
2503 }
2504
2505 fn worktree_snapshot(
2506 worktree: Entity<project::Worktree>,
2507 git_store: Entity<GitStore>,
2508 cx: &App,
2509 ) -> Task<WorktreeSnapshot> {
2510 cx.spawn(async move |cx| {
2511 // Get worktree path and snapshot
2512 let worktree_info = cx.update(|app_cx| {
2513 let worktree = worktree.read(app_cx);
2514 let path = worktree.abs_path().to_string_lossy().to_string();
2515 let snapshot = worktree.snapshot();
2516 (path, snapshot)
2517 });
2518
2519 let Ok((worktree_path, _snapshot)) = worktree_info else {
2520 return WorktreeSnapshot {
2521 worktree_path: String::new(),
2522 git_state: None,
2523 };
2524 };
2525
2526 let git_state = git_store
2527 .update(cx, |git_store, cx| {
2528 git_store
2529 .repositories()
2530 .values()
2531 .find(|repo| {
2532 repo.read(cx)
2533 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2534 .is_some()
2535 })
2536 .cloned()
2537 })
2538 .ok()
2539 .flatten()
2540 .map(|repo| {
2541 repo.update(cx, |repo, _| {
2542 let current_branch =
2543 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2544 repo.send_job(None, |state, _| async move {
2545 let RepositoryState::Local { backend, .. } = state else {
2546 return GitState {
2547 remote_url: None,
2548 head_sha: None,
2549 current_branch,
2550 diff: None,
2551 };
2552 };
2553
2554 let remote_url = backend.remote_url("origin");
2555 let head_sha = backend.head_sha().await;
2556 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2557
2558 GitState {
2559 remote_url,
2560 head_sha,
2561 current_branch,
2562 diff,
2563 }
2564 })
2565 })
2566 });
2567
2568 let git_state = match git_state {
2569 Some(git_state) => match git_state.ok() {
2570 Some(git_state) => git_state.await.ok(),
2571 None => None,
2572 },
2573 None => None,
2574 };
2575
2576 WorktreeSnapshot {
2577 worktree_path,
2578 git_state,
2579 }
2580 })
2581 }
2582
2583 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2584 let mut markdown = Vec::new();
2585
2586 let summary = self.summary().or_default();
2587 writeln!(markdown, "# {summary}\n")?;
2588
2589 for message in self.messages() {
2590 writeln!(
2591 markdown,
2592 "## {role}\n",
2593 role = match message.role {
2594 Role::User => "User",
2595 Role::Assistant => "Agent",
2596 Role::System => "System",
2597 }
2598 )?;
2599
2600 if !message.loaded_context.text.is_empty() {
2601 writeln!(markdown, "{}", message.loaded_context.text)?;
2602 }
2603
2604 if !message.loaded_context.images.is_empty() {
2605 writeln!(
2606 markdown,
2607 "\n{} images attached as context.\n",
2608 message.loaded_context.images.len()
2609 )?;
2610 }
2611
2612 for segment in &message.segments {
2613 match segment {
2614 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2615 MessageSegment::Thinking { text, .. } => {
2616 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2617 }
2618 MessageSegment::RedactedThinking(_) => {}
2619 }
2620 }
2621
2622 for tool_use in self.tool_uses_for_message(message.id, cx) {
2623 writeln!(
2624 markdown,
2625 "**Use Tool: {} ({})**",
2626 tool_use.name, tool_use.id
2627 )?;
2628 writeln!(markdown, "```json")?;
2629 writeln!(
2630 markdown,
2631 "{}",
2632 serde_json::to_string_pretty(&tool_use.input)?
2633 )?;
2634 writeln!(markdown, "```")?;
2635 }
2636
2637 for tool_result in self.tool_results_for_message(message.id) {
2638 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2639 if tool_result.is_error {
2640 write!(markdown, " (Error)")?;
2641 }
2642
2643 writeln!(markdown, "**\n")?;
2644 match &tool_result.content {
2645 LanguageModelToolResultContent::Text(text) => {
2646 writeln!(markdown, "{text}")?;
2647 }
2648 LanguageModelToolResultContent::Image(image) => {
2649 writeln!(markdown, "", image.source)?;
2650 }
2651 }
2652
2653 if let Some(output) = tool_result.output.as_ref() {
2654 writeln!(
2655 markdown,
2656 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2657 serde_json::to_string_pretty(output)?
2658 )?;
2659 }
2660 }
2661 }
2662
2663 Ok(String::from_utf8_lossy(&markdown).to_string())
2664 }
2665
2666 pub fn keep_edits_in_range(
2667 &mut self,
2668 buffer: Entity<language::Buffer>,
2669 buffer_range: Range<language::Anchor>,
2670 cx: &mut Context<Self>,
2671 ) {
2672 self.action_log.update(cx, |action_log, cx| {
2673 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2674 });
2675 }
2676
2677 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2678 self.action_log
2679 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2680 }
2681
2682 pub fn reject_edits_in_ranges(
2683 &mut self,
2684 buffer: Entity<language::Buffer>,
2685 buffer_ranges: Vec<Range<language::Anchor>>,
2686 cx: &mut Context<Self>,
2687 ) -> Task<Result<()>> {
2688 self.action_log.update(cx, |action_log, cx| {
2689 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2690 })
2691 }
2692
2693 pub fn action_log(&self) -> &Entity<ActionLog> {
2694 &self.action_log
2695 }
2696
2697 pub fn project(&self) -> &Entity<Project> {
2698 &self.project
2699 }
2700
2701 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2702 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2703 return;
2704 }
2705
2706 let now = Instant::now();
2707 if let Some(last) = self.last_auto_capture_at {
2708 if now.duration_since(last).as_secs() < 10 {
2709 return;
2710 }
2711 }
2712
2713 self.last_auto_capture_at = Some(now);
2714
2715 let thread_id = self.id().clone();
2716 let github_login = self
2717 .project
2718 .read(cx)
2719 .user_store()
2720 .read(cx)
2721 .current_user()
2722 .map(|user| user.github_login.clone());
2723 let client = self.project.read(cx).client();
2724 let serialize_task = self.serialize(cx);
2725
2726 cx.background_executor()
2727 .spawn(async move {
2728 if let Ok(serialized_thread) = serialize_task.await {
2729 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2730 telemetry::event!(
2731 "Agent Thread Auto-Captured",
2732 thread_id = thread_id.to_string(),
2733 thread_data = thread_data,
2734 auto_capture_reason = "tracked_user",
2735 github_login = github_login
2736 );
2737
2738 client.telemetry().flush_events().await;
2739 }
2740 }
2741 })
2742 .detach();
2743 }
2744
2745 pub fn cumulative_token_usage(&self) -> TokenUsage {
2746 self.cumulative_token_usage
2747 }
2748
2749 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2750 let Some(model) = self.configured_model.as_ref() else {
2751 return TotalTokenUsage::default();
2752 };
2753
2754 let max = model.model.max_token_count();
2755
2756 let index = self
2757 .messages
2758 .iter()
2759 .position(|msg| msg.id == message_id)
2760 .unwrap_or(0);
2761
2762 if index == 0 {
2763 return TotalTokenUsage { total: 0, max };
2764 }
2765
2766 let token_usage = &self
2767 .request_token_usage
2768 .get(index - 1)
2769 .cloned()
2770 .unwrap_or_default();
2771
2772 TotalTokenUsage {
2773 total: token_usage.total_tokens(),
2774 max,
2775 }
2776 }
2777
2778 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2779 let model = self.configured_model.as_ref()?;
2780
2781 let max = model.model.max_token_count();
2782
2783 if let Some(exceeded_error) = &self.exceeded_window_error {
2784 if model.model.id() == exceeded_error.model_id {
2785 return Some(TotalTokenUsage {
2786 total: exceeded_error.token_count,
2787 max,
2788 });
2789 }
2790 }
2791
2792 let total = self
2793 .token_usage_at_last_message()
2794 .unwrap_or_default()
2795 .total_tokens();
2796
2797 Some(TotalTokenUsage { total, max })
2798 }
2799
2800 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2801 self.request_token_usage
2802 .get(self.messages.len().saturating_sub(1))
2803 .or_else(|| self.request_token_usage.last())
2804 .cloned()
2805 }
2806
2807 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2808 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2809 self.request_token_usage
2810 .resize(self.messages.len(), placeholder);
2811
2812 if let Some(last) = self.request_token_usage.last_mut() {
2813 *last = token_usage;
2814 }
2815 }
2816
2817 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
2818 self.project.update(cx, |project, cx| {
2819 project.user_store().update(cx, |user_store, cx| {
2820 user_store.update_model_request_usage(
2821 ModelRequestUsage(RequestUsage {
2822 amount: amount as i32,
2823 limit,
2824 }),
2825 cx,
2826 )
2827 })
2828 });
2829 }
2830
2831 pub fn deny_tool_use(
2832 &mut self,
2833 tool_use_id: LanguageModelToolUseId,
2834 tool_name: Arc<str>,
2835 window: Option<AnyWindowHandle>,
2836 cx: &mut Context<Self>,
2837 ) {
2838 let err = Err(anyhow::anyhow!(
2839 "Permission to run tool action denied by user"
2840 ));
2841
2842 self.tool_use.insert_tool_output(
2843 tool_use_id.clone(),
2844 tool_name,
2845 err,
2846 self.configured_model.as_ref(),
2847 );
2848 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2849 }
2850}
2851
2852#[derive(Debug, Clone, Error)]
2853pub enum ThreadError {
2854 #[error("Payment required")]
2855 PaymentRequired,
2856 #[error("Model request limit reached")]
2857 ModelRequestLimitReached { plan: Plan },
2858 #[error("Message {header}: {message}")]
2859 Message {
2860 header: SharedString,
2861 message: SharedString,
2862 },
2863}
2864
2865#[derive(Debug, Clone)]
2866pub enum ThreadEvent {
2867 ShowError(ThreadError),
2868 StreamedCompletion,
2869 ReceivedTextChunk,
2870 NewRequest,
2871 StreamedAssistantText(MessageId, String),
2872 StreamedAssistantThinking(MessageId, String),
2873 StreamedToolUse {
2874 tool_use_id: LanguageModelToolUseId,
2875 ui_text: Arc<str>,
2876 input: serde_json::Value,
2877 },
2878 MissingToolUse {
2879 tool_use_id: LanguageModelToolUseId,
2880 ui_text: Arc<str>,
2881 },
2882 InvalidToolInput {
2883 tool_use_id: LanguageModelToolUseId,
2884 ui_text: Arc<str>,
2885 invalid_input_json: Arc<str>,
2886 },
2887 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2888 MessageAdded(MessageId),
2889 MessageEdited(MessageId),
2890 MessageDeleted(MessageId),
2891 SummaryGenerated,
2892 SummaryChanged,
2893 UsePendingTools {
2894 tool_uses: Vec<PendingToolUse>,
2895 },
2896 ToolFinished {
2897 #[allow(unused)]
2898 tool_use_id: LanguageModelToolUseId,
2899 /// The pending tool use that corresponds to this tool.
2900 pending_tool_use: Option<PendingToolUse>,
2901 },
2902 CheckpointChanged,
2903 ToolConfirmationNeeded,
2904 ToolUseLimitReached,
2905 CancelEditing,
2906 CompletionCanceled,
2907 ProfileChanged,
2908}
2909
2910impl EventEmitter<ThreadEvent> for Thread {}
2911
2912struct PendingCompletion {
2913 id: usize,
2914 queue_state: QueueState,
2915 _task: Task<()>,
2916}
2917
2918/// Resolves tool name conflicts by ensuring all tool names are unique.
2919///
2920/// When multiple tools have the same name, this function applies the following rules:
2921/// 1. Native tools always keep their original name
2922/// 2. Context server tools get prefixed with their server ID and an underscore
2923/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
2924/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
2925///
2926/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
2927fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
2928 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
2929 let mut tool_name = tool.name();
2930 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2931 tool_name
2932 }
2933
2934 const MAX_TOOL_NAME_LENGTH: usize = 64;
2935
2936 let mut duplicated_tool_names = HashSet::default();
2937 let mut seen_tool_names = HashSet::default();
2938 for tool in tools {
2939 let tool_name = resolve_tool_name(tool);
2940 if seen_tool_names.contains(&tool_name) {
2941 debug_assert!(
2942 tool.source() != assistant_tool::ToolSource::Native,
2943 "There are two built-in tools with the same name: {}",
2944 tool_name
2945 );
2946 duplicated_tool_names.insert(tool_name);
2947 } else {
2948 seen_tool_names.insert(tool_name);
2949 }
2950 }
2951
2952 if duplicated_tool_names.is_empty() {
2953 return tools
2954 .into_iter()
2955 .map(|tool| (resolve_tool_name(tool), tool.clone()))
2956 .collect();
2957 }
2958
2959 tools
2960 .into_iter()
2961 .filter_map(|tool| {
2962 let mut tool_name = resolve_tool_name(tool);
2963 if !duplicated_tool_names.contains(&tool_name) {
2964 return Some((tool_name, tool.clone()));
2965 }
2966 match tool.source() {
2967 assistant_tool::ToolSource::Native => {
2968 // Built-in tools always keep their original name
2969 Some((tool_name, tool.clone()))
2970 }
2971 assistant_tool::ToolSource::ContextServer { id } => {
2972 // Context server tools are prefixed with the context server ID, and truncated if necessary
2973 tool_name.insert(0, '_');
2974 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
2975 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
2976 let mut id = id.to_string();
2977 id.truncate(len);
2978 tool_name.insert_str(0, &id);
2979 } else {
2980 tool_name.insert_str(0, &id);
2981 }
2982
2983 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2984
2985 if seen_tool_names.contains(&tool_name) {
2986 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
2987 None
2988 } else {
2989 Some((tool_name, tool.clone()))
2990 }
2991 }
2992 }
2993 })
2994 .collect()
2995}
2996
2997#[cfg(test)]
2998mod tests {
2999 use super::*;
3000 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
3001 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
3002 use assistant_tool::ToolRegistry;
3003 use editor::EditorSettings;
3004 use gpui::TestAppContext;
3005 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
3006 use project::{FakeFs, Project};
3007 use prompt_store::PromptBuilder;
3008 use serde_json::json;
3009 use settings::{Settings, SettingsStore};
3010 use std::sync::Arc;
3011 use theme::ThemeSettings;
3012 use ui::IconName;
3013 use util::path;
3014 use workspace::Workspace;
3015
3016 #[gpui::test]
3017 async fn test_message_with_context(cx: &mut TestAppContext) {
3018 init_test_settings(cx);
3019
3020 let project = create_test_project(
3021 cx,
3022 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3023 )
3024 .await;
3025
3026 let (_workspace, _thread_store, thread, context_store, model) =
3027 setup_test_environment(cx, project.clone()).await;
3028
3029 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3030 .await
3031 .unwrap();
3032
3033 let context =
3034 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3035 let loaded_context = cx
3036 .update(|cx| load_context(vec![context], &project, &None, cx))
3037 .await;
3038
3039 // Insert user message with context
3040 let message_id = thread.update(cx, |thread, cx| {
3041 thread.insert_user_message(
3042 "Please explain this code",
3043 loaded_context,
3044 None,
3045 Vec::new(),
3046 cx,
3047 )
3048 });
3049
3050 // Check content and context in message object
3051 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3052
3053 // Use different path format strings based on platform for the test
3054 #[cfg(windows)]
3055 let path_part = r"test\code.rs";
3056 #[cfg(not(windows))]
3057 let path_part = "test/code.rs";
3058
3059 let expected_context = format!(
3060 r#"
3061<context>
3062The following items were attached by the user. They are up-to-date and don't need to be re-read.
3063
3064<files>
3065```rs {path_part}
3066fn main() {{
3067 println!("Hello, world!");
3068}}
3069```
3070</files>
3071</context>
3072"#
3073 );
3074
3075 assert_eq!(message.role, Role::User);
3076 assert_eq!(message.segments.len(), 1);
3077 assert_eq!(
3078 message.segments[0],
3079 MessageSegment::Text("Please explain this code".to_string())
3080 );
3081 assert_eq!(message.loaded_context.text, expected_context);
3082
3083 // Check message in request
3084 let request = thread.update(cx, |thread, cx| {
3085 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3086 });
3087
3088 assert_eq!(request.messages.len(), 2);
3089 let expected_full_message = format!("{}Please explain this code", expected_context);
3090 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3091 }
3092
3093 #[gpui::test]
3094 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3095 init_test_settings(cx);
3096
3097 let project = create_test_project(
3098 cx,
3099 json!({
3100 "file1.rs": "fn function1() {}\n",
3101 "file2.rs": "fn function2() {}\n",
3102 "file3.rs": "fn function3() {}\n",
3103 "file4.rs": "fn function4() {}\n",
3104 }),
3105 )
3106 .await;
3107
3108 let (_, _thread_store, thread, context_store, model) =
3109 setup_test_environment(cx, project.clone()).await;
3110
3111 // First message with context 1
3112 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3113 .await
3114 .unwrap();
3115 let new_contexts = context_store.update(cx, |store, cx| {
3116 store.new_context_for_thread(thread.read(cx), None)
3117 });
3118 assert_eq!(new_contexts.len(), 1);
3119 let loaded_context = cx
3120 .update(|cx| load_context(new_contexts, &project, &None, cx))
3121 .await;
3122 let message1_id = thread.update(cx, |thread, cx| {
3123 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3124 });
3125
3126 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3127 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3128 .await
3129 .unwrap();
3130 let new_contexts = context_store.update(cx, |store, cx| {
3131 store.new_context_for_thread(thread.read(cx), None)
3132 });
3133 assert_eq!(new_contexts.len(), 1);
3134 let loaded_context = cx
3135 .update(|cx| load_context(new_contexts, &project, &None, cx))
3136 .await;
3137 let message2_id = thread.update(cx, |thread, cx| {
3138 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3139 });
3140
3141 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3142 //
3143 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3144 .await
3145 .unwrap();
3146 let new_contexts = context_store.update(cx, |store, cx| {
3147 store.new_context_for_thread(thread.read(cx), None)
3148 });
3149 assert_eq!(new_contexts.len(), 1);
3150 let loaded_context = cx
3151 .update(|cx| load_context(new_contexts, &project, &None, cx))
3152 .await;
3153 let message3_id = thread.update(cx, |thread, cx| {
3154 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3155 });
3156
3157 // Check what contexts are included in each message
3158 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3159 (
3160 thread.message(message1_id).unwrap().clone(),
3161 thread.message(message2_id).unwrap().clone(),
3162 thread.message(message3_id).unwrap().clone(),
3163 )
3164 });
3165
3166 // First message should include context 1
3167 assert!(message1.loaded_context.text.contains("file1.rs"));
3168
3169 // Second message should include only context 2 (not 1)
3170 assert!(!message2.loaded_context.text.contains("file1.rs"));
3171 assert!(message2.loaded_context.text.contains("file2.rs"));
3172
3173 // Third message should include only context 3 (not 1 or 2)
3174 assert!(!message3.loaded_context.text.contains("file1.rs"));
3175 assert!(!message3.loaded_context.text.contains("file2.rs"));
3176 assert!(message3.loaded_context.text.contains("file3.rs"));
3177
3178 // Check entire request to make sure all contexts are properly included
3179 let request = thread.update(cx, |thread, cx| {
3180 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3181 });
3182
3183 // The request should contain all 3 messages
3184 assert_eq!(request.messages.len(), 4);
3185
3186 // Check that the contexts are properly formatted in each message
3187 assert!(request.messages[1].string_contents().contains("file1.rs"));
3188 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3189 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3190
3191 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3192 assert!(request.messages[2].string_contents().contains("file2.rs"));
3193 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3194
3195 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3196 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3197 assert!(request.messages[3].string_contents().contains("file3.rs"));
3198
3199 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3200 .await
3201 .unwrap();
3202 let new_contexts = context_store.update(cx, |store, cx| {
3203 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3204 });
3205 assert_eq!(new_contexts.len(), 3);
3206 let loaded_context = cx
3207 .update(|cx| load_context(new_contexts, &project, &None, cx))
3208 .await
3209 .loaded_context;
3210
3211 assert!(!loaded_context.text.contains("file1.rs"));
3212 assert!(loaded_context.text.contains("file2.rs"));
3213 assert!(loaded_context.text.contains("file3.rs"));
3214 assert!(loaded_context.text.contains("file4.rs"));
3215
3216 let new_contexts = context_store.update(cx, |store, cx| {
3217 // Remove file4.rs
3218 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3219 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3220 });
3221 assert_eq!(new_contexts.len(), 2);
3222 let loaded_context = cx
3223 .update(|cx| load_context(new_contexts, &project, &None, cx))
3224 .await
3225 .loaded_context;
3226
3227 assert!(!loaded_context.text.contains("file1.rs"));
3228 assert!(loaded_context.text.contains("file2.rs"));
3229 assert!(loaded_context.text.contains("file3.rs"));
3230 assert!(!loaded_context.text.contains("file4.rs"));
3231
3232 let new_contexts = context_store.update(cx, |store, cx| {
3233 // Remove file3.rs
3234 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3235 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3236 });
3237 assert_eq!(new_contexts.len(), 1);
3238 let loaded_context = cx
3239 .update(|cx| load_context(new_contexts, &project, &None, cx))
3240 .await
3241 .loaded_context;
3242
3243 assert!(!loaded_context.text.contains("file1.rs"));
3244 assert!(loaded_context.text.contains("file2.rs"));
3245 assert!(!loaded_context.text.contains("file3.rs"));
3246 assert!(!loaded_context.text.contains("file4.rs"));
3247 }
3248
3249 #[gpui::test]
3250 async fn test_message_without_files(cx: &mut TestAppContext) {
3251 init_test_settings(cx);
3252
3253 let project = create_test_project(
3254 cx,
3255 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3256 )
3257 .await;
3258
3259 let (_, _thread_store, thread, _context_store, model) =
3260 setup_test_environment(cx, project.clone()).await;
3261
3262 // Insert user message without any context (empty context vector)
3263 let message_id = thread.update(cx, |thread, cx| {
3264 thread.insert_user_message(
3265 "What is the best way to learn Rust?",
3266 ContextLoadResult::default(),
3267 None,
3268 Vec::new(),
3269 cx,
3270 )
3271 });
3272
3273 // Check content and context in message object
3274 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3275
3276 // Context should be empty when no files are included
3277 assert_eq!(message.role, Role::User);
3278 assert_eq!(message.segments.len(), 1);
3279 assert_eq!(
3280 message.segments[0],
3281 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3282 );
3283 assert_eq!(message.loaded_context.text, "");
3284
3285 // Check message in request
3286 let request = thread.update(cx, |thread, cx| {
3287 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3288 });
3289
3290 assert_eq!(request.messages.len(), 2);
3291 assert_eq!(
3292 request.messages[1].string_contents(),
3293 "What is the best way to learn Rust?"
3294 );
3295
3296 // Add second message, also without context
3297 let message2_id = thread.update(cx, |thread, cx| {
3298 thread.insert_user_message(
3299 "Are there any good books?",
3300 ContextLoadResult::default(),
3301 None,
3302 Vec::new(),
3303 cx,
3304 )
3305 });
3306
3307 let message2 =
3308 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3309 assert_eq!(message2.loaded_context.text, "");
3310
3311 // Check that both messages appear in the request
3312 let request = thread.update(cx, |thread, cx| {
3313 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3314 });
3315
3316 assert_eq!(request.messages.len(), 3);
3317 assert_eq!(
3318 request.messages[1].string_contents(),
3319 "What is the best way to learn Rust?"
3320 );
3321 assert_eq!(
3322 request.messages[2].string_contents(),
3323 "Are there any good books?"
3324 );
3325 }
3326
3327 #[gpui::test]
3328 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3329 init_test_settings(cx);
3330
3331 let project = create_test_project(
3332 cx,
3333 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3334 )
3335 .await;
3336
3337 let (_workspace, thread_store, thread, _context_store, _model) =
3338 setup_test_environment(cx, project.clone()).await;
3339
3340 // Check that we are starting with the default profile
3341 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3342 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3343 assert_eq!(
3344 profile,
3345 AgentProfile::new(AgentProfileId::default(), tool_set)
3346 );
3347 }
3348
3349 #[gpui::test]
3350 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3351 init_test_settings(cx);
3352
3353 let project = create_test_project(
3354 cx,
3355 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3356 )
3357 .await;
3358
3359 let (_workspace, thread_store, thread, _context_store, _model) =
3360 setup_test_environment(cx, project.clone()).await;
3361
3362 // Profile gets serialized with default values
3363 let serialized = thread
3364 .update(cx, |thread, cx| thread.serialize(cx))
3365 .await
3366 .unwrap();
3367
3368 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3369
3370 let deserialized = cx.update(|cx| {
3371 thread.update(cx, |thread, cx| {
3372 Thread::deserialize(
3373 thread.id.clone(),
3374 serialized,
3375 thread.project.clone(),
3376 thread.tools.clone(),
3377 thread.prompt_builder.clone(),
3378 thread.project_context.clone(),
3379 None,
3380 cx,
3381 )
3382 })
3383 });
3384 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3385
3386 assert_eq!(
3387 deserialized.profile,
3388 AgentProfile::new(AgentProfileId::default(), tool_set)
3389 );
3390 }
3391
3392 #[gpui::test]
3393 async fn test_temperature_setting(cx: &mut TestAppContext) {
3394 init_test_settings(cx);
3395
3396 let project = create_test_project(
3397 cx,
3398 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3399 )
3400 .await;
3401
3402 let (_workspace, _thread_store, thread, _context_store, model) =
3403 setup_test_environment(cx, project.clone()).await;
3404
3405 // Both model and provider
3406 cx.update(|cx| {
3407 AgentSettings::override_global(
3408 AgentSettings {
3409 model_parameters: vec![LanguageModelParameters {
3410 provider: Some(model.provider_id().0.to_string().into()),
3411 model: Some(model.id().0.clone()),
3412 temperature: Some(0.66),
3413 }],
3414 ..AgentSettings::get_global(cx).clone()
3415 },
3416 cx,
3417 );
3418 });
3419
3420 let request = thread.update(cx, |thread, cx| {
3421 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3422 });
3423 assert_eq!(request.temperature, Some(0.66));
3424
3425 // Only model
3426 cx.update(|cx| {
3427 AgentSettings::override_global(
3428 AgentSettings {
3429 model_parameters: vec![LanguageModelParameters {
3430 provider: None,
3431 model: Some(model.id().0.clone()),
3432 temperature: Some(0.66),
3433 }],
3434 ..AgentSettings::get_global(cx).clone()
3435 },
3436 cx,
3437 );
3438 });
3439
3440 let request = thread.update(cx, |thread, cx| {
3441 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3442 });
3443 assert_eq!(request.temperature, Some(0.66));
3444
3445 // Only provider
3446 cx.update(|cx| {
3447 AgentSettings::override_global(
3448 AgentSettings {
3449 model_parameters: vec![LanguageModelParameters {
3450 provider: Some(model.provider_id().0.to_string().into()),
3451 model: None,
3452 temperature: Some(0.66),
3453 }],
3454 ..AgentSettings::get_global(cx).clone()
3455 },
3456 cx,
3457 );
3458 });
3459
3460 let request = thread.update(cx, |thread, cx| {
3461 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3462 });
3463 assert_eq!(request.temperature, Some(0.66));
3464
3465 // Same model name, different provider
3466 cx.update(|cx| {
3467 AgentSettings::override_global(
3468 AgentSettings {
3469 model_parameters: vec![LanguageModelParameters {
3470 provider: Some("anthropic".into()),
3471 model: Some(model.id().0.clone()),
3472 temperature: Some(0.66),
3473 }],
3474 ..AgentSettings::get_global(cx).clone()
3475 },
3476 cx,
3477 );
3478 });
3479
3480 let request = thread.update(cx, |thread, cx| {
3481 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3482 });
3483 assert_eq!(request.temperature, None);
3484 }
3485
3486 #[gpui::test]
3487 async fn test_thread_summary(cx: &mut TestAppContext) {
3488 init_test_settings(cx);
3489
3490 let project = create_test_project(cx, json!({})).await;
3491
3492 let (_, _thread_store, thread, _context_store, model) =
3493 setup_test_environment(cx, project.clone()).await;
3494
3495 // Initial state should be pending
3496 thread.read_with(cx, |thread, _| {
3497 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3498 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3499 });
3500
3501 // Manually setting the summary should not be allowed in this state
3502 thread.update(cx, |thread, cx| {
3503 thread.set_summary("This should not work", cx);
3504 });
3505
3506 thread.read_with(cx, |thread, _| {
3507 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3508 });
3509
3510 // Send a message
3511 thread.update(cx, |thread, cx| {
3512 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3513 thread.send_to_model(
3514 model.clone(),
3515 CompletionIntent::ThreadSummarization,
3516 None,
3517 cx,
3518 );
3519 });
3520
3521 let fake_model = model.as_fake();
3522 simulate_successful_response(&fake_model, cx);
3523
3524 // Should start generating summary when there are >= 2 messages
3525 thread.read_with(cx, |thread, _| {
3526 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3527 });
3528
3529 // Should not be able to set the summary while generating
3530 thread.update(cx, |thread, cx| {
3531 thread.set_summary("This should not work either", cx);
3532 });
3533
3534 thread.read_with(cx, |thread, _| {
3535 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3536 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3537 });
3538
3539 cx.run_until_parked();
3540 fake_model.stream_last_completion_response("Brief");
3541 fake_model.stream_last_completion_response(" Introduction");
3542 fake_model.end_last_completion_stream();
3543 cx.run_until_parked();
3544
3545 // Summary should be set
3546 thread.read_with(cx, |thread, _| {
3547 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3548 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3549 });
3550
3551 // Now we should be able to set a summary
3552 thread.update(cx, |thread, cx| {
3553 thread.set_summary("Brief Intro", cx);
3554 });
3555
3556 thread.read_with(cx, |thread, _| {
3557 assert_eq!(thread.summary().or_default(), "Brief Intro");
3558 });
3559
3560 // Test setting an empty summary (should default to DEFAULT)
3561 thread.update(cx, |thread, cx| {
3562 thread.set_summary("", cx);
3563 });
3564
3565 thread.read_with(cx, |thread, _| {
3566 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3567 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3568 });
3569 }
3570
3571 #[gpui::test]
3572 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3573 init_test_settings(cx);
3574
3575 let project = create_test_project(cx, json!({})).await;
3576
3577 let (_, _thread_store, thread, _context_store, model) =
3578 setup_test_environment(cx, project.clone()).await;
3579
3580 test_summarize_error(&model, &thread, cx);
3581
3582 // Now we should be able to set a summary
3583 thread.update(cx, |thread, cx| {
3584 thread.set_summary("Brief Intro", cx);
3585 });
3586
3587 thread.read_with(cx, |thread, _| {
3588 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3589 assert_eq!(thread.summary().or_default(), "Brief Intro");
3590 });
3591 }
3592
3593 #[gpui::test]
3594 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3595 init_test_settings(cx);
3596
3597 let project = create_test_project(cx, json!({})).await;
3598
3599 let (_, _thread_store, thread, _context_store, model) =
3600 setup_test_environment(cx, project.clone()).await;
3601
3602 test_summarize_error(&model, &thread, cx);
3603
3604 // Sending another message should not trigger another summarize request
3605 thread.update(cx, |thread, cx| {
3606 thread.insert_user_message(
3607 "How are you?",
3608 ContextLoadResult::default(),
3609 None,
3610 vec![],
3611 cx,
3612 );
3613 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3614 });
3615
3616 let fake_model = model.as_fake();
3617 simulate_successful_response(&fake_model, cx);
3618
3619 thread.read_with(cx, |thread, _| {
3620 // State is still Error, not Generating
3621 assert!(matches!(thread.summary(), ThreadSummary::Error));
3622 });
3623
3624 // But the summarize request can be invoked manually
3625 thread.update(cx, |thread, cx| {
3626 thread.summarize(cx);
3627 });
3628
3629 thread.read_with(cx, |thread, _| {
3630 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3631 });
3632
3633 cx.run_until_parked();
3634 fake_model.stream_last_completion_response("A successful summary");
3635 fake_model.end_last_completion_stream();
3636 cx.run_until_parked();
3637
3638 thread.read_with(cx, |thread, _| {
3639 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3640 assert_eq!(thread.summary().or_default(), "A successful summary");
3641 });
3642 }
3643
3644 #[gpui::test]
3645 fn test_resolve_tool_name_conflicts() {
3646 use assistant_tool::{Tool, ToolSource};
3647
3648 assert_resolve_tool_name_conflicts(
3649 vec![
3650 TestTool::new("tool1", ToolSource::Native),
3651 TestTool::new("tool2", ToolSource::Native),
3652 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3653 ],
3654 vec!["tool1", "tool2", "tool3"],
3655 );
3656
3657 assert_resolve_tool_name_conflicts(
3658 vec![
3659 TestTool::new("tool1", ToolSource::Native),
3660 TestTool::new("tool2", ToolSource::Native),
3661 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3662 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3663 ],
3664 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3665 );
3666
3667 assert_resolve_tool_name_conflicts(
3668 vec![
3669 TestTool::new("tool1", ToolSource::Native),
3670 TestTool::new("tool2", ToolSource::Native),
3671 TestTool::new("tool3", ToolSource::Native),
3672 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3673 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3674 ],
3675 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3676 );
3677
3678 // Test that tool with very long name is always truncated
3679 assert_resolve_tool_name_conflicts(
3680 vec![TestTool::new(
3681 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3682 ToolSource::Native,
3683 )],
3684 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3685 );
3686
3687 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3688 assert_resolve_tool_name_conflicts(
3689 vec![
3690 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3691 TestTool::new(
3692 "tool-with-very-very-very-long-name",
3693 ToolSource::ContextServer {
3694 id: "mcp-with-very-very-very-long-name".into(),
3695 },
3696 ),
3697 ],
3698 vec![
3699 "tool-with-very-very-very-long-name",
3700 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3701 ],
3702 );
3703
3704 fn assert_resolve_tool_name_conflicts(
3705 tools: Vec<TestTool>,
3706 expected: Vec<impl Into<String>>,
3707 ) {
3708 let tools: Vec<Arc<dyn Tool>> = tools
3709 .into_iter()
3710 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3711 .collect();
3712 let tools = resolve_tool_name_conflicts(&tools);
3713 assert_eq!(tools.len(), expected.len());
3714 for (i, expected_name) in expected.into_iter().enumerate() {
3715 let expected_name = expected_name.into();
3716 let actual_name = &tools[i].0;
3717 assert_eq!(
3718 actual_name, &expected_name,
3719 "Expected '{}' got '{}' at index {}",
3720 expected_name, actual_name, i
3721 );
3722 }
3723 }
3724
3725 struct TestTool {
3726 name: String,
3727 source: ToolSource,
3728 }
3729
3730 impl TestTool {
3731 fn new(name: impl Into<String>, source: ToolSource) -> Self {
3732 Self {
3733 name: name.into(),
3734 source,
3735 }
3736 }
3737 }
3738
3739 impl Tool for TestTool {
3740 fn name(&self) -> String {
3741 self.name.clone()
3742 }
3743
3744 fn icon(&self) -> IconName {
3745 IconName::Ai
3746 }
3747
3748 fn may_perform_edits(&self) -> bool {
3749 false
3750 }
3751
3752 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3753 true
3754 }
3755
3756 fn source(&self) -> ToolSource {
3757 self.source.clone()
3758 }
3759
3760 fn description(&self) -> String {
3761 "Test tool".to_string()
3762 }
3763
3764 fn ui_text(&self, _input: &serde_json::Value) -> String {
3765 "Test tool".to_string()
3766 }
3767
3768 fn run(
3769 self: Arc<Self>,
3770 _input: serde_json::Value,
3771 _request: Arc<LanguageModelRequest>,
3772 _project: Entity<Project>,
3773 _action_log: Entity<ActionLog>,
3774 _model: Arc<dyn LanguageModel>,
3775 _window: Option<AnyWindowHandle>,
3776 _cx: &mut App,
3777 ) -> assistant_tool::ToolResult {
3778 assistant_tool::ToolResult {
3779 output: Task::ready(Err(anyhow::anyhow!("No content"))),
3780 card: None,
3781 }
3782 }
3783 }
3784 }
3785
3786 fn test_summarize_error(
3787 model: &Arc<dyn LanguageModel>,
3788 thread: &Entity<Thread>,
3789 cx: &mut TestAppContext,
3790 ) {
3791 thread.update(cx, |thread, cx| {
3792 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3793 thread.send_to_model(
3794 model.clone(),
3795 CompletionIntent::ThreadSummarization,
3796 None,
3797 cx,
3798 );
3799 });
3800
3801 let fake_model = model.as_fake();
3802 simulate_successful_response(&fake_model, cx);
3803
3804 thread.read_with(cx, |thread, _| {
3805 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3806 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3807 });
3808
3809 // Simulate summary request ending
3810 cx.run_until_parked();
3811 fake_model.end_last_completion_stream();
3812 cx.run_until_parked();
3813
3814 // State is set to Error and default message
3815 thread.read_with(cx, |thread, _| {
3816 assert!(matches!(thread.summary(), ThreadSummary::Error));
3817 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3818 });
3819 }
3820
3821 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3822 cx.run_until_parked();
3823 fake_model.stream_last_completion_response("Assistant response");
3824 fake_model.end_last_completion_stream();
3825 cx.run_until_parked();
3826 }
3827
3828 fn init_test_settings(cx: &mut TestAppContext) {
3829 cx.update(|cx| {
3830 let settings_store = SettingsStore::test(cx);
3831 cx.set_global(settings_store);
3832 language::init(cx);
3833 Project::init_settings(cx);
3834 AgentSettings::register(cx);
3835 prompt_store::init(cx);
3836 thread_store::init(cx);
3837 workspace::init_settings(cx);
3838 language_model::init_settings(cx);
3839 ThemeSettings::register(cx);
3840 EditorSettings::register(cx);
3841 ToolRegistry::default_global(cx);
3842 });
3843 }
3844
3845 // Helper to create a test project with test files
3846 async fn create_test_project(
3847 cx: &mut TestAppContext,
3848 files: serde_json::Value,
3849 ) -> Entity<Project> {
3850 let fs = FakeFs::new(cx.executor());
3851 fs.insert_tree(path!("/test"), files).await;
3852 Project::test(fs, [path!("/test").as_ref()], cx).await
3853 }
3854
3855 async fn setup_test_environment(
3856 cx: &mut TestAppContext,
3857 project: Entity<Project>,
3858 ) -> (
3859 Entity<Workspace>,
3860 Entity<ThreadStore>,
3861 Entity<Thread>,
3862 Entity<ContextStore>,
3863 Arc<dyn LanguageModel>,
3864 ) {
3865 let (workspace, cx) =
3866 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3867
3868 let thread_store = cx
3869 .update(|_, cx| {
3870 ThreadStore::load(
3871 project.clone(),
3872 cx.new(|_| ToolWorkingSet::default()),
3873 None,
3874 Arc::new(PromptBuilder::new(None).unwrap()),
3875 cx,
3876 )
3877 })
3878 .await
3879 .unwrap();
3880
3881 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3882 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3883
3884 let provider = Arc::new(FakeLanguageModelProvider);
3885 let model = provider.test_model();
3886 let model: Arc<dyn LanguageModel> = Arc::new(model);
3887
3888 cx.update(|_, cx| {
3889 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3890 registry.set_default_model(
3891 Some(ConfiguredModel {
3892 provider: provider.clone(),
3893 model: model.clone(),
3894 }),
3895 cx,
3896 );
3897 registry.set_thread_summary_model(
3898 Some(ConfiguredModel {
3899 provider,
3900 model: model.clone(),
3901 }),
3902 cx,
3903 );
3904 })
3905 });
3906
3907 (workspace, thread_store, thread, context_store, model)
3908 }
3909
3910 async fn add_file_to_context(
3911 project: &Entity<Project>,
3912 context_store: &Entity<ContextStore>,
3913 path: &str,
3914 cx: &mut TestAppContext,
3915 ) -> Result<Entity<language::Buffer>> {
3916 let buffer_path = project
3917 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3918 .unwrap();
3919
3920 let buffer = project
3921 .update(cx, |project, cx| {
3922 project.open_buffer(buffer_path.clone(), cx)
3923 })
3924 .await
3925 .unwrap();
3926
3927 context_store.update(cx, |context_store, cx| {
3928 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3929 });
3930
3931 Ok(buffer)
3932 }
3933}