1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{HashMap, HashSet};
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::agent_profile::AgentProfile;
45use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
46use crate::thread_store::{
47 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
48 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
49};
50use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
51
52#[derive(
53 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
54)]
55pub struct ThreadId(Arc<str>);
56
57impl ThreadId {
58 pub fn new() -> Self {
59 Self(Uuid::new_v4().to_string().into())
60 }
61}
62
63impl std::fmt::Display for ThreadId {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 write!(f, "{}", self.0)
66 }
67}
68
69impl From<&str> for ThreadId {
70 fn from(value: &str) -> Self {
71 Self(value.into())
72 }
73}
74
75/// The ID of the user prompt that initiated a request.
76///
77/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
78#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
79pub struct PromptId(Arc<str>);
80
81impl PromptId {
82 pub fn new() -> Self {
83 Self(Uuid::new_v4().to_string().into())
84 }
85}
86
87impl std::fmt::Display for PromptId {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(f, "{}", self.0)
90 }
91}
92
93#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
94pub struct MessageId(pub(crate) usize);
95
96impl MessageId {
97 fn post_inc(&mut self) -> Self {
98 Self(post_inc(&mut self.0))
99 }
100}
101
102/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
103#[derive(Clone, Debug)]
104pub struct MessageCrease {
105 pub range: Range<usize>,
106 pub metadata: CreaseMetadata,
107 /// None for a deserialized message, Some otherwise.
108 pub context: Option<AgentContextHandle>,
109}
110
111/// A message in a [`Thread`].
112#[derive(Debug, Clone)]
113pub struct Message {
114 pub id: MessageId,
115 pub role: Role,
116 pub segments: Vec<MessageSegment>,
117 pub loaded_context: LoadedContext,
118 pub creases: Vec<MessageCrease>,
119 pub is_hidden: bool,
120}
121
122impl Message {
123 /// Returns whether the message contains any meaningful text that should be displayed
124 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
125 pub fn should_display_content(&self) -> bool {
126 self.segments.iter().all(|segment| segment.should_display())
127 }
128
129 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
130 if let Some(MessageSegment::Thinking {
131 text: segment,
132 signature: current_signature,
133 }) = self.segments.last_mut()
134 {
135 if let Some(signature) = signature {
136 *current_signature = Some(signature);
137 }
138 segment.push_str(text);
139 } else {
140 self.segments.push(MessageSegment::Thinking {
141 text: text.to_string(),
142 signature,
143 });
144 }
145 }
146
147 pub fn push_text(&mut self, text: &str) {
148 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
149 segment.push_str(text);
150 } else {
151 self.segments.push(MessageSegment::Text(text.to_string()));
152 }
153 }
154
155 pub fn to_string(&self) -> String {
156 let mut result = String::new();
157
158 if !self.loaded_context.text.is_empty() {
159 result.push_str(&self.loaded_context.text);
160 }
161
162 for segment in &self.segments {
163 match segment {
164 MessageSegment::Text(text) => result.push_str(text),
165 MessageSegment::Thinking { text, .. } => {
166 result.push_str("<think>\n");
167 result.push_str(text);
168 result.push_str("\n</think>");
169 }
170 MessageSegment::RedactedThinking(_) => {}
171 }
172 }
173
174 result
175 }
176}
177
178#[derive(Debug, Clone, PartialEq, Eq)]
179pub enum MessageSegment {
180 Text(String),
181 Thinking {
182 text: String,
183 signature: Option<String>,
184 },
185 RedactedThinking(Vec<u8>),
186}
187
188impl MessageSegment {
189 pub fn should_display(&self) -> bool {
190 match self {
191 Self::Text(text) => text.is_empty(),
192 Self::Thinking { text, .. } => text.is_empty(),
193 Self::RedactedThinking(_) => false,
194 }
195 }
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
199pub struct ProjectSnapshot {
200 pub worktree_snapshots: Vec<WorktreeSnapshot>,
201 pub unsaved_buffer_paths: Vec<String>,
202 pub timestamp: DateTime<Utc>,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
206pub struct WorktreeSnapshot {
207 pub worktree_path: String,
208 pub git_state: Option<GitState>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
212pub struct GitState {
213 pub remote_url: Option<String>,
214 pub head_sha: Option<String>,
215 pub current_branch: Option<String>,
216 pub diff: Option<String>,
217}
218
219#[derive(Clone, Debug)]
220pub struct ThreadCheckpoint {
221 message_id: MessageId,
222 git_checkpoint: GitStoreCheckpoint,
223}
224
225#[derive(Copy, Clone, Debug, PartialEq, Eq)]
226pub enum ThreadFeedback {
227 Positive,
228 Negative,
229}
230
231pub enum LastRestoreCheckpoint {
232 Pending {
233 message_id: MessageId,
234 },
235 Error {
236 message_id: MessageId,
237 error: String,
238 },
239}
240
241impl LastRestoreCheckpoint {
242 pub fn message_id(&self) -> MessageId {
243 match self {
244 LastRestoreCheckpoint::Pending { message_id } => *message_id,
245 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
246 }
247 }
248}
249
250#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
251pub enum DetailedSummaryState {
252 #[default]
253 NotGenerated,
254 Generating {
255 message_id: MessageId,
256 },
257 Generated {
258 text: SharedString,
259 message_id: MessageId,
260 },
261}
262
263impl DetailedSummaryState {
264 fn text(&self) -> Option<SharedString> {
265 if let Self::Generated { text, .. } = self {
266 Some(text.clone())
267 } else {
268 None
269 }
270 }
271}
272
273#[derive(Default, Debug)]
274pub struct TotalTokenUsage {
275 pub total: usize,
276 pub max: usize,
277}
278
279impl TotalTokenUsage {
280 pub fn ratio(&self) -> TokenUsageRatio {
281 #[cfg(debug_assertions)]
282 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
283 .unwrap_or("0.8".to_string())
284 .parse()
285 .unwrap();
286 #[cfg(not(debug_assertions))]
287 let warning_threshold: f32 = 0.8;
288
289 // When the maximum is unknown because there is no selected model,
290 // avoid showing the token limit warning.
291 if self.max == 0 {
292 TokenUsageRatio::Normal
293 } else if self.total >= self.max {
294 TokenUsageRatio::Exceeded
295 } else if self.total as f32 / self.max as f32 >= warning_threshold {
296 TokenUsageRatio::Warning
297 } else {
298 TokenUsageRatio::Normal
299 }
300 }
301
302 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
303 TotalTokenUsage {
304 total: self.total + tokens,
305 max: self.max,
306 }
307 }
308}
309
310#[derive(Debug, Default, PartialEq, Eq)]
311pub enum TokenUsageRatio {
312 #[default]
313 Normal,
314 Warning,
315 Exceeded,
316}
317
318#[derive(Debug, Clone, Copy)]
319pub enum QueueState {
320 Sending,
321 Queued { position: usize },
322 Started,
323}
324
325/// A thread of conversation with the LLM.
326pub struct Thread {
327 id: ThreadId,
328 updated_at: DateTime<Utc>,
329 summary: ThreadSummary,
330 pending_summary: Task<Option<()>>,
331 detailed_summary_task: Task<Option<()>>,
332 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
333 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
334 completion_mode: agent_settings::CompletionMode,
335 messages: Vec<Message>,
336 next_message_id: MessageId,
337 last_prompt_id: PromptId,
338 project_context: SharedProjectContext,
339 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
340 completion_count: usize,
341 pending_completions: Vec<PendingCompletion>,
342 project: Entity<Project>,
343 prompt_builder: Arc<PromptBuilder>,
344 tools: Entity<ToolWorkingSet>,
345 tool_use: ToolUseState,
346 action_log: Entity<ActionLog>,
347 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
348 pending_checkpoint: Option<ThreadCheckpoint>,
349 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
350 request_token_usage: Vec<TokenUsage>,
351 cumulative_token_usage: TokenUsage,
352 exceeded_window_error: Option<ExceededWindowError>,
353 last_usage: Option<RequestUsage>,
354 tool_use_limit_reached: bool,
355 feedback: Option<ThreadFeedback>,
356 message_feedback: HashMap<MessageId, ThreadFeedback>,
357 last_auto_capture_at: Option<Instant>,
358 last_received_chunk_at: Option<Instant>,
359 request_callback: Option<
360 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
361 >,
362 remaining_turns: u32,
363 configured_model: Option<ConfiguredModel>,
364 profile: AgentProfile,
365}
366
367#[derive(Clone, Debug, PartialEq, Eq)]
368pub enum ThreadSummary {
369 Pending,
370 Generating,
371 Ready(SharedString),
372 Error,
373}
374
375impl ThreadSummary {
376 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
377
378 pub fn or_default(&self) -> SharedString {
379 self.unwrap_or(Self::DEFAULT)
380 }
381
382 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
383 self.ready().unwrap_or_else(|| message.into())
384 }
385
386 pub fn ready(&self) -> Option<SharedString> {
387 match self {
388 ThreadSummary::Ready(summary) => Some(summary.clone()),
389 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
390 }
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
395pub struct ExceededWindowError {
396 /// Model used when last message exceeded context window
397 model_id: LanguageModelId,
398 /// Token count including last message
399 token_count: usize,
400}
401
402impl Thread {
403 pub fn new(
404 project: Entity<Project>,
405 tools: Entity<ToolWorkingSet>,
406 prompt_builder: Arc<PromptBuilder>,
407 system_prompt: SharedProjectContext,
408 cx: &mut Context<Self>,
409 ) -> Self {
410 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
411 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
412 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
413
414 Self {
415 id: ThreadId::new(),
416 updated_at: Utc::now(),
417 summary: ThreadSummary::Pending,
418 pending_summary: Task::ready(None),
419 detailed_summary_task: Task::ready(None),
420 detailed_summary_tx,
421 detailed_summary_rx,
422 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
423 messages: Vec::new(),
424 next_message_id: MessageId(0),
425 last_prompt_id: PromptId::new(),
426 project_context: system_prompt,
427 checkpoints_by_message: HashMap::default(),
428 completion_count: 0,
429 pending_completions: Vec::new(),
430 project: project.clone(),
431 prompt_builder,
432 tools: tools.clone(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 tool_use: ToolUseState::new(tools.clone()),
436 action_log: cx.new(|_| ActionLog::new(project.clone())),
437 initial_project_snapshot: {
438 let project_snapshot = Self::project_snapshot(project, cx);
439 cx.foreground_executor()
440 .spawn(async move { Some(project_snapshot.await) })
441 .shared()
442 },
443 request_token_usage: Vec::new(),
444 cumulative_token_usage: TokenUsage::default(),
445 exceeded_window_error: None,
446 last_usage: None,
447 tool_use_limit_reached: false,
448 feedback: None,
449 message_feedback: HashMap::default(),
450 last_auto_capture_at: None,
451 last_received_chunk_at: None,
452 request_callback: None,
453 remaining_turns: u32::MAX,
454 configured_model,
455 profile: AgentProfile::new(profile_id, tools),
456 }
457 }
458
459 pub fn deserialize(
460 id: ThreadId,
461 serialized: SerializedThread,
462 project: Entity<Project>,
463 tools: Entity<ToolWorkingSet>,
464 prompt_builder: Arc<PromptBuilder>,
465 project_context: SharedProjectContext,
466 window: Option<&mut Window>, // None in headless mode
467 cx: &mut Context<Self>,
468 ) -> Self {
469 let next_message_id = MessageId(
470 serialized
471 .messages
472 .last()
473 .map(|message| message.id.0 + 1)
474 .unwrap_or(0),
475 );
476 let tool_use = ToolUseState::from_serialized_messages(
477 tools.clone(),
478 &serialized.messages,
479 project.clone(),
480 window,
481 cx,
482 );
483 let (detailed_summary_tx, detailed_summary_rx) =
484 postage::watch::channel_with(serialized.detailed_summary_state);
485
486 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
487 serialized
488 .model
489 .and_then(|model| {
490 let model = SelectedModel {
491 provider: model.provider.clone().into(),
492 model: model.model.clone().into(),
493 };
494 registry.select_model(&model, cx)
495 })
496 .or_else(|| registry.default_model())
497 });
498
499 let completion_mode = serialized
500 .completion_mode
501 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
502 let profile_id = serialized
503 .profile
504 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
505
506 Self {
507 id,
508 updated_at: serialized.updated_at,
509 summary: ThreadSummary::Ready(serialized.summary),
510 pending_summary: Task::ready(None),
511 detailed_summary_task: Task::ready(None),
512 detailed_summary_tx,
513 detailed_summary_rx,
514 completion_mode,
515 messages: serialized
516 .messages
517 .into_iter()
518 .map(|message| Message {
519 id: message.id,
520 role: message.role,
521 segments: message
522 .segments
523 .into_iter()
524 .map(|segment| match segment {
525 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
526 SerializedMessageSegment::Thinking { text, signature } => {
527 MessageSegment::Thinking { text, signature }
528 }
529 SerializedMessageSegment::RedactedThinking { data } => {
530 MessageSegment::RedactedThinking(data)
531 }
532 })
533 .collect(),
534 loaded_context: LoadedContext {
535 contexts: Vec::new(),
536 text: message.context,
537 images: Vec::new(),
538 },
539 creases: message
540 .creases
541 .into_iter()
542 .map(|crease| MessageCrease {
543 range: crease.start..crease.end,
544 metadata: CreaseMetadata {
545 icon_path: crease.icon_path,
546 label: crease.label,
547 },
548 context: None,
549 })
550 .collect(),
551 is_hidden: message.is_hidden,
552 })
553 .collect(),
554 next_message_id,
555 last_prompt_id: PromptId::new(),
556 project_context,
557 checkpoints_by_message: HashMap::default(),
558 completion_count: 0,
559 pending_completions: Vec::new(),
560 last_restore_checkpoint: None,
561 pending_checkpoint: None,
562 project: project.clone(),
563 prompt_builder,
564 tools: tools.clone(),
565 tool_use,
566 action_log: cx.new(|_| ActionLog::new(project)),
567 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
568 request_token_usage: serialized.request_token_usage,
569 cumulative_token_usage: serialized.cumulative_token_usage,
570 exceeded_window_error: None,
571 last_usage: None,
572 tool_use_limit_reached: serialized.tool_use_limit_reached,
573 feedback: None,
574 message_feedback: HashMap::default(),
575 last_auto_capture_at: None,
576 last_received_chunk_at: None,
577 request_callback: None,
578 remaining_turns: u32::MAX,
579 configured_model,
580 profile: AgentProfile::new(profile_id, tools),
581 }
582 }
583
584 pub fn set_request_callback(
585 &mut self,
586 callback: impl 'static
587 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
588 ) {
589 self.request_callback = Some(Box::new(callback));
590 }
591
592 pub fn id(&self) -> &ThreadId {
593 &self.id
594 }
595
596 pub fn profile(&self) -> &AgentProfile {
597 &self.profile
598 }
599
600 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
601 if &id != self.profile.id() {
602 self.profile = AgentProfile::new(id, self.tools.clone());
603 cx.emit(ThreadEvent::ProfileChanged);
604 }
605 }
606
607 pub fn is_empty(&self) -> bool {
608 self.messages.is_empty()
609 }
610
611 pub fn updated_at(&self) -> DateTime<Utc> {
612 self.updated_at
613 }
614
615 pub fn touch_updated_at(&mut self) {
616 self.updated_at = Utc::now();
617 }
618
619 pub fn advance_prompt_id(&mut self) {
620 self.last_prompt_id = PromptId::new();
621 }
622
623 pub fn project_context(&self) -> SharedProjectContext {
624 self.project_context.clone()
625 }
626
627 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
628 if self.configured_model.is_none() {
629 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
630 }
631 self.configured_model.clone()
632 }
633
634 pub fn configured_model(&self) -> Option<ConfiguredModel> {
635 self.configured_model.clone()
636 }
637
638 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
639 self.configured_model = model;
640 cx.notify();
641 }
642
643 pub fn summary(&self) -> &ThreadSummary {
644 &self.summary
645 }
646
647 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
648 let current_summary = match &self.summary {
649 ThreadSummary::Pending | ThreadSummary::Generating => return,
650 ThreadSummary::Ready(summary) => summary,
651 ThreadSummary::Error => &ThreadSummary::DEFAULT,
652 };
653
654 let mut new_summary = new_summary.into();
655
656 if new_summary.is_empty() {
657 new_summary = ThreadSummary::DEFAULT;
658 }
659
660 if current_summary != &new_summary {
661 self.summary = ThreadSummary::Ready(new_summary);
662 cx.emit(ThreadEvent::SummaryChanged);
663 }
664 }
665
666 pub fn completion_mode(&self) -> CompletionMode {
667 self.completion_mode
668 }
669
670 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
671 self.completion_mode = mode;
672 }
673
674 pub fn message(&self, id: MessageId) -> Option<&Message> {
675 let index = self
676 .messages
677 .binary_search_by(|message| message.id.cmp(&id))
678 .ok()?;
679
680 self.messages.get(index)
681 }
682
683 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
684 self.messages.iter()
685 }
686
687 pub fn is_generating(&self) -> bool {
688 !self.pending_completions.is_empty() || !self.all_tools_finished()
689 }
690
691 /// Indicates whether streaming of language model events is stale.
692 /// When `is_generating()` is false, this method returns `None`.
693 pub fn is_generation_stale(&self) -> Option<bool> {
694 const STALE_THRESHOLD: u128 = 250;
695
696 self.last_received_chunk_at
697 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
698 }
699
700 fn received_chunk(&mut self) {
701 self.last_received_chunk_at = Some(Instant::now());
702 }
703
704 pub fn queue_state(&self) -> Option<QueueState> {
705 self.pending_completions
706 .first()
707 .map(|pending_completion| pending_completion.queue_state)
708 }
709
710 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
711 &self.tools
712 }
713
714 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
715 self.tool_use
716 .pending_tool_uses()
717 .into_iter()
718 .find(|tool_use| &tool_use.id == id)
719 }
720
721 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
722 self.tool_use
723 .pending_tool_uses()
724 .into_iter()
725 .filter(|tool_use| tool_use.status.needs_confirmation())
726 }
727
728 pub fn has_pending_tool_uses(&self) -> bool {
729 !self.tool_use.pending_tool_uses().is_empty()
730 }
731
732 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
733 self.checkpoints_by_message.get(&id).cloned()
734 }
735
736 pub fn restore_checkpoint(
737 &mut self,
738 checkpoint: ThreadCheckpoint,
739 cx: &mut Context<Self>,
740 ) -> Task<Result<()>> {
741 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
742 message_id: checkpoint.message_id,
743 });
744 cx.emit(ThreadEvent::CheckpointChanged);
745 cx.notify();
746
747 let git_store = self.project().read(cx).git_store().clone();
748 let restore = git_store.update(cx, |git_store, cx| {
749 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
750 });
751
752 cx.spawn(async move |this, cx| {
753 let result = restore.await;
754 this.update(cx, |this, cx| {
755 if let Err(err) = result.as_ref() {
756 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
757 message_id: checkpoint.message_id,
758 error: err.to_string(),
759 });
760 } else {
761 this.truncate(checkpoint.message_id, cx);
762 this.last_restore_checkpoint = None;
763 }
764 this.pending_checkpoint = None;
765 cx.emit(ThreadEvent::CheckpointChanged);
766 cx.notify();
767 })?;
768 result
769 })
770 }
771
772 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
773 let pending_checkpoint = if self.is_generating() {
774 return;
775 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
776 checkpoint
777 } else {
778 return;
779 };
780
781 self.finalize_checkpoint(pending_checkpoint, cx);
782 }
783
784 fn finalize_checkpoint(
785 &mut self,
786 pending_checkpoint: ThreadCheckpoint,
787 cx: &mut Context<Self>,
788 ) {
789 let git_store = self.project.read(cx).git_store().clone();
790 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
791 cx.spawn(async move |this, cx| match final_checkpoint.await {
792 Ok(final_checkpoint) => {
793 let equal = git_store
794 .update(cx, |store, cx| {
795 store.compare_checkpoints(
796 pending_checkpoint.git_checkpoint.clone(),
797 final_checkpoint.clone(),
798 cx,
799 )
800 })?
801 .await
802 .unwrap_or(false);
803
804 if !equal {
805 this.update(cx, |this, cx| {
806 this.insert_checkpoint(pending_checkpoint, cx)
807 })?;
808 }
809
810 Ok(())
811 }
812 Err(_) => this.update(cx, |this, cx| {
813 this.insert_checkpoint(pending_checkpoint, cx)
814 }),
815 })
816 .detach();
817 }
818
819 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
820 self.checkpoints_by_message
821 .insert(checkpoint.message_id, checkpoint);
822 cx.emit(ThreadEvent::CheckpointChanged);
823 cx.notify();
824 }
825
826 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
827 self.last_restore_checkpoint.as_ref()
828 }
829
830 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
831 let Some(message_ix) = self
832 .messages
833 .iter()
834 .rposition(|message| message.id == message_id)
835 else {
836 return;
837 };
838 for deleted_message in self.messages.drain(message_ix..) {
839 self.checkpoints_by_message.remove(&deleted_message.id);
840 }
841 cx.notify();
842 }
843
844 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
845 self.messages
846 .iter()
847 .find(|message| message.id == id)
848 .into_iter()
849 .flat_map(|message| message.loaded_context.contexts.iter())
850 }
851
852 pub fn is_turn_end(&self, ix: usize) -> bool {
853 if self.messages.is_empty() {
854 return false;
855 }
856
857 if !self.is_generating() && ix == self.messages.len() - 1 {
858 return true;
859 }
860
861 let Some(message) = self.messages.get(ix) else {
862 return false;
863 };
864
865 if message.role != Role::Assistant {
866 return false;
867 }
868
869 self.messages
870 .get(ix + 1)
871 .and_then(|message| {
872 self.message(message.id)
873 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
874 })
875 .unwrap_or(false)
876 }
877
878 pub fn last_usage(&self) -> Option<RequestUsage> {
879 self.last_usage
880 }
881
882 pub fn tool_use_limit_reached(&self) -> bool {
883 self.tool_use_limit_reached
884 }
885
886 /// Returns whether all of the tool uses have finished running.
887 pub fn all_tools_finished(&self) -> bool {
888 // If the only pending tool uses left are the ones with errors, then
889 // that means that we've finished running all of the pending tools.
890 self.tool_use
891 .pending_tool_uses()
892 .iter()
893 .all(|pending_tool_use| pending_tool_use.status.is_error())
894 }
895
896 /// Returns whether any pending tool uses may perform edits
897 pub fn has_pending_edit_tool_uses(&self) -> bool {
898 self.tool_use
899 .pending_tool_uses()
900 .iter()
901 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
902 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
903 }
904
905 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
906 self.tool_use.tool_uses_for_message(id, cx)
907 }
908
909 pub fn tool_results_for_message(
910 &self,
911 assistant_message_id: MessageId,
912 ) -> Vec<&LanguageModelToolResult> {
913 self.tool_use.tool_results_for_message(assistant_message_id)
914 }
915
916 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
917 self.tool_use.tool_result(id)
918 }
919
920 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
921 match &self.tool_use.tool_result(id)?.content {
922 LanguageModelToolResultContent::Text(text) => Some(text),
923 LanguageModelToolResultContent::Image(_) => {
924 // TODO: We should display image
925 None
926 }
927 }
928 }
929
930 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
931 self.tool_use.tool_result_card(id).cloned()
932 }
933
934 /// Return tools that are both enabled and supported by the model
935 pub fn available_tools(
936 &self,
937 cx: &App,
938 model: Arc<dyn LanguageModel>,
939 ) -> Vec<LanguageModelRequestTool> {
940 if model.supports_tools() {
941 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
942 .into_iter()
943 .filter_map(|(name, tool)| {
944 // Skip tools that cannot be supported
945 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
946 Some(LanguageModelRequestTool {
947 name,
948 description: tool.description(),
949 input_schema,
950 })
951 })
952 .collect()
953 } else {
954 Vec::default()
955 }
956 }
957
958 pub fn insert_user_message(
959 &mut self,
960 text: impl Into<String>,
961 loaded_context: ContextLoadResult,
962 git_checkpoint: Option<GitStoreCheckpoint>,
963 creases: Vec<MessageCrease>,
964 cx: &mut Context<Self>,
965 ) -> MessageId {
966 if !loaded_context.referenced_buffers.is_empty() {
967 self.action_log.update(cx, |log, cx| {
968 for buffer in loaded_context.referenced_buffers {
969 log.buffer_read(buffer, cx);
970 }
971 });
972 }
973
974 let message_id = self.insert_message(
975 Role::User,
976 vec![MessageSegment::Text(text.into())],
977 loaded_context.loaded_context,
978 creases,
979 false,
980 cx,
981 );
982
983 if let Some(git_checkpoint) = git_checkpoint {
984 self.pending_checkpoint = Some(ThreadCheckpoint {
985 message_id,
986 git_checkpoint,
987 });
988 }
989
990 self.auto_capture_telemetry(cx);
991
992 message_id
993 }
994
995 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
996 let id = self.insert_message(
997 Role::User,
998 vec![MessageSegment::Text("Continue where you left off".into())],
999 LoadedContext::default(),
1000 vec![],
1001 true,
1002 cx,
1003 );
1004 self.pending_checkpoint = None;
1005
1006 id
1007 }
1008
1009 pub fn insert_assistant_message(
1010 &mut self,
1011 segments: Vec<MessageSegment>,
1012 cx: &mut Context<Self>,
1013 ) -> MessageId {
1014 self.insert_message(
1015 Role::Assistant,
1016 segments,
1017 LoadedContext::default(),
1018 Vec::new(),
1019 false,
1020 cx,
1021 )
1022 }
1023
1024 pub fn insert_message(
1025 &mut self,
1026 role: Role,
1027 segments: Vec<MessageSegment>,
1028 loaded_context: LoadedContext,
1029 creases: Vec<MessageCrease>,
1030 is_hidden: bool,
1031 cx: &mut Context<Self>,
1032 ) -> MessageId {
1033 let id = self.next_message_id.post_inc();
1034 self.messages.push(Message {
1035 id,
1036 role,
1037 segments,
1038 loaded_context,
1039 creases,
1040 is_hidden,
1041 });
1042 self.touch_updated_at();
1043 cx.emit(ThreadEvent::MessageAdded(id));
1044 id
1045 }
1046
1047 pub fn edit_message(
1048 &mut self,
1049 id: MessageId,
1050 new_role: Role,
1051 new_segments: Vec<MessageSegment>,
1052 creases: Vec<MessageCrease>,
1053 loaded_context: Option<LoadedContext>,
1054 checkpoint: Option<GitStoreCheckpoint>,
1055 cx: &mut Context<Self>,
1056 ) -> bool {
1057 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1058 return false;
1059 };
1060 message.role = new_role;
1061 message.segments = new_segments;
1062 message.creases = creases;
1063 if let Some(context) = loaded_context {
1064 message.loaded_context = context;
1065 }
1066 if let Some(git_checkpoint) = checkpoint {
1067 self.checkpoints_by_message.insert(
1068 id,
1069 ThreadCheckpoint {
1070 message_id: id,
1071 git_checkpoint,
1072 },
1073 );
1074 }
1075 self.touch_updated_at();
1076 cx.emit(ThreadEvent::MessageEdited(id));
1077 true
1078 }
1079
1080 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1081 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1082 return false;
1083 };
1084 self.messages.remove(index);
1085 self.touch_updated_at();
1086 cx.emit(ThreadEvent::MessageDeleted(id));
1087 true
1088 }
1089
1090 /// Returns the representation of this [`Thread`] in a textual form.
1091 ///
1092 /// This is the representation we use when attaching a thread as context to another thread.
1093 pub fn text(&self) -> String {
1094 let mut text = String::new();
1095
1096 for message in &self.messages {
1097 text.push_str(match message.role {
1098 language_model::Role::User => "User:",
1099 language_model::Role::Assistant => "Agent:",
1100 language_model::Role::System => "System:",
1101 });
1102 text.push('\n');
1103
1104 for segment in &message.segments {
1105 match segment {
1106 MessageSegment::Text(content) => text.push_str(content),
1107 MessageSegment::Thinking { text: content, .. } => {
1108 text.push_str(&format!("<think>{}</think>", content))
1109 }
1110 MessageSegment::RedactedThinking(_) => {}
1111 }
1112 }
1113 text.push('\n');
1114 }
1115
1116 text
1117 }
1118
1119 /// Serializes this thread into a format for storage or telemetry.
1120 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1121 let initial_project_snapshot = self.initial_project_snapshot.clone();
1122 cx.spawn(async move |this, cx| {
1123 let initial_project_snapshot = initial_project_snapshot.await;
1124 this.read_with(cx, |this, cx| SerializedThread {
1125 version: SerializedThread::VERSION.to_string(),
1126 summary: this.summary().or_default(),
1127 updated_at: this.updated_at(),
1128 messages: this
1129 .messages()
1130 .map(|message| SerializedMessage {
1131 id: message.id,
1132 role: message.role,
1133 segments: message
1134 .segments
1135 .iter()
1136 .map(|segment| match segment {
1137 MessageSegment::Text(text) => {
1138 SerializedMessageSegment::Text { text: text.clone() }
1139 }
1140 MessageSegment::Thinking { text, signature } => {
1141 SerializedMessageSegment::Thinking {
1142 text: text.clone(),
1143 signature: signature.clone(),
1144 }
1145 }
1146 MessageSegment::RedactedThinking(data) => {
1147 SerializedMessageSegment::RedactedThinking {
1148 data: data.clone(),
1149 }
1150 }
1151 })
1152 .collect(),
1153 tool_uses: this
1154 .tool_uses_for_message(message.id, cx)
1155 .into_iter()
1156 .map(|tool_use| SerializedToolUse {
1157 id: tool_use.id,
1158 name: tool_use.name,
1159 input: tool_use.input,
1160 })
1161 .collect(),
1162 tool_results: this
1163 .tool_results_for_message(message.id)
1164 .into_iter()
1165 .map(|tool_result| SerializedToolResult {
1166 tool_use_id: tool_result.tool_use_id.clone(),
1167 is_error: tool_result.is_error,
1168 content: tool_result.content.clone(),
1169 output: tool_result.output.clone(),
1170 })
1171 .collect(),
1172 context: message.loaded_context.text.clone(),
1173 creases: message
1174 .creases
1175 .iter()
1176 .map(|crease| SerializedCrease {
1177 start: crease.range.start,
1178 end: crease.range.end,
1179 icon_path: crease.metadata.icon_path.clone(),
1180 label: crease.metadata.label.clone(),
1181 })
1182 .collect(),
1183 is_hidden: message.is_hidden,
1184 })
1185 .collect(),
1186 initial_project_snapshot,
1187 cumulative_token_usage: this.cumulative_token_usage,
1188 request_token_usage: this.request_token_usage.clone(),
1189 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1190 exceeded_window_error: this.exceeded_window_error.clone(),
1191 model: this
1192 .configured_model
1193 .as_ref()
1194 .map(|model| SerializedLanguageModel {
1195 provider: model.provider.id().0.to_string(),
1196 model: model.model.id().0.to_string(),
1197 }),
1198 completion_mode: Some(this.completion_mode),
1199 tool_use_limit_reached: this.tool_use_limit_reached,
1200 profile: Some(this.profile.id().clone()),
1201 })
1202 })
1203 }
1204
1205 pub fn remaining_turns(&self) -> u32 {
1206 self.remaining_turns
1207 }
1208
1209 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1210 self.remaining_turns = remaining_turns;
1211 }
1212
1213 pub fn send_to_model(
1214 &mut self,
1215 model: Arc<dyn LanguageModel>,
1216 intent: CompletionIntent,
1217 window: Option<AnyWindowHandle>,
1218 cx: &mut Context<Self>,
1219 ) {
1220 if self.remaining_turns == 0 {
1221 return;
1222 }
1223
1224 self.remaining_turns -= 1;
1225
1226 let request = self.to_completion_request(model.clone(), intent, cx);
1227
1228 self.stream_completion(request, model, window, cx);
1229 }
1230
1231 pub fn used_tools_since_last_user_message(&self) -> bool {
1232 for message in self.messages.iter().rev() {
1233 if self.tool_use.message_has_tool_results(message.id) {
1234 return true;
1235 } else if message.role == Role::User {
1236 return false;
1237 }
1238 }
1239
1240 false
1241 }
1242
1243 pub fn to_completion_request(
1244 &self,
1245 model: Arc<dyn LanguageModel>,
1246 intent: CompletionIntent,
1247 cx: &mut Context<Self>,
1248 ) -> LanguageModelRequest {
1249 let mut request = LanguageModelRequest {
1250 thread_id: Some(self.id.to_string()),
1251 prompt_id: Some(self.last_prompt_id.to_string()),
1252 intent: Some(intent),
1253 mode: None,
1254 messages: vec![],
1255 tools: Vec::new(),
1256 tool_choice: None,
1257 stop: Vec::new(),
1258 temperature: AgentSettings::temperature_for_model(&model, cx),
1259 };
1260
1261 let available_tools = self.available_tools(cx, model.clone());
1262 let available_tool_names = available_tools
1263 .iter()
1264 .map(|tool| tool.name.clone())
1265 .collect();
1266
1267 let model_context = &ModelContext {
1268 available_tools: available_tool_names,
1269 };
1270
1271 if let Some(project_context) = self.project_context.borrow().as_ref() {
1272 match self
1273 .prompt_builder
1274 .generate_assistant_system_prompt(project_context, model_context)
1275 {
1276 Err(err) => {
1277 let message = format!("{err:?}").into();
1278 log::error!("{message}");
1279 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1280 header: "Error generating system prompt".into(),
1281 message,
1282 }));
1283 }
1284 Ok(system_prompt) => {
1285 request.messages.push(LanguageModelRequestMessage {
1286 role: Role::System,
1287 content: vec![MessageContent::Text(system_prompt)],
1288 cache: true,
1289 });
1290 }
1291 }
1292 } else {
1293 let message = "Context for system prompt unexpectedly not ready.".into();
1294 log::error!("{message}");
1295 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1296 header: "Error generating system prompt".into(),
1297 message,
1298 }));
1299 }
1300
1301 let mut message_ix_to_cache = None;
1302 for message in &self.messages {
1303 let mut request_message = LanguageModelRequestMessage {
1304 role: message.role,
1305 content: Vec::new(),
1306 cache: false,
1307 };
1308
1309 message
1310 .loaded_context
1311 .add_to_request_message(&mut request_message);
1312
1313 for segment in &message.segments {
1314 match segment {
1315 MessageSegment::Text(text) => {
1316 if !text.is_empty() {
1317 request_message
1318 .content
1319 .push(MessageContent::Text(text.into()));
1320 }
1321 }
1322 MessageSegment::Thinking { text, signature } => {
1323 if !text.is_empty() {
1324 request_message.content.push(MessageContent::Thinking {
1325 text: text.into(),
1326 signature: signature.clone(),
1327 });
1328 }
1329 }
1330 MessageSegment::RedactedThinking(data) => {
1331 request_message
1332 .content
1333 .push(MessageContent::RedactedThinking(data.clone()));
1334 }
1335 };
1336 }
1337
1338 let mut cache_message = true;
1339 let mut tool_results_message = LanguageModelRequestMessage {
1340 role: Role::User,
1341 content: Vec::new(),
1342 cache: false,
1343 };
1344 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1345 if let Some(tool_result) = tool_result {
1346 request_message
1347 .content
1348 .push(MessageContent::ToolUse(tool_use.clone()));
1349 tool_results_message
1350 .content
1351 .push(MessageContent::ToolResult(LanguageModelToolResult {
1352 tool_use_id: tool_use.id.clone(),
1353 tool_name: tool_result.tool_name.clone(),
1354 is_error: tool_result.is_error,
1355 content: if tool_result.content.is_empty() {
1356 // Surprisingly, the API fails if we return an empty string here.
1357 // It thinks we are sending a tool use without a tool result.
1358 "<Tool returned an empty string>".into()
1359 } else {
1360 tool_result.content.clone()
1361 },
1362 output: None,
1363 }));
1364 } else {
1365 cache_message = false;
1366 log::debug!(
1367 "skipped tool use {:?} because it is still pending",
1368 tool_use
1369 );
1370 }
1371 }
1372
1373 if cache_message {
1374 message_ix_to_cache = Some(request.messages.len());
1375 }
1376 request.messages.push(request_message);
1377
1378 if !tool_results_message.content.is_empty() {
1379 if cache_message {
1380 message_ix_to_cache = Some(request.messages.len());
1381 }
1382 request.messages.push(tool_results_message);
1383 }
1384 }
1385
1386 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1387 if let Some(message_ix_to_cache) = message_ix_to_cache {
1388 request.messages[message_ix_to_cache].cache = true;
1389 }
1390
1391 self.attached_tracked_files_state(&mut request.messages, cx);
1392
1393 request.tools = available_tools;
1394 request.mode = if model.supports_max_mode() {
1395 Some(self.completion_mode.into())
1396 } else {
1397 Some(CompletionMode::Normal.into())
1398 };
1399
1400 request
1401 }
1402
1403 fn to_summarize_request(
1404 &self,
1405 model: &Arc<dyn LanguageModel>,
1406 intent: CompletionIntent,
1407 added_user_message: String,
1408 cx: &App,
1409 ) -> LanguageModelRequest {
1410 let mut request = LanguageModelRequest {
1411 thread_id: None,
1412 prompt_id: None,
1413 intent: Some(intent),
1414 mode: None,
1415 messages: vec![],
1416 tools: Vec::new(),
1417 tool_choice: None,
1418 stop: Vec::new(),
1419 temperature: AgentSettings::temperature_for_model(model, cx),
1420 };
1421
1422 for message in &self.messages {
1423 let mut request_message = LanguageModelRequestMessage {
1424 role: message.role,
1425 content: Vec::new(),
1426 cache: false,
1427 };
1428
1429 for segment in &message.segments {
1430 match segment {
1431 MessageSegment::Text(text) => request_message
1432 .content
1433 .push(MessageContent::Text(text.clone())),
1434 MessageSegment::Thinking { .. } => {}
1435 MessageSegment::RedactedThinking(_) => {}
1436 }
1437 }
1438
1439 if request_message.content.is_empty() {
1440 continue;
1441 }
1442
1443 request.messages.push(request_message);
1444 }
1445
1446 request.messages.push(LanguageModelRequestMessage {
1447 role: Role::User,
1448 content: vec![MessageContent::Text(added_user_message)],
1449 cache: false,
1450 });
1451
1452 request
1453 }
1454
1455 fn attached_tracked_files_state(
1456 &self,
1457 messages: &mut Vec<LanguageModelRequestMessage>,
1458 cx: &App,
1459 ) {
1460 const STALE_FILES_HEADER: &str = include_str!("./prompts/stale_files_prompt_header.txt");
1461
1462 let mut stale_message = String::new();
1463
1464 let action_log = self.action_log.read(cx);
1465
1466 for stale_file in action_log.stale_buffers(cx) {
1467 let Some(file) = stale_file.read(cx).file() else {
1468 continue;
1469 };
1470
1471 if stale_message.is_empty() {
1472 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER.trim()).ok();
1473 }
1474
1475 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1476 }
1477
1478 let mut content = Vec::with_capacity(2);
1479
1480 if !stale_message.is_empty() {
1481 content.push(stale_message.into());
1482 }
1483
1484 if !content.is_empty() {
1485 let context_message = LanguageModelRequestMessage {
1486 role: Role::User,
1487 content,
1488 cache: false,
1489 };
1490
1491 messages.push(context_message);
1492 }
1493 }
1494
1495 pub fn stream_completion(
1496 &mut self,
1497 request: LanguageModelRequest,
1498 model: Arc<dyn LanguageModel>,
1499 window: Option<AnyWindowHandle>,
1500 cx: &mut Context<Self>,
1501 ) {
1502 self.tool_use_limit_reached = false;
1503
1504 let pending_completion_id = post_inc(&mut self.completion_count);
1505 let mut request_callback_parameters = if self.request_callback.is_some() {
1506 Some((request.clone(), Vec::new()))
1507 } else {
1508 None
1509 };
1510 let prompt_id = self.last_prompt_id.clone();
1511 let tool_use_metadata = ToolUseMetadata {
1512 model: model.clone(),
1513 thread_id: self.id.clone(),
1514 prompt_id: prompt_id.clone(),
1515 };
1516
1517 self.last_received_chunk_at = Some(Instant::now());
1518
1519 let task = cx.spawn(async move |thread, cx| {
1520 let stream_completion_future = model.stream_completion(request, &cx);
1521 let initial_token_usage =
1522 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1523 let stream_completion = async {
1524 let mut events = stream_completion_future.await?;
1525
1526 let mut stop_reason = StopReason::EndTurn;
1527 let mut current_token_usage = TokenUsage::default();
1528
1529 thread
1530 .update(cx, |_thread, cx| {
1531 cx.emit(ThreadEvent::NewRequest);
1532 })
1533 .ok();
1534
1535 let mut request_assistant_message_id = None;
1536
1537 while let Some(event) = events.next().await {
1538 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1539 response_events
1540 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1541 }
1542
1543 thread.update(cx, |thread, cx| {
1544 let event = match event {
1545 Ok(event) => event,
1546 Err(LanguageModelCompletionError::BadInputJson {
1547 id,
1548 tool_name,
1549 raw_input: invalid_input_json,
1550 json_parse_error,
1551 }) => {
1552 thread.receive_invalid_tool_json(
1553 id,
1554 tool_name,
1555 invalid_input_json,
1556 json_parse_error,
1557 window,
1558 cx,
1559 );
1560 return Ok(());
1561 }
1562 Err(LanguageModelCompletionError::Other(error)) => {
1563 return Err(error);
1564 }
1565 Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
1566 return Err(err.into());
1567 }
1568 };
1569
1570 match event {
1571 LanguageModelCompletionEvent::StartMessage { .. } => {
1572 request_assistant_message_id =
1573 Some(thread.insert_assistant_message(
1574 vec![MessageSegment::Text(String::new())],
1575 cx,
1576 ));
1577 }
1578 LanguageModelCompletionEvent::Stop(reason) => {
1579 stop_reason = reason;
1580 }
1581 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1582 thread.update_token_usage_at_last_message(token_usage);
1583 thread.cumulative_token_usage = thread.cumulative_token_usage
1584 + token_usage
1585 - current_token_usage;
1586 current_token_usage = token_usage;
1587 }
1588 LanguageModelCompletionEvent::Text(chunk) => {
1589 thread.received_chunk();
1590
1591 cx.emit(ThreadEvent::ReceivedTextChunk);
1592 if let Some(last_message) = thread.messages.last_mut() {
1593 if last_message.role == Role::Assistant
1594 && !thread.tool_use.has_tool_results(last_message.id)
1595 {
1596 last_message.push_text(&chunk);
1597 cx.emit(ThreadEvent::StreamedAssistantText(
1598 last_message.id,
1599 chunk,
1600 ));
1601 } else {
1602 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1603 // of a new Assistant response.
1604 //
1605 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1606 // will result in duplicating the text of the chunk in the rendered Markdown.
1607 request_assistant_message_id =
1608 Some(thread.insert_assistant_message(
1609 vec![MessageSegment::Text(chunk.to_string())],
1610 cx,
1611 ));
1612 };
1613 }
1614 }
1615 LanguageModelCompletionEvent::Thinking {
1616 text: chunk,
1617 signature,
1618 } => {
1619 thread.received_chunk();
1620
1621 if let Some(last_message) = thread.messages.last_mut() {
1622 if last_message.role == Role::Assistant
1623 && !thread.tool_use.has_tool_results(last_message.id)
1624 {
1625 last_message.push_thinking(&chunk, signature);
1626 cx.emit(ThreadEvent::StreamedAssistantThinking(
1627 last_message.id,
1628 chunk,
1629 ));
1630 } else {
1631 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1632 // of a new Assistant response.
1633 //
1634 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1635 // will result in duplicating the text of the chunk in the rendered Markdown.
1636 request_assistant_message_id =
1637 Some(thread.insert_assistant_message(
1638 vec![MessageSegment::Thinking {
1639 text: chunk.to_string(),
1640 signature,
1641 }],
1642 cx,
1643 ));
1644 };
1645 }
1646 }
1647 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1648 let last_assistant_message_id = request_assistant_message_id
1649 .unwrap_or_else(|| {
1650 let new_assistant_message_id =
1651 thread.insert_assistant_message(vec![], cx);
1652 request_assistant_message_id =
1653 Some(new_assistant_message_id);
1654 new_assistant_message_id
1655 });
1656
1657 let tool_use_id = tool_use.id.clone();
1658 let streamed_input = if tool_use.is_input_complete {
1659 None
1660 } else {
1661 Some((&tool_use.input).clone())
1662 };
1663
1664 let ui_text = thread.tool_use.request_tool_use(
1665 last_assistant_message_id,
1666 tool_use,
1667 tool_use_metadata.clone(),
1668 cx,
1669 );
1670
1671 if let Some(input) = streamed_input {
1672 cx.emit(ThreadEvent::StreamedToolUse {
1673 tool_use_id,
1674 ui_text,
1675 input,
1676 });
1677 }
1678 }
1679 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1680 if let Some(completion) = thread
1681 .pending_completions
1682 .iter_mut()
1683 .find(|completion| completion.id == pending_completion_id)
1684 {
1685 match status_update {
1686 CompletionRequestStatus::Queued {
1687 position,
1688 } => {
1689 completion.queue_state = QueueState::Queued { position };
1690 }
1691 CompletionRequestStatus::Started => {
1692 completion.queue_state = QueueState::Started;
1693 }
1694 CompletionRequestStatus::Failed {
1695 code, message, request_id
1696 } => {
1697 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1698 }
1699 CompletionRequestStatus::UsageUpdated {
1700 amount, limit
1701 } => {
1702 let usage = RequestUsage { limit, amount: amount as i32 };
1703
1704 thread.last_usage = Some(usage);
1705 }
1706 CompletionRequestStatus::ToolUseLimitReached => {
1707 thread.tool_use_limit_reached = true;
1708 cx.emit(ThreadEvent::ToolUseLimitReached);
1709 }
1710 }
1711 }
1712 }
1713 }
1714
1715 thread.touch_updated_at();
1716 cx.emit(ThreadEvent::StreamedCompletion);
1717 cx.notify();
1718
1719 thread.auto_capture_telemetry(cx);
1720 Ok(())
1721 })??;
1722
1723 smol::future::yield_now().await;
1724 }
1725
1726 thread.update(cx, |thread, cx| {
1727 thread.last_received_chunk_at = None;
1728 thread
1729 .pending_completions
1730 .retain(|completion| completion.id != pending_completion_id);
1731
1732 // If there is a response without tool use, summarize the message. Otherwise,
1733 // allow two tool uses before summarizing.
1734 if matches!(thread.summary, ThreadSummary::Pending)
1735 && thread.messages.len() >= 2
1736 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1737 {
1738 thread.summarize(cx);
1739 }
1740 })?;
1741
1742 anyhow::Ok(stop_reason)
1743 };
1744
1745 let result = stream_completion.await;
1746
1747 thread
1748 .update(cx, |thread, cx| {
1749 thread.finalize_pending_checkpoint(cx);
1750 match result.as_ref() {
1751 Ok(stop_reason) => match stop_reason {
1752 StopReason::ToolUse => {
1753 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1754 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1755 }
1756 StopReason::EndTurn | StopReason::MaxTokens => {
1757 thread.project.update(cx, |project, cx| {
1758 project.set_agent_location(None, cx);
1759 });
1760 }
1761 StopReason::Refusal => {
1762 thread.project.update(cx, |project, cx| {
1763 project.set_agent_location(None, cx);
1764 });
1765
1766 // Remove the turn that was refused.
1767 //
1768 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1769 {
1770 let mut messages_to_remove = Vec::new();
1771
1772 for (ix, message) in thread.messages.iter().enumerate().rev() {
1773 messages_to_remove.push(message.id);
1774
1775 if message.role == Role::User {
1776 if ix == 0 {
1777 break;
1778 }
1779
1780 if let Some(prev_message) = thread.messages.get(ix - 1) {
1781 if prev_message.role == Role::Assistant {
1782 break;
1783 }
1784 }
1785 }
1786 }
1787
1788 for message_id in messages_to_remove {
1789 thread.delete_message(message_id, cx);
1790 }
1791 }
1792
1793 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1794 header: "Language model refusal".into(),
1795 message: "Model refused to generate content for safety reasons.".into(),
1796 }));
1797 }
1798 },
1799 Err(error) => {
1800 thread.project.update(cx, |project, cx| {
1801 project.set_agent_location(None, cx);
1802 });
1803
1804 if error.is::<PaymentRequiredError>() {
1805 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1806 } else if let Some(error) =
1807 error.downcast_ref::<ModelRequestLimitReachedError>()
1808 {
1809 cx.emit(ThreadEvent::ShowError(
1810 ThreadError::ModelRequestLimitReached { plan: error.plan },
1811 ));
1812 } else if let Some(known_error) =
1813 error.downcast_ref::<LanguageModelKnownError>()
1814 {
1815 match known_error {
1816 LanguageModelKnownError::ContextWindowLimitExceeded {
1817 tokens,
1818 } => {
1819 thread.exceeded_window_error = Some(ExceededWindowError {
1820 model_id: model.id(),
1821 token_count: *tokens,
1822 });
1823 cx.notify();
1824 }
1825 }
1826 } else {
1827 let error_message = error
1828 .chain()
1829 .map(|err| err.to_string())
1830 .collect::<Vec<_>>()
1831 .join("\n");
1832 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1833 header: "Error interacting with language model".into(),
1834 message: SharedString::from(error_message.clone()),
1835 }));
1836 }
1837
1838 thread.cancel_last_completion(window, cx);
1839 }
1840 }
1841
1842 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1843
1844 if let Some((request_callback, (request, response_events))) = thread
1845 .request_callback
1846 .as_mut()
1847 .zip(request_callback_parameters.as_ref())
1848 {
1849 request_callback(request, response_events);
1850 }
1851
1852 thread.auto_capture_telemetry(cx);
1853
1854 if let Ok(initial_usage) = initial_token_usage {
1855 let usage = thread.cumulative_token_usage - initial_usage;
1856
1857 telemetry::event!(
1858 "Assistant Thread Completion",
1859 thread_id = thread.id().to_string(),
1860 prompt_id = prompt_id,
1861 model = model.telemetry_id(),
1862 model_provider = model.provider_id().to_string(),
1863 input_tokens = usage.input_tokens,
1864 output_tokens = usage.output_tokens,
1865 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1866 cache_read_input_tokens = usage.cache_read_input_tokens,
1867 );
1868 }
1869 })
1870 .ok();
1871 });
1872
1873 self.pending_completions.push(PendingCompletion {
1874 id: pending_completion_id,
1875 queue_state: QueueState::Sending,
1876 _task: task,
1877 });
1878 }
1879
1880 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1881 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1882 println!("No thread summary model");
1883 return;
1884 };
1885
1886 if !model.provider.is_authenticated(cx) {
1887 return;
1888 }
1889
1890 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1891
1892 let request = self.to_summarize_request(
1893 &model.model,
1894 CompletionIntent::ThreadSummarization,
1895 added_user_message.into(),
1896 cx,
1897 );
1898
1899 self.summary = ThreadSummary::Generating;
1900
1901 self.pending_summary = cx.spawn(async move |this, cx| {
1902 let result = async {
1903 let mut messages = model.model.stream_completion(request, &cx).await?;
1904
1905 let mut new_summary = String::new();
1906 while let Some(event) = messages.next().await {
1907 let Ok(event) = event else {
1908 continue;
1909 };
1910 let text = match event {
1911 LanguageModelCompletionEvent::Text(text) => text,
1912 LanguageModelCompletionEvent::StatusUpdate(
1913 CompletionRequestStatus::UsageUpdated { amount, limit },
1914 ) => {
1915 this.update(cx, |thread, _cx| {
1916 thread.last_usage = Some(RequestUsage {
1917 limit,
1918 amount: amount as i32,
1919 });
1920 })?;
1921 continue;
1922 }
1923 _ => continue,
1924 };
1925
1926 let mut lines = text.lines();
1927 new_summary.extend(lines.next());
1928
1929 // Stop if the LLM generated multiple lines.
1930 if lines.next().is_some() {
1931 break;
1932 }
1933 }
1934
1935 anyhow::Ok(new_summary)
1936 }
1937 .await;
1938
1939 this.update(cx, |this, cx| {
1940 match result {
1941 Ok(new_summary) => {
1942 if new_summary.is_empty() {
1943 this.summary = ThreadSummary::Error;
1944 } else {
1945 this.summary = ThreadSummary::Ready(new_summary.into());
1946 }
1947 }
1948 Err(err) => {
1949 this.summary = ThreadSummary::Error;
1950 log::error!("Failed to generate thread summary: {}", err);
1951 }
1952 }
1953 cx.emit(ThreadEvent::SummaryGenerated);
1954 })
1955 .log_err()?;
1956
1957 Some(())
1958 });
1959 }
1960
1961 pub fn start_generating_detailed_summary_if_needed(
1962 &mut self,
1963 thread_store: WeakEntity<ThreadStore>,
1964 cx: &mut Context<Self>,
1965 ) {
1966 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1967 return;
1968 };
1969
1970 match &*self.detailed_summary_rx.borrow() {
1971 DetailedSummaryState::Generating { message_id, .. }
1972 | DetailedSummaryState::Generated { message_id, .. }
1973 if *message_id == last_message_id =>
1974 {
1975 // Already up-to-date
1976 return;
1977 }
1978 _ => {}
1979 }
1980
1981 let Some(ConfiguredModel { model, provider }) =
1982 LanguageModelRegistry::read_global(cx).thread_summary_model()
1983 else {
1984 return;
1985 };
1986
1987 if !provider.is_authenticated(cx) {
1988 return;
1989 }
1990
1991 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
1992
1993 let request = self.to_summarize_request(
1994 &model,
1995 CompletionIntent::ThreadContextSummarization,
1996 added_user_message.into(),
1997 cx,
1998 );
1999
2000 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2001 message_id: last_message_id,
2002 };
2003
2004 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2005 // be better to allow the old task to complete, but this would require logic for choosing
2006 // which result to prefer (the old task could complete after the new one, resulting in a
2007 // stale summary).
2008 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2009 let stream = model.stream_completion_text(request, &cx);
2010 let Some(mut messages) = stream.await.log_err() else {
2011 thread
2012 .update(cx, |thread, _cx| {
2013 *thread.detailed_summary_tx.borrow_mut() =
2014 DetailedSummaryState::NotGenerated;
2015 })
2016 .ok()?;
2017 return None;
2018 };
2019
2020 let mut new_detailed_summary = String::new();
2021
2022 while let Some(chunk) = messages.stream.next().await {
2023 if let Some(chunk) = chunk.log_err() {
2024 new_detailed_summary.push_str(&chunk);
2025 }
2026 }
2027
2028 thread
2029 .update(cx, |thread, _cx| {
2030 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2031 text: new_detailed_summary.into(),
2032 message_id: last_message_id,
2033 };
2034 })
2035 .ok()?;
2036
2037 // Save thread so its summary can be reused later
2038 if let Some(thread) = thread.upgrade() {
2039 if let Ok(Ok(save_task)) = cx.update(|cx| {
2040 thread_store
2041 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2042 }) {
2043 save_task.await.log_err();
2044 }
2045 }
2046
2047 Some(())
2048 });
2049 }
2050
2051 pub async fn wait_for_detailed_summary_or_text(
2052 this: &Entity<Self>,
2053 cx: &mut AsyncApp,
2054 ) -> Option<SharedString> {
2055 let mut detailed_summary_rx = this
2056 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2057 .ok()?;
2058 loop {
2059 match detailed_summary_rx.recv().await? {
2060 DetailedSummaryState::Generating { .. } => {}
2061 DetailedSummaryState::NotGenerated => {
2062 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2063 }
2064 DetailedSummaryState::Generated { text, .. } => return Some(text),
2065 }
2066 }
2067 }
2068
2069 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2070 self.detailed_summary_rx
2071 .borrow()
2072 .text()
2073 .unwrap_or_else(|| self.text().into())
2074 }
2075
2076 pub fn is_generating_detailed_summary(&self) -> bool {
2077 matches!(
2078 &*self.detailed_summary_rx.borrow(),
2079 DetailedSummaryState::Generating { .. }
2080 )
2081 }
2082
2083 pub fn use_pending_tools(
2084 &mut self,
2085 window: Option<AnyWindowHandle>,
2086 cx: &mut Context<Self>,
2087 model: Arc<dyn LanguageModel>,
2088 ) -> Vec<PendingToolUse> {
2089 self.auto_capture_telemetry(cx);
2090 let request =
2091 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2092 let pending_tool_uses = self
2093 .tool_use
2094 .pending_tool_uses()
2095 .into_iter()
2096 .filter(|tool_use| tool_use.status.is_idle())
2097 .cloned()
2098 .collect::<Vec<_>>();
2099
2100 for tool_use in pending_tool_uses.iter() {
2101 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2102 if tool.needs_confirmation(&tool_use.input, cx)
2103 && !AgentSettings::get_global(cx).always_allow_tool_actions
2104 {
2105 self.tool_use.confirm_tool_use(
2106 tool_use.id.clone(),
2107 tool_use.ui_text.clone(),
2108 tool_use.input.clone(),
2109 request.clone(),
2110 tool,
2111 );
2112 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2113 } else {
2114 self.run_tool(
2115 tool_use.id.clone(),
2116 tool_use.ui_text.clone(),
2117 tool_use.input.clone(),
2118 request.clone(),
2119 tool,
2120 model.clone(),
2121 window,
2122 cx,
2123 );
2124 }
2125 } else {
2126 self.handle_hallucinated_tool_use(
2127 tool_use.id.clone(),
2128 tool_use.name.clone(),
2129 window,
2130 cx,
2131 );
2132 }
2133 }
2134
2135 pending_tool_uses
2136 }
2137
2138 pub fn handle_hallucinated_tool_use(
2139 &mut self,
2140 tool_use_id: LanguageModelToolUseId,
2141 hallucinated_tool_name: Arc<str>,
2142 window: Option<AnyWindowHandle>,
2143 cx: &mut Context<Thread>,
2144 ) {
2145 let available_tools = self.profile.enabled_tools(cx);
2146
2147 let tool_list = available_tools
2148 .iter()
2149 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2150 .collect::<Vec<_>>()
2151 .join("\n");
2152
2153 let error_message = format!(
2154 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2155 hallucinated_tool_name, tool_list
2156 );
2157
2158 let pending_tool_use = self.tool_use.insert_tool_output(
2159 tool_use_id.clone(),
2160 hallucinated_tool_name,
2161 Err(anyhow!("Missing tool call: {error_message}")),
2162 self.configured_model.as_ref(),
2163 );
2164
2165 cx.emit(ThreadEvent::MissingToolUse {
2166 tool_use_id: tool_use_id.clone(),
2167 ui_text: error_message.into(),
2168 });
2169
2170 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2171 }
2172
2173 pub fn receive_invalid_tool_json(
2174 &mut self,
2175 tool_use_id: LanguageModelToolUseId,
2176 tool_name: Arc<str>,
2177 invalid_json: Arc<str>,
2178 error: String,
2179 window: Option<AnyWindowHandle>,
2180 cx: &mut Context<Thread>,
2181 ) {
2182 log::error!("The model returned invalid input JSON: {invalid_json}");
2183
2184 let pending_tool_use = self.tool_use.insert_tool_output(
2185 tool_use_id.clone(),
2186 tool_name,
2187 Err(anyhow!("Error parsing input JSON: {error}")),
2188 self.configured_model.as_ref(),
2189 );
2190 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2191 pending_tool_use.ui_text.clone()
2192 } else {
2193 log::error!(
2194 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2195 );
2196 format!("Unknown tool {}", tool_use_id).into()
2197 };
2198
2199 cx.emit(ThreadEvent::InvalidToolInput {
2200 tool_use_id: tool_use_id.clone(),
2201 ui_text,
2202 invalid_input_json: invalid_json,
2203 });
2204
2205 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2206 }
2207
2208 pub fn run_tool(
2209 &mut self,
2210 tool_use_id: LanguageModelToolUseId,
2211 ui_text: impl Into<SharedString>,
2212 input: serde_json::Value,
2213 request: Arc<LanguageModelRequest>,
2214 tool: Arc<dyn Tool>,
2215 model: Arc<dyn LanguageModel>,
2216 window: Option<AnyWindowHandle>,
2217 cx: &mut Context<Thread>,
2218 ) {
2219 let task =
2220 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2221 self.tool_use
2222 .run_pending_tool(tool_use_id, ui_text.into(), task);
2223 }
2224
2225 fn spawn_tool_use(
2226 &mut self,
2227 tool_use_id: LanguageModelToolUseId,
2228 request: Arc<LanguageModelRequest>,
2229 input: serde_json::Value,
2230 tool: Arc<dyn Tool>,
2231 model: Arc<dyn LanguageModel>,
2232 window: Option<AnyWindowHandle>,
2233 cx: &mut Context<Thread>,
2234 ) -> Task<()> {
2235 let tool_name: Arc<str> = tool.name().into();
2236
2237 let tool_result = tool.run(
2238 input,
2239 request,
2240 self.project.clone(),
2241 self.action_log.clone(),
2242 model,
2243 window,
2244 cx,
2245 );
2246
2247 // Store the card separately if it exists
2248 if let Some(card) = tool_result.card.clone() {
2249 self.tool_use
2250 .insert_tool_result_card(tool_use_id.clone(), card);
2251 }
2252
2253 cx.spawn({
2254 async move |thread: WeakEntity<Thread>, cx| {
2255 let output = tool_result.output.await;
2256
2257 thread
2258 .update(cx, |thread, cx| {
2259 let pending_tool_use = thread.tool_use.insert_tool_output(
2260 tool_use_id.clone(),
2261 tool_name,
2262 output,
2263 thread.configured_model.as_ref(),
2264 );
2265 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2266 })
2267 .ok();
2268 }
2269 })
2270 }
2271
2272 fn tool_finished(
2273 &mut self,
2274 tool_use_id: LanguageModelToolUseId,
2275 pending_tool_use: Option<PendingToolUse>,
2276 canceled: bool,
2277 window: Option<AnyWindowHandle>,
2278 cx: &mut Context<Self>,
2279 ) {
2280 if self.all_tools_finished() {
2281 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2282 if !canceled {
2283 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2284 }
2285 self.auto_capture_telemetry(cx);
2286 }
2287 }
2288
2289 cx.emit(ThreadEvent::ToolFinished {
2290 tool_use_id,
2291 pending_tool_use,
2292 });
2293 }
2294
2295 /// Cancels the last pending completion, if there are any pending.
2296 ///
2297 /// Returns whether a completion was canceled.
2298 pub fn cancel_last_completion(
2299 &mut self,
2300 window: Option<AnyWindowHandle>,
2301 cx: &mut Context<Self>,
2302 ) -> bool {
2303 let mut canceled = self.pending_completions.pop().is_some();
2304
2305 for pending_tool_use in self.tool_use.cancel_pending() {
2306 canceled = true;
2307 self.tool_finished(
2308 pending_tool_use.id.clone(),
2309 Some(pending_tool_use),
2310 true,
2311 window,
2312 cx,
2313 );
2314 }
2315
2316 if canceled {
2317 cx.emit(ThreadEvent::CompletionCanceled);
2318
2319 // When canceled, we always want to insert the checkpoint.
2320 // (We skip over finalize_pending_checkpoint, because it
2321 // would conclude we didn't have anything to insert here.)
2322 if let Some(checkpoint) = self.pending_checkpoint.take() {
2323 self.insert_checkpoint(checkpoint, cx);
2324 }
2325 } else {
2326 self.finalize_pending_checkpoint(cx);
2327 }
2328
2329 canceled
2330 }
2331
2332 /// Signals that any in-progress editing should be canceled.
2333 ///
2334 /// This method is used to notify listeners (like ActiveThread) that
2335 /// they should cancel any editing operations.
2336 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2337 cx.emit(ThreadEvent::CancelEditing);
2338 }
2339
2340 pub fn feedback(&self) -> Option<ThreadFeedback> {
2341 self.feedback
2342 }
2343
2344 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2345 self.message_feedback.get(&message_id).copied()
2346 }
2347
2348 pub fn report_message_feedback(
2349 &mut self,
2350 message_id: MessageId,
2351 feedback: ThreadFeedback,
2352 cx: &mut Context<Self>,
2353 ) -> Task<Result<()>> {
2354 if self.message_feedback.get(&message_id) == Some(&feedback) {
2355 return Task::ready(Ok(()));
2356 }
2357
2358 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2359 let serialized_thread = self.serialize(cx);
2360 let thread_id = self.id().clone();
2361 let client = self.project.read(cx).client();
2362
2363 let enabled_tool_names: Vec<String> = self
2364 .profile
2365 .enabled_tools(cx)
2366 .iter()
2367 .map(|tool| tool.name())
2368 .collect();
2369
2370 self.message_feedback.insert(message_id, feedback);
2371
2372 cx.notify();
2373
2374 let message_content = self
2375 .message(message_id)
2376 .map(|msg| msg.to_string())
2377 .unwrap_or_default();
2378
2379 cx.background_spawn(async move {
2380 let final_project_snapshot = final_project_snapshot.await;
2381 let serialized_thread = serialized_thread.await?;
2382 let thread_data =
2383 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2384
2385 let rating = match feedback {
2386 ThreadFeedback::Positive => "positive",
2387 ThreadFeedback::Negative => "negative",
2388 };
2389 telemetry::event!(
2390 "Assistant Thread Rated",
2391 rating,
2392 thread_id,
2393 enabled_tool_names,
2394 message_id = message_id.0,
2395 message_content,
2396 thread_data,
2397 final_project_snapshot
2398 );
2399 client.telemetry().flush_events().await;
2400
2401 Ok(())
2402 })
2403 }
2404
2405 pub fn report_feedback(
2406 &mut self,
2407 feedback: ThreadFeedback,
2408 cx: &mut Context<Self>,
2409 ) -> Task<Result<()>> {
2410 let last_assistant_message_id = self
2411 .messages
2412 .iter()
2413 .rev()
2414 .find(|msg| msg.role == Role::Assistant)
2415 .map(|msg| msg.id);
2416
2417 if let Some(message_id) = last_assistant_message_id {
2418 self.report_message_feedback(message_id, feedback, cx)
2419 } else {
2420 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2421 let serialized_thread = self.serialize(cx);
2422 let thread_id = self.id().clone();
2423 let client = self.project.read(cx).client();
2424 self.feedback = Some(feedback);
2425 cx.notify();
2426
2427 cx.background_spawn(async move {
2428 let final_project_snapshot = final_project_snapshot.await;
2429 let serialized_thread = serialized_thread.await?;
2430 let thread_data = serde_json::to_value(serialized_thread)
2431 .unwrap_or_else(|_| serde_json::Value::Null);
2432
2433 let rating = match feedback {
2434 ThreadFeedback::Positive => "positive",
2435 ThreadFeedback::Negative => "negative",
2436 };
2437 telemetry::event!(
2438 "Assistant Thread Rated",
2439 rating,
2440 thread_id,
2441 thread_data,
2442 final_project_snapshot
2443 );
2444 client.telemetry().flush_events().await;
2445
2446 Ok(())
2447 })
2448 }
2449 }
2450
2451 /// Create a snapshot of the current project state including git information and unsaved buffers.
2452 fn project_snapshot(
2453 project: Entity<Project>,
2454 cx: &mut Context<Self>,
2455 ) -> Task<Arc<ProjectSnapshot>> {
2456 let git_store = project.read(cx).git_store().clone();
2457 let worktree_snapshots: Vec<_> = project
2458 .read(cx)
2459 .visible_worktrees(cx)
2460 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2461 .collect();
2462
2463 cx.spawn(async move |_, cx| {
2464 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2465
2466 let mut unsaved_buffers = Vec::new();
2467 cx.update(|app_cx| {
2468 let buffer_store = project.read(app_cx).buffer_store();
2469 for buffer_handle in buffer_store.read(app_cx).buffers() {
2470 let buffer = buffer_handle.read(app_cx);
2471 if buffer.is_dirty() {
2472 if let Some(file) = buffer.file() {
2473 let path = file.path().to_string_lossy().to_string();
2474 unsaved_buffers.push(path);
2475 }
2476 }
2477 }
2478 })
2479 .ok();
2480
2481 Arc::new(ProjectSnapshot {
2482 worktree_snapshots,
2483 unsaved_buffer_paths: unsaved_buffers,
2484 timestamp: Utc::now(),
2485 })
2486 })
2487 }
2488
2489 fn worktree_snapshot(
2490 worktree: Entity<project::Worktree>,
2491 git_store: Entity<GitStore>,
2492 cx: &App,
2493 ) -> Task<WorktreeSnapshot> {
2494 cx.spawn(async move |cx| {
2495 // Get worktree path and snapshot
2496 let worktree_info = cx.update(|app_cx| {
2497 let worktree = worktree.read(app_cx);
2498 let path = worktree.abs_path().to_string_lossy().to_string();
2499 let snapshot = worktree.snapshot();
2500 (path, snapshot)
2501 });
2502
2503 let Ok((worktree_path, _snapshot)) = worktree_info else {
2504 return WorktreeSnapshot {
2505 worktree_path: String::new(),
2506 git_state: None,
2507 };
2508 };
2509
2510 let git_state = git_store
2511 .update(cx, |git_store, cx| {
2512 git_store
2513 .repositories()
2514 .values()
2515 .find(|repo| {
2516 repo.read(cx)
2517 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2518 .is_some()
2519 })
2520 .cloned()
2521 })
2522 .ok()
2523 .flatten()
2524 .map(|repo| {
2525 repo.update(cx, |repo, _| {
2526 let current_branch =
2527 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2528 repo.send_job(None, |state, _| async move {
2529 let RepositoryState::Local { backend, .. } = state else {
2530 return GitState {
2531 remote_url: None,
2532 head_sha: None,
2533 current_branch,
2534 diff: None,
2535 };
2536 };
2537
2538 let remote_url = backend.remote_url("origin");
2539 let head_sha = backend.head_sha().await;
2540 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2541
2542 GitState {
2543 remote_url,
2544 head_sha,
2545 current_branch,
2546 diff,
2547 }
2548 })
2549 })
2550 });
2551
2552 let git_state = match git_state {
2553 Some(git_state) => match git_state.ok() {
2554 Some(git_state) => git_state.await.ok(),
2555 None => None,
2556 },
2557 None => None,
2558 };
2559
2560 WorktreeSnapshot {
2561 worktree_path,
2562 git_state,
2563 }
2564 })
2565 }
2566
2567 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2568 let mut markdown = Vec::new();
2569
2570 let summary = self.summary().or_default();
2571 writeln!(markdown, "# {summary}\n")?;
2572
2573 for message in self.messages() {
2574 writeln!(
2575 markdown,
2576 "## {role}\n",
2577 role = match message.role {
2578 Role::User => "User",
2579 Role::Assistant => "Agent",
2580 Role::System => "System",
2581 }
2582 )?;
2583
2584 if !message.loaded_context.text.is_empty() {
2585 writeln!(markdown, "{}", message.loaded_context.text)?;
2586 }
2587
2588 if !message.loaded_context.images.is_empty() {
2589 writeln!(
2590 markdown,
2591 "\n{} images attached as context.\n",
2592 message.loaded_context.images.len()
2593 )?;
2594 }
2595
2596 for segment in &message.segments {
2597 match segment {
2598 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2599 MessageSegment::Thinking { text, .. } => {
2600 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2601 }
2602 MessageSegment::RedactedThinking(_) => {}
2603 }
2604 }
2605
2606 for tool_use in self.tool_uses_for_message(message.id, cx) {
2607 writeln!(
2608 markdown,
2609 "**Use Tool: {} ({})**",
2610 tool_use.name, tool_use.id
2611 )?;
2612 writeln!(markdown, "```json")?;
2613 writeln!(
2614 markdown,
2615 "{}",
2616 serde_json::to_string_pretty(&tool_use.input)?
2617 )?;
2618 writeln!(markdown, "```")?;
2619 }
2620
2621 for tool_result in self.tool_results_for_message(message.id) {
2622 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2623 if tool_result.is_error {
2624 write!(markdown, " (Error)")?;
2625 }
2626
2627 writeln!(markdown, "**\n")?;
2628 match &tool_result.content {
2629 LanguageModelToolResultContent::Text(text) => {
2630 writeln!(markdown, "{text}")?;
2631 }
2632 LanguageModelToolResultContent::Image(image) => {
2633 writeln!(markdown, "", image.source)?;
2634 }
2635 }
2636
2637 if let Some(output) = tool_result.output.as_ref() {
2638 writeln!(
2639 markdown,
2640 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2641 serde_json::to_string_pretty(output)?
2642 )?;
2643 }
2644 }
2645 }
2646
2647 Ok(String::from_utf8_lossy(&markdown).to_string())
2648 }
2649
2650 pub fn keep_edits_in_range(
2651 &mut self,
2652 buffer: Entity<language::Buffer>,
2653 buffer_range: Range<language::Anchor>,
2654 cx: &mut Context<Self>,
2655 ) {
2656 self.action_log.update(cx, |action_log, cx| {
2657 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2658 });
2659 }
2660
2661 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2662 self.action_log
2663 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2664 }
2665
2666 pub fn reject_edits_in_ranges(
2667 &mut self,
2668 buffer: Entity<language::Buffer>,
2669 buffer_ranges: Vec<Range<language::Anchor>>,
2670 cx: &mut Context<Self>,
2671 ) -> Task<Result<()>> {
2672 self.action_log.update(cx, |action_log, cx| {
2673 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2674 })
2675 }
2676
2677 pub fn action_log(&self) -> &Entity<ActionLog> {
2678 &self.action_log
2679 }
2680
2681 pub fn project(&self) -> &Entity<Project> {
2682 &self.project
2683 }
2684
2685 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2686 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2687 return;
2688 }
2689
2690 let now = Instant::now();
2691 if let Some(last) = self.last_auto_capture_at {
2692 if now.duration_since(last).as_secs() < 10 {
2693 return;
2694 }
2695 }
2696
2697 self.last_auto_capture_at = Some(now);
2698
2699 let thread_id = self.id().clone();
2700 let github_login = self
2701 .project
2702 .read(cx)
2703 .user_store()
2704 .read(cx)
2705 .current_user()
2706 .map(|user| user.github_login.clone());
2707 let client = self.project.read(cx).client();
2708 let serialize_task = self.serialize(cx);
2709
2710 cx.background_executor()
2711 .spawn(async move {
2712 if let Ok(serialized_thread) = serialize_task.await {
2713 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2714 telemetry::event!(
2715 "Agent Thread Auto-Captured",
2716 thread_id = thread_id.to_string(),
2717 thread_data = thread_data,
2718 auto_capture_reason = "tracked_user",
2719 github_login = github_login
2720 );
2721
2722 client.telemetry().flush_events().await;
2723 }
2724 }
2725 })
2726 .detach();
2727 }
2728
2729 pub fn cumulative_token_usage(&self) -> TokenUsage {
2730 self.cumulative_token_usage
2731 }
2732
2733 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2734 let Some(model) = self.configured_model.as_ref() else {
2735 return TotalTokenUsage::default();
2736 };
2737
2738 let max = model.model.max_token_count();
2739
2740 let index = self
2741 .messages
2742 .iter()
2743 .position(|msg| msg.id == message_id)
2744 .unwrap_or(0);
2745
2746 if index == 0 {
2747 return TotalTokenUsage { total: 0, max };
2748 }
2749
2750 let token_usage = &self
2751 .request_token_usage
2752 .get(index - 1)
2753 .cloned()
2754 .unwrap_or_default();
2755
2756 TotalTokenUsage {
2757 total: token_usage.total_tokens() as usize,
2758 max,
2759 }
2760 }
2761
2762 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2763 let model = self.configured_model.as_ref()?;
2764
2765 let max = model.model.max_token_count();
2766
2767 if let Some(exceeded_error) = &self.exceeded_window_error {
2768 if model.model.id() == exceeded_error.model_id {
2769 return Some(TotalTokenUsage {
2770 total: exceeded_error.token_count,
2771 max,
2772 });
2773 }
2774 }
2775
2776 let total = self
2777 .token_usage_at_last_message()
2778 .unwrap_or_default()
2779 .total_tokens() as usize;
2780
2781 Some(TotalTokenUsage { total, max })
2782 }
2783
2784 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2785 self.request_token_usage
2786 .get(self.messages.len().saturating_sub(1))
2787 .or_else(|| self.request_token_usage.last())
2788 .cloned()
2789 }
2790
2791 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2792 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2793 self.request_token_usage
2794 .resize(self.messages.len(), placeholder);
2795
2796 if let Some(last) = self.request_token_usage.last_mut() {
2797 *last = token_usage;
2798 }
2799 }
2800
2801 pub fn deny_tool_use(
2802 &mut self,
2803 tool_use_id: LanguageModelToolUseId,
2804 tool_name: Arc<str>,
2805 window: Option<AnyWindowHandle>,
2806 cx: &mut Context<Self>,
2807 ) {
2808 let err = Err(anyhow::anyhow!(
2809 "Permission to run tool action denied by user"
2810 ));
2811
2812 self.tool_use.insert_tool_output(
2813 tool_use_id.clone(),
2814 tool_name,
2815 err,
2816 self.configured_model.as_ref(),
2817 );
2818 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2819 }
2820}
2821
2822#[derive(Debug, Clone, Error)]
2823pub enum ThreadError {
2824 #[error("Payment required")]
2825 PaymentRequired,
2826 #[error("Model request limit reached")]
2827 ModelRequestLimitReached { plan: Plan },
2828 #[error("Message {header}: {message}")]
2829 Message {
2830 header: SharedString,
2831 message: SharedString,
2832 },
2833}
2834
2835#[derive(Debug, Clone)]
2836pub enum ThreadEvent {
2837 ShowError(ThreadError),
2838 StreamedCompletion,
2839 ReceivedTextChunk,
2840 NewRequest,
2841 StreamedAssistantText(MessageId, String),
2842 StreamedAssistantThinking(MessageId, String),
2843 StreamedToolUse {
2844 tool_use_id: LanguageModelToolUseId,
2845 ui_text: Arc<str>,
2846 input: serde_json::Value,
2847 },
2848 MissingToolUse {
2849 tool_use_id: LanguageModelToolUseId,
2850 ui_text: Arc<str>,
2851 },
2852 InvalidToolInput {
2853 tool_use_id: LanguageModelToolUseId,
2854 ui_text: Arc<str>,
2855 invalid_input_json: Arc<str>,
2856 },
2857 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2858 MessageAdded(MessageId),
2859 MessageEdited(MessageId),
2860 MessageDeleted(MessageId),
2861 SummaryGenerated,
2862 SummaryChanged,
2863 UsePendingTools {
2864 tool_uses: Vec<PendingToolUse>,
2865 },
2866 ToolFinished {
2867 #[allow(unused)]
2868 tool_use_id: LanguageModelToolUseId,
2869 /// The pending tool use that corresponds to this tool.
2870 pending_tool_use: Option<PendingToolUse>,
2871 },
2872 CheckpointChanged,
2873 ToolConfirmationNeeded,
2874 ToolUseLimitReached,
2875 CancelEditing,
2876 CompletionCanceled,
2877 ProfileChanged,
2878}
2879
2880impl EventEmitter<ThreadEvent> for Thread {}
2881
2882struct PendingCompletion {
2883 id: usize,
2884 queue_state: QueueState,
2885 _task: Task<()>,
2886}
2887
2888/// Resolves tool name conflicts by ensuring all tool names are unique.
2889///
2890/// When multiple tools have the same name, this function applies the following rules:
2891/// 1. Native tools always keep their original name
2892/// 2. Context server tools get prefixed with their server ID and an underscore
2893/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
2894/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
2895///
2896/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
2897fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
2898 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
2899 let mut tool_name = tool.name();
2900 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2901 tool_name
2902 }
2903
2904 const MAX_TOOL_NAME_LENGTH: usize = 64;
2905
2906 let mut duplicated_tool_names = HashSet::default();
2907 let mut seen_tool_names = HashSet::default();
2908 for tool in tools {
2909 let tool_name = resolve_tool_name(tool);
2910 if seen_tool_names.contains(&tool_name) {
2911 debug_assert!(
2912 tool.source() != assistant_tool::ToolSource::Native,
2913 "There are two built-in tools with the same name: {}",
2914 tool_name
2915 );
2916 duplicated_tool_names.insert(tool_name);
2917 } else {
2918 seen_tool_names.insert(tool_name);
2919 }
2920 }
2921
2922 if duplicated_tool_names.is_empty() {
2923 return tools
2924 .into_iter()
2925 .map(|tool| (resolve_tool_name(tool), tool.clone()))
2926 .collect();
2927 }
2928
2929 tools
2930 .into_iter()
2931 .filter_map(|tool| {
2932 let mut tool_name = resolve_tool_name(tool);
2933 if !duplicated_tool_names.contains(&tool_name) {
2934 return Some((tool_name, tool.clone()));
2935 }
2936 match tool.source() {
2937 assistant_tool::ToolSource::Native => {
2938 // Built-in tools always keep their original name
2939 Some((tool_name, tool.clone()))
2940 }
2941 assistant_tool::ToolSource::ContextServer { id } => {
2942 // Context server tools are prefixed with the context server ID, and truncated if necessary
2943 tool_name.insert(0, '_');
2944 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
2945 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
2946 let mut id = id.to_string();
2947 id.truncate(len);
2948 tool_name.insert_str(0, &id);
2949 } else {
2950 tool_name.insert_str(0, &id);
2951 }
2952
2953 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2954
2955 if seen_tool_names.contains(&tool_name) {
2956 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
2957 None
2958 } else {
2959 Some((tool_name, tool.clone()))
2960 }
2961 }
2962 }
2963 })
2964 .collect()
2965}
2966
2967#[cfg(test)]
2968mod tests {
2969 use super::*;
2970 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2971 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
2972 use assistant_tool::ToolRegistry;
2973 use editor::EditorSettings;
2974 use gpui::TestAppContext;
2975 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2976 use project::{FakeFs, Project};
2977 use prompt_store::PromptBuilder;
2978 use serde_json::json;
2979 use settings::{Settings, SettingsStore};
2980 use std::sync::Arc;
2981 use theme::ThemeSettings;
2982 use ui::IconName;
2983 use util::path;
2984 use workspace::Workspace;
2985
2986 #[gpui::test]
2987 async fn test_message_with_context(cx: &mut TestAppContext) {
2988 init_test_settings(cx);
2989
2990 let project = create_test_project(
2991 cx,
2992 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2993 )
2994 .await;
2995
2996 let (_workspace, _thread_store, thread, context_store, model) =
2997 setup_test_environment(cx, project.clone()).await;
2998
2999 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3000 .await
3001 .unwrap();
3002
3003 let context =
3004 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3005 let loaded_context = cx
3006 .update(|cx| load_context(vec![context], &project, &None, cx))
3007 .await;
3008
3009 // Insert user message with context
3010 let message_id = thread.update(cx, |thread, cx| {
3011 thread.insert_user_message(
3012 "Please explain this code",
3013 loaded_context,
3014 None,
3015 Vec::new(),
3016 cx,
3017 )
3018 });
3019
3020 // Check content and context in message object
3021 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3022
3023 // Use different path format strings based on platform for the test
3024 #[cfg(windows)]
3025 let path_part = r"test\code.rs";
3026 #[cfg(not(windows))]
3027 let path_part = "test/code.rs";
3028
3029 let expected_context = format!(
3030 r#"
3031<context>
3032The following items were attached by the user. They are up-to-date and don't need to be re-read.
3033
3034<files>
3035```rs {path_part}
3036fn main() {{
3037 println!("Hello, world!");
3038}}
3039```
3040</files>
3041</context>
3042"#
3043 );
3044
3045 assert_eq!(message.role, Role::User);
3046 assert_eq!(message.segments.len(), 1);
3047 assert_eq!(
3048 message.segments[0],
3049 MessageSegment::Text("Please explain this code".to_string())
3050 );
3051 assert_eq!(message.loaded_context.text, expected_context);
3052
3053 // Check message in request
3054 let request = thread.update(cx, |thread, cx| {
3055 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3056 });
3057
3058 assert_eq!(request.messages.len(), 2);
3059 let expected_full_message = format!("{}Please explain this code", expected_context);
3060 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3061 }
3062
3063 #[gpui::test]
3064 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3065 init_test_settings(cx);
3066
3067 let project = create_test_project(
3068 cx,
3069 json!({
3070 "file1.rs": "fn function1() {}\n",
3071 "file2.rs": "fn function2() {}\n",
3072 "file3.rs": "fn function3() {}\n",
3073 "file4.rs": "fn function4() {}\n",
3074 }),
3075 )
3076 .await;
3077
3078 let (_, _thread_store, thread, context_store, model) =
3079 setup_test_environment(cx, project.clone()).await;
3080
3081 // First message with context 1
3082 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3083 .await
3084 .unwrap();
3085 let new_contexts = context_store.update(cx, |store, cx| {
3086 store.new_context_for_thread(thread.read(cx), None)
3087 });
3088 assert_eq!(new_contexts.len(), 1);
3089 let loaded_context = cx
3090 .update(|cx| load_context(new_contexts, &project, &None, cx))
3091 .await;
3092 let message1_id = thread.update(cx, |thread, cx| {
3093 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3094 });
3095
3096 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3097 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3098 .await
3099 .unwrap();
3100 let new_contexts = context_store.update(cx, |store, cx| {
3101 store.new_context_for_thread(thread.read(cx), None)
3102 });
3103 assert_eq!(new_contexts.len(), 1);
3104 let loaded_context = cx
3105 .update(|cx| load_context(new_contexts, &project, &None, cx))
3106 .await;
3107 let message2_id = thread.update(cx, |thread, cx| {
3108 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3109 });
3110
3111 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3112 //
3113 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3114 .await
3115 .unwrap();
3116 let new_contexts = context_store.update(cx, |store, cx| {
3117 store.new_context_for_thread(thread.read(cx), None)
3118 });
3119 assert_eq!(new_contexts.len(), 1);
3120 let loaded_context = cx
3121 .update(|cx| load_context(new_contexts, &project, &None, cx))
3122 .await;
3123 let message3_id = thread.update(cx, |thread, cx| {
3124 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3125 });
3126
3127 // Check what contexts are included in each message
3128 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3129 (
3130 thread.message(message1_id).unwrap().clone(),
3131 thread.message(message2_id).unwrap().clone(),
3132 thread.message(message3_id).unwrap().clone(),
3133 )
3134 });
3135
3136 // First message should include context 1
3137 assert!(message1.loaded_context.text.contains("file1.rs"));
3138
3139 // Second message should include only context 2 (not 1)
3140 assert!(!message2.loaded_context.text.contains("file1.rs"));
3141 assert!(message2.loaded_context.text.contains("file2.rs"));
3142
3143 // Third message should include only context 3 (not 1 or 2)
3144 assert!(!message3.loaded_context.text.contains("file1.rs"));
3145 assert!(!message3.loaded_context.text.contains("file2.rs"));
3146 assert!(message3.loaded_context.text.contains("file3.rs"));
3147
3148 // Check entire request to make sure all contexts are properly included
3149 let request = thread.update(cx, |thread, cx| {
3150 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3151 });
3152
3153 // The request should contain all 3 messages
3154 assert_eq!(request.messages.len(), 4);
3155
3156 // Check that the contexts are properly formatted in each message
3157 assert!(request.messages[1].string_contents().contains("file1.rs"));
3158 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3159 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3160
3161 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3162 assert!(request.messages[2].string_contents().contains("file2.rs"));
3163 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3164
3165 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3166 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3167 assert!(request.messages[3].string_contents().contains("file3.rs"));
3168
3169 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3170 .await
3171 .unwrap();
3172 let new_contexts = context_store.update(cx, |store, cx| {
3173 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3174 });
3175 assert_eq!(new_contexts.len(), 3);
3176 let loaded_context = cx
3177 .update(|cx| load_context(new_contexts, &project, &None, cx))
3178 .await
3179 .loaded_context;
3180
3181 assert!(!loaded_context.text.contains("file1.rs"));
3182 assert!(loaded_context.text.contains("file2.rs"));
3183 assert!(loaded_context.text.contains("file3.rs"));
3184 assert!(loaded_context.text.contains("file4.rs"));
3185
3186 let new_contexts = context_store.update(cx, |store, cx| {
3187 // Remove file4.rs
3188 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3189 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3190 });
3191 assert_eq!(new_contexts.len(), 2);
3192 let loaded_context = cx
3193 .update(|cx| load_context(new_contexts, &project, &None, cx))
3194 .await
3195 .loaded_context;
3196
3197 assert!(!loaded_context.text.contains("file1.rs"));
3198 assert!(loaded_context.text.contains("file2.rs"));
3199 assert!(loaded_context.text.contains("file3.rs"));
3200 assert!(!loaded_context.text.contains("file4.rs"));
3201
3202 let new_contexts = context_store.update(cx, |store, cx| {
3203 // Remove file3.rs
3204 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3205 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3206 });
3207 assert_eq!(new_contexts.len(), 1);
3208 let loaded_context = cx
3209 .update(|cx| load_context(new_contexts, &project, &None, cx))
3210 .await
3211 .loaded_context;
3212
3213 assert!(!loaded_context.text.contains("file1.rs"));
3214 assert!(loaded_context.text.contains("file2.rs"));
3215 assert!(!loaded_context.text.contains("file3.rs"));
3216 assert!(!loaded_context.text.contains("file4.rs"));
3217 }
3218
3219 #[gpui::test]
3220 async fn test_message_without_files(cx: &mut TestAppContext) {
3221 init_test_settings(cx);
3222
3223 let project = create_test_project(
3224 cx,
3225 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3226 )
3227 .await;
3228
3229 let (_, _thread_store, thread, _context_store, model) =
3230 setup_test_environment(cx, project.clone()).await;
3231
3232 // Insert user message without any context (empty context vector)
3233 let message_id = thread.update(cx, |thread, cx| {
3234 thread.insert_user_message(
3235 "What is the best way to learn Rust?",
3236 ContextLoadResult::default(),
3237 None,
3238 Vec::new(),
3239 cx,
3240 )
3241 });
3242
3243 // Check content and context in message object
3244 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3245
3246 // Context should be empty when no files are included
3247 assert_eq!(message.role, Role::User);
3248 assert_eq!(message.segments.len(), 1);
3249 assert_eq!(
3250 message.segments[0],
3251 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3252 );
3253 assert_eq!(message.loaded_context.text, "");
3254
3255 // Check message in request
3256 let request = thread.update(cx, |thread, cx| {
3257 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3258 });
3259
3260 assert_eq!(request.messages.len(), 2);
3261 assert_eq!(
3262 request.messages[1].string_contents(),
3263 "What is the best way to learn Rust?"
3264 );
3265
3266 // Add second message, also without context
3267 let message2_id = thread.update(cx, |thread, cx| {
3268 thread.insert_user_message(
3269 "Are there any good books?",
3270 ContextLoadResult::default(),
3271 None,
3272 Vec::new(),
3273 cx,
3274 )
3275 });
3276
3277 let message2 =
3278 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3279 assert_eq!(message2.loaded_context.text, "");
3280
3281 // Check that both messages appear in the request
3282 let request = thread.update(cx, |thread, cx| {
3283 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3284 });
3285
3286 assert_eq!(request.messages.len(), 3);
3287 assert_eq!(
3288 request.messages[1].string_contents(),
3289 "What is the best way to learn Rust?"
3290 );
3291 assert_eq!(
3292 request.messages[2].string_contents(),
3293 "Are there any good books?"
3294 );
3295 }
3296
3297 #[gpui::test]
3298 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3299 init_test_settings(cx);
3300
3301 let project = create_test_project(
3302 cx,
3303 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3304 )
3305 .await;
3306
3307 let (_workspace, _thread_store, thread, context_store, model) =
3308 setup_test_environment(cx, project.clone()).await;
3309
3310 // Open buffer and add it to context
3311 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3312 .await
3313 .unwrap();
3314
3315 let context =
3316 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3317 let loaded_context = cx
3318 .update(|cx| load_context(vec![context], &project, &None, cx))
3319 .await;
3320
3321 // Insert user message with the buffer as context
3322 thread.update(cx, |thread, cx| {
3323 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3324 });
3325
3326 // Create a request and check that it doesn't have a stale buffer warning yet
3327 let initial_request = thread.update(cx, |thread, cx| {
3328 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3329 });
3330
3331 // Make sure we don't have a stale file warning yet
3332 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3333 msg.string_contents()
3334 .contains("These files changed since last read:")
3335 });
3336 assert!(
3337 !has_stale_warning,
3338 "Should not have stale buffer warning before buffer is modified"
3339 );
3340
3341 // Modify the buffer
3342 buffer.update(cx, |buffer, cx| {
3343 // Find a position at the end of line 1
3344 buffer.edit(
3345 [(1..1, "\n println!(\"Added a new line\");\n")],
3346 None,
3347 cx,
3348 );
3349 });
3350
3351 // Insert another user message without context
3352 thread.update(cx, |thread, cx| {
3353 thread.insert_user_message(
3354 "What does the code do now?",
3355 ContextLoadResult::default(),
3356 None,
3357 Vec::new(),
3358 cx,
3359 )
3360 });
3361
3362 // Create a new request and check for the stale buffer warning
3363 let new_request = thread.update(cx, |thread, cx| {
3364 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3365 });
3366
3367 // We should have a stale file warning as the last message
3368 let last_message = new_request
3369 .messages
3370 .last()
3371 .expect("Request should have messages");
3372
3373 // The last message should be the stale buffer notification
3374 assert_eq!(last_message.role, Role::User);
3375
3376 // Check the exact content of the message
3377 let expected_content = "These files changed since last read:\n- code.rs\n";
3378 assert_eq!(
3379 last_message.string_contents(),
3380 expected_content,
3381 "Last message should be exactly the stale buffer notification"
3382 );
3383 }
3384
3385 #[gpui::test]
3386 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3387 init_test_settings(cx);
3388
3389 let project = create_test_project(
3390 cx,
3391 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3392 )
3393 .await;
3394
3395 let (_workspace, thread_store, thread, _context_store, _model) =
3396 setup_test_environment(cx, project.clone()).await;
3397
3398 // Check that we are starting with the default profile
3399 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3400 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3401 assert_eq!(
3402 profile,
3403 AgentProfile::new(AgentProfileId::default(), tool_set)
3404 );
3405 }
3406
3407 #[gpui::test]
3408 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3409 init_test_settings(cx);
3410
3411 let project = create_test_project(
3412 cx,
3413 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3414 )
3415 .await;
3416
3417 let (_workspace, thread_store, thread, _context_store, _model) =
3418 setup_test_environment(cx, project.clone()).await;
3419
3420 // Profile gets serialized with default values
3421 let serialized = thread
3422 .update(cx, |thread, cx| thread.serialize(cx))
3423 .await
3424 .unwrap();
3425
3426 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3427
3428 let deserialized = cx.update(|cx| {
3429 thread.update(cx, |thread, cx| {
3430 Thread::deserialize(
3431 thread.id.clone(),
3432 serialized,
3433 thread.project.clone(),
3434 thread.tools.clone(),
3435 thread.prompt_builder.clone(),
3436 thread.project_context.clone(),
3437 None,
3438 cx,
3439 )
3440 })
3441 });
3442 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3443
3444 assert_eq!(
3445 deserialized.profile,
3446 AgentProfile::new(AgentProfileId::default(), tool_set)
3447 );
3448 }
3449
3450 #[gpui::test]
3451 async fn test_temperature_setting(cx: &mut TestAppContext) {
3452 init_test_settings(cx);
3453
3454 let project = create_test_project(
3455 cx,
3456 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3457 )
3458 .await;
3459
3460 let (_workspace, _thread_store, thread, _context_store, model) =
3461 setup_test_environment(cx, project.clone()).await;
3462
3463 // Both model and provider
3464 cx.update(|cx| {
3465 AgentSettings::override_global(
3466 AgentSettings {
3467 model_parameters: vec![LanguageModelParameters {
3468 provider: Some(model.provider_id().0.to_string().into()),
3469 model: Some(model.id().0.clone()),
3470 temperature: Some(0.66),
3471 }],
3472 ..AgentSettings::get_global(cx).clone()
3473 },
3474 cx,
3475 );
3476 });
3477
3478 let request = thread.update(cx, |thread, cx| {
3479 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3480 });
3481 assert_eq!(request.temperature, Some(0.66));
3482
3483 // Only model
3484 cx.update(|cx| {
3485 AgentSettings::override_global(
3486 AgentSettings {
3487 model_parameters: vec![LanguageModelParameters {
3488 provider: None,
3489 model: Some(model.id().0.clone()),
3490 temperature: Some(0.66),
3491 }],
3492 ..AgentSettings::get_global(cx).clone()
3493 },
3494 cx,
3495 );
3496 });
3497
3498 let request = thread.update(cx, |thread, cx| {
3499 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3500 });
3501 assert_eq!(request.temperature, Some(0.66));
3502
3503 // Only provider
3504 cx.update(|cx| {
3505 AgentSettings::override_global(
3506 AgentSettings {
3507 model_parameters: vec![LanguageModelParameters {
3508 provider: Some(model.provider_id().0.to_string().into()),
3509 model: None,
3510 temperature: Some(0.66),
3511 }],
3512 ..AgentSettings::get_global(cx).clone()
3513 },
3514 cx,
3515 );
3516 });
3517
3518 let request = thread.update(cx, |thread, cx| {
3519 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3520 });
3521 assert_eq!(request.temperature, Some(0.66));
3522
3523 // Same model name, different provider
3524 cx.update(|cx| {
3525 AgentSettings::override_global(
3526 AgentSettings {
3527 model_parameters: vec![LanguageModelParameters {
3528 provider: Some("anthropic".into()),
3529 model: Some(model.id().0.clone()),
3530 temperature: Some(0.66),
3531 }],
3532 ..AgentSettings::get_global(cx).clone()
3533 },
3534 cx,
3535 );
3536 });
3537
3538 let request = thread.update(cx, |thread, cx| {
3539 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3540 });
3541 assert_eq!(request.temperature, None);
3542 }
3543
3544 #[gpui::test]
3545 async fn test_thread_summary(cx: &mut TestAppContext) {
3546 init_test_settings(cx);
3547
3548 let project = create_test_project(cx, json!({})).await;
3549
3550 let (_, _thread_store, thread, _context_store, model) =
3551 setup_test_environment(cx, project.clone()).await;
3552
3553 // Initial state should be pending
3554 thread.read_with(cx, |thread, _| {
3555 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3556 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3557 });
3558
3559 // Manually setting the summary should not be allowed in this state
3560 thread.update(cx, |thread, cx| {
3561 thread.set_summary("This should not work", cx);
3562 });
3563
3564 thread.read_with(cx, |thread, _| {
3565 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3566 });
3567
3568 // Send a message
3569 thread.update(cx, |thread, cx| {
3570 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3571 thread.send_to_model(
3572 model.clone(),
3573 CompletionIntent::ThreadSummarization,
3574 None,
3575 cx,
3576 );
3577 });
3578
3579 let fake_model = model.as_fake();
3580 simulate_successful_response(&fake_model, cx);
3581
3582 // Should start generating summary when there are >= 2 messages
3583 thread.read_with(cx, |thread, _| {
3584 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3585 });
3586
3587 // Should not be able to set the summary while generating
3588 thread.update(cx, |thread, cx| {
3589 thread.set_summary("This should not work either", cx);
3590 });
3591
3592 thread.read_with(cx, |thread, _| {
3593 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3594 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3595 });
3596
3597 cx.run_until_parked();
3598 fake_model.stream_last_completion_response("Brief");
3599 fake_model.stream_last_completion_response(" Introduction");
3600 fake_model.end_last_completion_stream();
3601 cx.run_until_parked();
3602
3603 // Summary should be set
3604 thread.read_with(cx, |thread, _| {
3605 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3606 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3607 });
3608
3609 // Now we should be able to set a summary
3610 thread.update(cx, |thread, cx| {
3611 thread.set_summary("Brief Intro", cx);
3612 });
3613
3614 thread.read_with(cx, |thread, _| {
3615 assert_eq!(thread.summary().or_default(), "Brief Intro");
3616 });
3617
3618 // Test setting an empty summary (should default to DEFAULT)
3619 thread.update(cx, |thread, cx| {
3620 thread.set_summary("", cx);
3621 });
3622
3623 thread.read_with(cx, |thread, _| {
3624 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3625 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3626 });
3627 }
3628
3629 #[gpui::test]
3630 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3631 init_test_settings(cx);
3632
3633 let project = create_test_project(cx, json!({})).await;
3634
3635 let (_, _thread_store, thread, _context_store, model) =
3636 setup_test_environment(cx, project.clone()).await;
3637
3638 test_summarize_error(&model, &thread, cx);
3639
3640 // Now we should be able to set a summary
3641 thread.update(cx, |thread, cx| {
3642 thread.set_summary("Brief Intro", cx);
3643 });
3644
3645 thread.read_with(cx, |thread, _| {
3646 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3647 assert_eq!(thread.summary().or_default(), "Brief Intro");
3648 });
3649 }
3650
3651 #[gpui::test]
3652 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3653 init_test_settings(cx);
3654
3655 let project = create_test_project(cx, json!({})).await;
3656
3657 let (_, _thread_store, thread, _context_store, model) =
3658 setup_test_environment(cx, project.clone()).await;
3659
3660 test_summarize_error(&model, &thread, cx);
3661
3662 // Sending another message should not trigger another summarize request
3663 thread.update(cx, |thread, cx| {
3664 thread.insert_user_message(
3665 "How are you?",
3666 ContextLoadResult::default(),
3667 None,
3668 vec![],
3669 cx,
3670 );
3671 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3672 });
3673
3674 let fake_model = model.as_fake();
3675 simulate_successful_response(&fake_model, cx);
3676
3677 thread.read_with(cx, |thread, _| {
3678 // State is still Error, not Generating
3679 assert!(matches!(thread.summary(), ThreadSummary::Error));
3680 });
3681
3682 // But the summarize request can be invoked manually
3683 thread.update(cx, |thread, cx| {
3684 thread.summarize(cx);
3685 });
3686
3687 thread.read_with(cx, |thread, _| {
3688 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3689 });
3690
3691 cx.run_until_parked();
3692 fake_model.stream_last_completion_response("A successful summary");
3693 fake_model.end_last_completion_stream();
3694 cx.run_until_parked();
3695
3696 thread.read_with(cx, |thread, _| {
3697 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3698 assert_eq!(thread.summary().or_default(), "A successful summary");
3699 });
3700 }
3701
3702 #[gpui::test]
3703 fn test_resolve_tool_name_conflicts() {
3704 use assistant_tool::{Tool, ToolSource};
3705
3706 assert_resolve_tool_name_conflicts(
3707 vec![
3708 TestTool::new("tool1", ToolSource::Native),
3709 TestTool::new("tool2", ToolSource::Native),
3710 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3711 ],
3712 vec!["tool1", "tool2", "tool3"],
3713 );
3714
3715 assert_resolve_tool_name_conflicts(
3716 vec![
3717 TestTool::new("tool1", ToolSource::Native),
3718 TestTool::new("tool2", ToolSource::Native),
3719 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3720 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3721 ],
3722 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3723 );
3724
3725 assert_resolve_tool_name_conflicts(
3726 vec![
3727 TestTool::new("tool1", ToolSource::Native),
3728 TestTool::new("tool2", ToolSource::Native),
3729 TestTool::new("tool3", ToolSource::Native),
3730 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3731 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3732 ],
3733 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3734 );
3735
3736 // Test that tool with very long name is always truncated
3737 assert_resolve_tool_name_conflicts(
3738 vec![TestTool::new(
3739 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3740 ToolSource::Native,
3741 )],
3742 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3743 );
3744
3745 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3746 assert_resolve_tool_name_conflicts(
3747 vec![
3748 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3749 TestTool::new(
3750 "tool-with-very-very-very-long-name",
3751 ToolSource::ContextServer {
3752 id: "mcp-with-very-very-very-long-name".into(),
3753 },
3754 ),
3755 ],
3756 vec![
3757 "tool-with-very-very-very-long-name",
3758 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3759 ],
3760 );
3761
3762 fn assert_resolve_tool_name_conflicts(
3763 tools: Vec<TestTool>,
3764 expected: Vec<impl Into<String>>,
3765 ) {
3766 let tools: Vec<Arc<dyn Tool>> = tools
3767 .into_iter()
3768 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3769 .collect();
3770 let tools = resolve_tool_name_conflicts(&tools);
3771 assert_eq!(tools.len(), expected.len());
3772 for (i, expected_name) in expected.into_iter().enumerate() {
3773 let expected_name = expected_name.into();
3774 let actual_name = &tools[i].0;
3775 assert_eq!(
3776 actual_name, &expected_name,
3777 "Expected '{}' got '{}' at index {}",
3778 expected_name, actual_name, i
3779 );
3780 }
3781 }
3782
3783 struct TestTool {
3784 name: String,
3785 source: ToolSource,
3786 }
3787
3788 impl TestTool {
3789 fn new(name: impl Into<String>, source: ToolSource) -> Self {
3790 Self {
3791 name: name.into(),
3792 source,
3793 }
3794 }
3795 }
3796
3797 impl Tool for TestTool {
3798 fn name(&self) -> String {
3799 self.name.clone()
3800 }
3801
3802 fn icon(&self) -> IconName {
3803 IconName::Ai
3804 }
3805
3806 fn may_perform_edits(&self) -> bool {
3807 false
3808 }
3809
3810 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3811 true
3812 }
3813
3814 fn source(&self) -> ToolSource {
3815 self.source.clone()
3816 }
3817
3818 fn description(&self) -> String {
3819 "Test tool".to_string()
3820 }
3821
3822 fn ui_text(&self, _input: &serde_json::Value) -> String {
3823 "Test tool".to_string()
3824 }
3825
3826 fn run(
3827 self: Arc<Self>,
3828 _input: serde_json::Value,
3829 _request: Arc<LanguageModelRequest>,
3830 _project: Entity<Project>,
3831 _action_log: Entity<ActionLog>,
3832 _model: Arc<dyn LanguageModel>,
3833 _window: Option<AnyWindowHandle>,
3834 _cx: &mut App,
3835 ) -> assistant_tool::ToolResult {
3836 assistant_tool::ToolResult {
3837 output: Task::ready(Err(anyhow::anyhow!("No content"))),
3838 card: None,
3839 }
3840 }
3841 }
3842 }
3843
3844 fn test_summarize_error(
3845 model: &Arc<dyn LanguageModel>,
3846 thread: &Entity<Thread>,
3847 cx: &mut TestAppContext,
3848 ) {
3849 thread.update(cx, |thread, cx| {
3850 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3851 thread.send_to_model(
3852 model.clone(),
3853 CompletionIntent::ThreadSummarization,
3854 None,
3855 cx,
3856 );
3857 });
3858
3859 let fake_model = model.as_fake();
3860 simulate_successful_response(&fake_model, cx);
3861
3862 thread.read_with(cx, |thread, _| {
3863 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3864 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3865 });
3866
3867 // Simulate summary request ending
3868 cx.run_until_parked();
3869 fake_model.end_last_completion_stream();
3870 cx.run_until_parked();
3871
3872 // State is set to Error and default message
3873 thread.read_with(cx, |thread, _| {
3874 assert!(matches!(thread.summary(), ThreadSummary::Error));
3875 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3876 });
3877 }
3878
3879 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3880 cx.run_until_parked();
3881 fake_model.stream_last_completion_response("Assistant response");
3882 fake_model.end_last_completion_stream();
3883 cx.run_until_parked();
3884 }
3885
3886 fn init_test_settings(cx: &mut TestAppContext) {
3887 cx.update(|cx| {
3888 let settings_store = SettingsStore::test(cx);
3889 cx.set_global(settings_store);
3890 language::init(cx);
3891 Project::init_settings(cx);
3892 AgentSettings::register(cx);
3893 prompt_store::init(cx);
3894 thread_store::init(cx);
3895 workspace::init_settings(cx);
3896 language_model::init_settings(cx);
3897 ThemeSettings::register(cx);
3898 EditorSettings::register(cx);
3899 ToolRegistry::default_global(cx);
3900 });
3901 }
3902
3903 // Helper to create a test project with test files
3904 async fn create_test_project(
3905 cx: &mut TestAppContext,
3906 files: serde_json::Value,
3907 ) -> Entity<Project> {
3908 let fs = FakeFs::new(cx.executor());
3909 fs.insert_tree(path!("/test"), files).await;
3910 Project::test(fs, [path!("/test").as_ref()], cx).await
3911 }
3912
3913 async fn setup_test_environment(
3914 cx: &mut TestAppContext,
3915 project: Entity<Project>,
3916 ) -> (
3917 Entity<Workspace>,
3918 Entity<ThreadStore>,
3919 Entity<Thread>,
3920 Entity<ContextStore>,
3921 Arc<dyn LanguageModel>,
3922 ) {
3923 let (workspace, cx) =
3924 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3925
3926 let thread_store = cx
3927 .update(|_, cx| {
3928 ThreadStore::load(
3929 project.clone(),
3930 cx.new(|_| ToolWorkingSet::default()),
3931 None,
3932 Arc::new(PromptBuilder::new(None).unwrap()),
3933 cx,
3934 )
3935 })
3936 .await
3937 .unwrap();
3938
3939 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3940 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3941
3942 let provider = Arc::new(FakeLanguageModelProvider);
3943 let model = provider.test_model();
3944 let model: Arc<dyn LanguageModel> = Arc::new(model);
3945
3946 cx.update(|_, cx| {
3947 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3948 registry.set_default_model(
3949 Some(ConfiguredModel {
3950 provider: provider.clone(),
3951 model: model.clone(),
3952 }),
3953 cx,
3954 );
3955 registry.set_thread_summary_model(
3956 Some(ConfiguredModel {
3957 provider,
3958 model: model.clone(),
3959 }),
3960 cx,
3961 );
3962 })
3963 });
3964
3965 (workspace, thread_store, thread, context_store, model)
3966 }
3967
3968 async fn add_file_to_context(
3969 project: &Entity<Project>,
3970 context_store: &Entity<ContextStore>,
3971 path: &str,
3972 cx: &mut TestAppContext,
3973 ) -> Result<Entity<language::Buffer>> {
3974 let buffer_path = project
3975 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3976 .unwrap();
3977
3978 let buffer = project
3979 .update(cx, |project, cx| {
3980 project.open_buffer(buffer_path.clone(), cx)
3981 })
3982 .await
3983 .unwrap();
3984
3985 context_store.update(cx, |context_store, cx| {
3986 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3987 });
3988
3989 Ok(buffer)
3990 }
3991}