1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{HashMap, HashSet};
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::agent_profile::AgentProfile;
45use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
46use crate::thread_store::{
47 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
48 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
49};
50use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
51
52#[derive(
53 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
54)]
55pub struct ThreadId(Arc<str>);
56
57impl ThreadId {
58 pub fn new() -> Self {
59 Self(Uuid::new_v4().to_string().into())
60 }
61}
62
63impl std::fmt::Display for ThreadId {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 write!(f, "{}", self.0)
66 }
67}
68
69impl From<&str> for ThreadId {
70 fn from(value: &str) -> Self {
71 Self(value.into())
72 }
73}
74
75/// The ID of the user prompt that initiated a request.
76///
77/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
78#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
79pub struct PromptId(Arc<str>);
80
81impl PromptId {
82 pub fn new() -> Self {
83 Self(Uuid::new_v4().to_string().into())
84 }
85}
86
87impl std::fmt::Display for PromptId {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(f, "{}", self.0)
90 }
91}
92
93#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
94pub struct MessageId(pub(crate) usize);
95
96impl MessageId {
97 fn post_inc(&mut self) -> Self {
98 Self(post_inc(&mut self.0))
99 }
100}
101
102/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
103#[derive(Clone, Debug)]
104pub struct MessageCrease {
105 pub range: Range<usize>,
106 pub metadata: CreaseMetadata,
107 /// None for a deserialized message, Some otherwise.
108 pub context: Option<AgentContextHandle>,
109}
110
111/// A message in a [`Thread`].
112#[derive(Debug, Clone)]
113pub struct Message {
114 pub id: MessageId,
115 pub role: Role,
116 pub segments: Vec<MessageSegment>,
117 pub loaded_context: LoadedContext,
118 pub creases: Vec<MessageCrease>,
119 pub is_hidden: bool,
120}
121
122impl Message {
123 /// Returns whether the message contains any meaningful text that should be displayed
124 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
125 pub fn should_display_content(&self) -> bool {
126 self.segments.iter().all(|segment| segment.should_display())
127 }
128
129 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
130 if let Some(MessageSegment::Thinking {
131 text: segment,
132 signature: current_signature,
133 }) = self.segments.last_mut()
134 {
135 if let Some(signature) = signature {
136 *current_signature = Some(signature);
137 }
138 segment.push_str(text);
139 } else {
140 self.segments.push(MessageSegment::Thinking {
141 text: text.to_string(),
142 signature,
143 });
144 }
145 }
146
147 pub fn push_redacted_thinking(&mut self, data: String) {
148 self.segments.push(MessageSegment::RedactedThinking(data));
149 }
150
151 pub fn push_text(&mut self, text: &str) {
152 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
153 segment.push_str(text);
154 } else {
155 self.segments.push(MessageSegment::Text(text.to_string()));
156 }
157 }
158
159 pub fn to_string(&self) -> String {
160 let mut result = String::new();
161
162 if !self.loaded_context.text.is_empty() {
163 result.push_str(&self.loaded_context.text);
164 }
165
166 for segment in &self.segments {
167 match segment {
168 MessageSegment::Text(text) => result.push_str(text),
169 MessageSegment::Thinking { text, .. } => {
170 result.push_str("<think>\n");
171 result.push_str(text);
172 result.push_str("\n</think>");
173 }
174 MessageSegment::RedactedThinking(_) => {}
175 }
176 }
177
178 result
179 }
180}
181
182#[derive(Debug, Clone, PartialEq, Eq)]
183pub enum MessageSegment {
184 Text(String),
185 Thinking {
186 text: String,
187 signature: Option<String>,
188 },
189 RedactedThinking(String),
190}
191
192impl MessageSegment {
193 pub fn should_display(&self) -> bool {
194 match self {
195 Self::Text(text) => text.is_empty(),
196 Self::Thinking { text, .. } => text.is_empty(),
197 Self::RedactedThinking(_) => false,
198 }
199 }
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
203pub struct ProjectSnapshot {
204 pub worktree_snapshots: Vec<WorktreeSnapshot>,
205 pub unsaved_buffer_paths: Vec<String>,
206 pub timestamp: DateTime<Utc>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
210pub struct WorktreeSnapshot {
211 pub worktree_path: String,
212 pub git_state: Option<GitState>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
216pub struct GitState {
217 pub remote_url: Option<String>,
218 pub head_sha: Option<String>,
219 pub current_branch: Option<String>,
220 pub diff: Option<String>,
221}
222
223#[derive(Clone, Debug)]
224pub struct ThreadCheckpoint {
225 message_id: MessageId,
226 git_checkpoint: GitStoreCheckpoint,
227}
228
229#[derive(Copy, Clone, Debug, PartialEq, Eq)]
230pub enum ThreadFeedback {
231 Positive,
232 Negative,
233}
234
235pub enum LastRestoreCheckpoint {
236 Pending {
237 message_id: MessageId,
238 },
239 Error {
240 message_id: MessageId,
241 error: String,
242 },
243}
244
245impl LastRestoreCheckpoint {
246 pub fn message_id(&self) -> MessageId {
247 match self {
248 LastRestoreCheckpoint::Pending { message_id } => *message_id,
249 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
250 }
251 }
252}
253
254#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
255pub enum DetailedSummaryState {
256 #[default]
257 NotGenerated,
258 Generating {
259 message_id: MessageId,
260 },
261 Generated {
262 text: SharedString,
263 message_id: MessageId,
264 },
265}
266
267impl DetailedSummaryState {
268 fn text(&self) -> Option<SharedString> {
269 if let Self::Generated { text, .. } = self {
270 Some(text.clone())
271 } else {
272 None
273 }
274 }
275}
276
277#[derive(Default, Debug)]
278pub struct TotalTokenUsage {
279 pub total: usize,
280 pub max: usize,
281}
282
283impl TotalTokenUsage {
284 pub fn ratio(&self) -> TokenUsageRatio {
285 #[cfg(debug_assertions)]
286 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
287 .unwrap_or("0.8".to_string())
288 .parse()
289 .unwrap();
290 #[cfg(not(debug_assertions))]
291 let warning_threshold: f32 = 0.8;
292
293 // When the maximum is unknown because there is no selected model,
294 // avoid showing the token limit warning.
295 if self.max == 0 {
296 TokenUsageRatio::Normal
297 } else if self.total >= self.max {
298 TokenUsageRatio::Exceeded
299 } else if self.total as f32 / self.max as f32 >= warning_threshold {
300 TokenUsageRatio::Warning
301 } else {
302 TokenUsageRatio::Normal
303 }
304 }
305
306 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
307 TotalTokenUsage {
308 total: self.total + tokens,
309 max: self.max,
310 }
311 }
312}
313
314#[derive(Debug, Default, PartialEq, Eq)]
315pub enum TokenUsageRatio {
316 #[default]
317 Normal,
318 Warning,
319 Exceeded,
320}
321
322#[derive(Debug, Clone, Copy)]
323pub enum QueueState {
324 Sending,
325 Queued { position: usize },
326 Started,
327}
328
329/// A thread of conversation with the LLM.
330pub struct Thread {
331 id: ThreadId,
332 updated_at: DateTime<Utc>,
333 summary: ThreadSummary,
334 pending_summary: Task<Option<()>>,
335 detailed_summary_task: Task<Option<()>>,
336 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
337 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
338 completion_mode: agent_settings::CompletionMode,
339 messages: Vec<Message>,
340 next_message_id: MessageId,
341 last_prompt_id: PromptId,
342 project_context: SharedProjectContext,
343 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
344 completion_count: usize,
345 pending_completions: Vec<PendingCompletion>,
346 project: Entity<Project>,
347 prompt_builder: Arc<PromptBuilder>,
348 tools: Entity<ToolWorkingSet>,
349 tool_use: ToolUseState,
350 action_log: Entity<ActionLog>,
351 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
352 pending_checkpoint: Option<ThreadCheckpoint>,
353 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
354 request_token_usage: Vec<TokenUsage>,
355 cumulative_token_usage: TokenUsage,
356 exceeded_window_error: Option<ExceededWindowError>,
357 last_usage: Option<RequestUsage>,
358 tool_use_limit_reached: bool,
359 feedback: Option<ThreadFeedback>,
360 message_feedback: HashMap<MessageId, ThreadFeedback>,
361 last_auto_capture_at: Option<Instant>,
362 last_received_chunk_at: Option<Instant>,
363 request_callback: Option<
364 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
365 >,
366 remaining_turns: u32,
367 configured_model: Option<ConfiguredModel>,
368 profile: AgentProfile,
369}
370
371#[derive(Clone, Debug, PartialEq, Eq)]
372pub enum ThreadSummary {
373 Pending,
374 Generating,
375 Ready(SharedString),
376 Error,
377}
378
379impl ThreadSummary {
380 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
381
382 pub fn or_default(&self) -> SharedString {
383 self.unwrap_or(Self::DEFAULT)
384 }
385
386 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
387 self.ready().unwrap_or_else(|| message.into())
388 }
389
390 pub fn ready(&self) -> Option<SharedString> {
391 match self {
392 ThreadSummary::Ready(summary) => Some(summary.clone()),
393 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
394 }
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
399pub struct ExceededWindowError {
400 /// Model used when last message exceeded context window
401 model_id: LanguageModelId,
402 /// Token count including last message
403 token_count: usize,
404}
405
406impl Thread {
407 pub fn new(
408 project: Entity<Project>,
409 tools: Entity<ToolWorkingSet>,
410 prompt_builder: Arc<PromptBuilder>,
411 system_prompt: SharedProjectContext,
412 cx: &mut Context<Self>,
413 ) -> Self {
414 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
415 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
416 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
417
418 Self {
419 id: ThreadId::new(),
420 updated_at: Utc::now(),
421 summary: ThreadSummary::Pending,
422 pending_summary: Task::ready(None),
423 detailed_summary_task: Task::ready(None),
424 detailed_summary_tx,
425 detailed_summary_rx,
426 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
427 messages: Vec::new(),
428 next_message_id: MessageId(0),
429 last_prompt_id: PromptId::new(),
430 project_context: system_prompt,
431 checkpoints_by_message: HashMap::default(),
432 completion_count: 0,
433 pending_completions: Vec::new(),
434 project: project.clone(),
435 prompt_builder,
436 tools: tools.clone(),
437 last_restore_checkpoint: None,
438 pending_checkpoint: None,
439 tool_use: ToolUseState::new(tools.clone()),
440 action_log: cx.new(|_| ActionLog::new(project.clone())),
441 initial_project_snapshot: {
442 let project_snapshot = Self::project_snapshot(project, cx);
443 cx.foreground_executor()
444 .spawn(async move { Some(project_snapshot.await) })
445 .shared()
446 },
447 request_token_usage: Vec::new(),
448 cumulative_token_usage: TokenUsage::default(),
449 exceeded_window_error: None,
450 last_usage: None,
451 tool_use_limit_reached: false,
452 feedback: None,
453 message_feedback: HashMap::default(),
454 last_auto_capture_at: None,
455 last_received_chunk_at: None,
456 request_callback: None,
457 remaining_turns: u32::MAX,
458 configured_model,
459 profile: AgentProfile::new(profile_id, tools),
460 }
461 }
462
463 pub fn deserialize(
464 id: ThreadId,
465 serialized: SerializedThread,
466 project: Entity<Project>,
467 tools: Entity<ToolWorkingSet>,
468 prompt_builder: Arc<PromptBuilder>,
469 project_context: SharedProjectContext,
470 window: Option<&mut Window>, // None in headless mode
471 cx: &mut Context<Self>,
472 ) -> Self {
473 let next_message_id = MessageId(
474 serialized
475 .messages
476 .last()
477 .map(|message| message.id.0 + 1)
478 .unwrap_or(0),
479 );
480 let tool_use = ToolUseState::from_serialized_messages(
481 tools.clone(),
482 &serialized.messages,
483 project.clone(),
484 window,
485 cx,
486 );
487 let (detailed_summary_tx, detailed_summary_rx) =
488 postage::watch::channel_with(serialized.detailed_summary_state);
489
490 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
491 serialized
492 .model
493 .and_then(|model| {
494 let model = SelectedModel {
495 provider: model.provider.clone().into(),
496 model: model.model.clone().into(),
497 };
498 registry.select_model(&model, cx)
499 })
500 .or_else(|| registry.default_model())
501 });
502
503 let completion_mode = serialized
504 .completion_mode
505 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
506 let profile_id = serialized
507 .profile
508 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
509
510 Self {
511 id,
512 updated_at: serialized.updated_at,
513 summary: ThreadSummary::Ready(serialized.summary),
514 pending_summary: Task::ready(None),
515 detailed_summary_task: Task::ready(None),
516 detailed_summary_tx,
517 detailed_summary_rx,
518 completion_mode,
519 messages: serialized
520 .messages
521 .into_iter()
522 .map(|message| Message {
523 id: message.id,
524 role: message.role,
525 segments: message
526 .segments
527 .into_iter()
528 .map(|segment| match segment {
529 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
530 SerializedMessageSegment::Thinking { text, signature } => {
531 MessageSegment::Thinking { text, signature }
532 }
533 SerializedMessageSegment::RedactedThinking { data } => {
534 MessageSegment::RedactedThinking(data)
535 }
536 })
537 .collect(),
538 loaded_context: LoadedContext {
539 contexts: Vec::new(),
540 text: message.context,
541 images: Vec::new(),
542 },
543 creases: message
544 .creases
545 .into_iter()
546 .map(|crease| MessageCrease {
547 range: crease.start..crease.end,
548 metadata: CreaseMetadata {
549 icon_path: crease.icon_path,
550 label: crease.label,
551 },
552 context: None,
553 })
554 .collect(),
555 is_hidden: message.is_hidden,
556 })
557 .collect(),
558 next_message_id,
559 last_prompt_id: PromptId::new(),
560 project_context,
561 checkpoints_by_message: HashMap::default(),
562 completion_count: 0,
563 pending_completions: Vec::new(),
564 last_restore_checkpoint: None,
565 pending_checkpoint: None,
566 project: project.clone(),
567 prompt_builder,
568 tools: tools.clone(),
569 tool_use,
570 action_log: cx.new(|_| ActionLog::new(project)),
571 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
572 request_token_usage: serialized.request_token_usage,
573 cumulative_token_usage: serialized.cumulative_token_usage,
574 exceeded_window_error: None,
575 last_usage: None,
576 tool_use_limit_reached: serialized.tool_use_limit_reached,
577 feedback: None,
578 message_feedback: HashMap::default(),
579 last_auto_capture_at: None,
580 last_received_chunk_at: None,
581 request_callback: None,
582 remaining_turns: u32::MAX,
583 configured_model,
584 profile: AgentProfile::new(profile_id, tools),
585 }
586 }
587
588 pub fn set_request_callback(
589 &mut self,
590 callback: impl 'static
591 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
592 ) {
593 self.request_callback = Some(Box::new(callback));
594 }
595
596 pub fn id(&self) -> &ThreadId {
597 &self.id
598 }
599
600 pub fn profile(&self) -> &AgentProfile {
601 &self.profile
602 }
603
604 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
605 if &id != self.profile.id() {
606 self.profile = AgentProfile::new(id, self.tools.clone());
607 cx.emit(ThreadEvent::ProfileChanged);
608 }
609 }
610
611 pub fn is_empty(&self) -> bool {
612 self.messages.is_empty()
613 }
614
615 pub fn updated_at(&self) -> DateTime<Utc> {
616 self.updated_at
617 }
618
619 pub fn touch_updated_at(&mut self) {
620 self.updated_at = Utc::now();
621 }
622
623 pub fn advance_prompt_id(&mut self) {
624 self.last_prompt_id = PromptId::new();
625 }
626
627 pub fn project_context(&self) -> SharedProjectContext {
628 self.project_context.clone()
629 }
630
631 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
632 if self.configured_model.is_none() {
633 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
634 }
635 self.configured_model.clone()
636 }
637
638 pub fn configured_model(&self) -> Option<ConfiguredModel> {
639 self.configured_model.clone()
640 }
641
642 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
643 self.configured_model = model;
644 cx.notify();
645 }
646
647 pub fn summary(&self) -> &ThreadSummary {
648 &self.summary
649 }
650
651 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
652 let current_summary = match &self.summary {
653 ThreadSummary::Pending | ThreadSummary::Generating => return,
654 ThreadSummary::Ready(summary) => summary,
655 ThreadSummary::Error => &ThreadSummary::DEFAULT,
656 };
657
658 let mut new_summary = new_summary.into();
659
660 if new_summary.is_empty() {
661 new_summary = ThreadSummary::DEFAULT;
662 }
663
664 if current_summary != &new_summary {
665 self.summary = ThreadSummary::Ready(new_summary);
666 cx.emit(ThreadEvent::SummaryChanged);
667 }
668 }
669
670 pub fn completion_mode(&self) -> CompletionMode {
671 self.completion_mode
672 }
673
674 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
675 self.completion_mode = mode;
676 }
677
678 pub fn message(&self, id: MessageId) -> Option<&Message> {
679 let index = self
680 .messages
681 .binary_search_by(|message| message.id.cmp(&id))
682 .ok()?;
683
684 self.messages.get(index)
685 }
686
687 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
688 self.messages.iter()
689 }
690
691 pub fn is_generating(&self) -> bool {
692 !self.pending_completions.is_empty() || !self.all_tools_finished()
693 }
694
695 /// Indicates whether streaming of language model events is stale.
696 /// When `is_generating()` is false, this method returns `None`.
697 pub fn is_generation_stale(&self) -> Option<bool> {
698 const STALE_THRESHOLD: u128 = 250;
699
700 self.last_received_chunk_at
701 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
702 }
703
704 fn received_chunk(&mut self) {
705 self.last_received_chunk_at = Some(Instant::now());
706 }
707
708 pub fn queue_state(&self) -> Option<QueueState> {
709 self.pending_completions
710 .first()
711 .map(|pending_completion| pending_completion.queue_state)
712 }
713
714 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
715 &self.tools
716 }
717
718 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
719 self.tool_use
720 .pending_tool_uses()
721 .into_iter()
722 .find(|tool_use| &tool_use.id == id)
723 }
724
725 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
726 self.tool_use
727 .pending_tool_uses()
728 .into_iter()
729 .filter(|tool_use| tool_use.status.needs_confirmation())
730 }
731
732 pub fn has_pending_tool_uses(&self) -> bool {
733 !self.tool_use.pending_tool_uses().is_empty()
734 }
735
736 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
737 self.checkpoints_by_message.get(&id).cloned()
738 }
739
740 pub fn restore_checkpoint(
741 &mut self,
742 checkpoint: ThreadCheckpoint,
743 cx: &mut Context<Self>,
744 ) -> Task<Result<()>> {
745 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
746 message_id: checkpoint.message_id,
747 });
748 cx.emit(ThreadEvent::CheckpointChanged);
749 cx.notify();
750
751 let git_store = self.project().read(cx).git_store().clone();
752 let restore = git_store.update(cx, |git_store, cx| {
753 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
754 });
755
756 cx.spawn(async move |this, cx| {
757 let result = restore.await;
758 this.update(cx, |this, cx| {
759 if let Err(err) = result.as_ref() {
760 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
761 message_id: checkpoint.message_id,
762 error: err.to_string(),
763 });
764 } else {
765 this.truncate(checkpoint.message_id, cx);
766 this.last_restore_checkpoint = None;
767 }
768 this.pending_checkpoint = None;
769 cx.emit(ThreadEvent::CheckpointChanged);
770 cx.notify();
771 })?;
772 result
773 })
774 }
775
776 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
777 let pending_checkpoint = if self.is_generating() {
778 return;
779 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
780 checkpoint
781 } else {
782 return;
783 };
784
785 self.finalize_checkpoint(pending_checkpoint, cx);
786 }
787
788 fn finalize_checkpoint(
789 &mut self,
790 pending_checkpoint: ThreadCheckpoint,
791 cx: &mut Context<Self>,
792 ) {
793 let git_store = self.project.read(cx).git_store().clone();
794 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
795 cx.spawn(async move |this, cx| match final_checkpoint.await {
796 Ok(final_checkpoint) => {
797 let equal = git_store
798 .update(cx, |store, cx| {
799 store.compare_checkpoints(
800 pending_checkpoint.git_checkpoint.clone(),
801 final_checkpoint.clone(),
802 cx,
803 )
804 })?
805 .await
806 .unwrap_or(false);
807
808 if !equal {
809 this.update(cx, |this, cx| {
810 this.insert_checkpoint(pending_checkpoint, cx)
811 })?;
812 }
813
814 Ok(())
815 }
816 Err(_) => this.update(cx, |this, cx| {
817 this.insert_checkpoint(pending_checkpoint, cx)
818 }),
819 })
820 .detach();
821 }
822
823 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
824 self.checkpoints_by_message
825 .insert(checkpoint.message_id, checkpoint);
826 cx.emit(ThreadEvent::CheckpointChanged);
827 cx.notify();
828 }
829
830 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
831 self.last_restore_checkpoint.as_ref()
832 }
833
834 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
835 let Some(message_ix) = self
836 .messages
837 .iter()
838 .rposition(|message| message.id == message_id)
839 else {
840 return;
841 };
842 for deleted_message in self.messages.drain(message_ix..) {
843 self.checkpoints_by_message.remove(&deleted_message.id);
844 }
845 cx.notify();
846 }
847
848 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
849 self.messages
850 .iter()
851 .find(|message| message.id == id)
852 .into_iter()
853 .flat_map(|message| message.loaded_context.contexts.iter())
854 }
855
856 pub fn is_turn_end(&self, ix: usize) -> bool {
857 if self.messages.is_empty() {
858 return false;
859 }
860
861 if !self.is_generating() && ix == self.messages.len() - 1 {
862 return true;
863 }
864
865 let Some(message) = self.messages.get(ix) else {
866 return false;
867 };
868
869 if message.role != Role::Assistant {
870 return false;
871 }
872
873 self.messages
874 .get(ix + 1)
875 .and_then(|message| {
876 self.message(message.id)
877 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
878 })
879 .unwrap_or(false)
880 }
881
882 pub fn last_usage(&self) -> Option<RequestUsage> {
883 self.last_usage
884 }
885
886 pub fn tool_use_limit_reached(&self) -> bool {
887 self.tool_use_limit_reached
888 }
889
890 /// Returns whether all of the tool uses have finished running.
891 pub fn all_tools_finished(&self) -> bool {
892 // If the only pending tool uses left are the ones with errors, then
893 // that means that we've finished running all of the pending tools.
894 self.tool_use
895 .pending_tool_uses()
896 .iter()
897 .all(|pending_tool_use| pending_tool_use.status.is_error())
898 }
899
900 /// Returns whether any pending tool uses may perform edits
901 pub fn has_pending_edit_tool_uses(&self) -> bool {
902 self.tool_use
903 .pending_tool_uses()
904 .iter()
905 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
906 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
907 }
908
909 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
910 self.tool_use.tool_uses_for_message(id, cx)
911 }
912
913 pub fn tool_results_for_message(
914 &self,
915 assistant_message_id: MessageId,
916 ) -> Vec<&LanguageModelToolResult> {
917 self.tool_use.tool_results_for_message(assistant_message_id)
918 }
919
920 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
921 self.tool_use.tool_result(id)
922 }
923
924 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
925 match &self.tool_use.tool_result(id)?.content {
926 LanguageModelToolResultContent::Text(text) => Some(text),
927 LanguageModelToolResultContent::Image(_) => {
928 // TODO: We should display image
929 None
930 }
931 }
932 }
933
934 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
935 self.tool_use.tool_result_card(id).cloned()
936 }
937
938 /// Return tools that are both enabled and supported by the model
939 pub fn available_tools(
940 &self,
941 cx: &App,
942 model: Arc<dyn LanguageModel>,
943 ) -> Vec<LanguageModelRequestTool> {
944 if model.supports_tools() {
945 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
946 .into_iter()
947 .filter_map(|(name, tool)| {
948 // Skip tools that cannot be supported
949 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
950 Some(LanguageModelRequestTool {
951 name,
952 description: tool.description(),
953 input_schema,
954 })
955 })
956 .collect()
957 } else {
958 Vec::default()
959 }
960 }
961
962 pub fn insert_user_message(
963 &mut self,
964 text: impl Into<String>,
965 loaded_context: ContextLoadResult,
966 git_checkpoint: Option<GitStoreCheckpoint>,
967 creases: Vec<MessageCrease>,
968 cx: &mut Context<Self>,
969 ) -> MessageId {
970 if !loaded_context.referenced_buffers.is_empty() {
971 self.action_log.update(cx, |log, cx| {
972 for buffer in loaded_context.referenced_buffers {
973 log.buffer_read(buffer, cx);
974 }
975 });
976 }
977
978 let message_id = self.insert_message(
979 Role::User,
980 vec![MessageSegment::Text(text.into())],
981 loaded_context.loaded_context,
982 creases,
983 false,
984 cx,
985 );
986
987 if let Some(git_checkpoint) = git_checkpoint {
988 self.pending_checkpoint = Some(ThreadCheckpoint {
989 message_id,
990 git_checkpoint,
991 });
992 }
993
994 self.auto_capture_telemetry(cx);
995
996 message_id
997 }
998
999 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1000 let id = self.insert_message(
1001 Role::User,
1002 vec![MessageSegment::Text("Continue where you left off".into())],
1003 LoadedContext::default(),
1004 vec![],
1005 true,
1006 cx,
1007 );
1008 self.pending_checkpoint = None;
1009
1010 id
1011 }
1012
1013 pub fn insert_assistant_message(
1014 &mut self,
1015 segments: Vec<MessageSegment>,
1016 cx: &mut Context<Self>,
1017 ) -> MessageId {
1018 self.insert_message(
1019 Role::Assistant,
1020 segments,
1021 LoadedContext::default(),
1022 Vec::new(),
1023 false,
1024 cx,
1025 )
1026 }
1027
1028 pub fn insert_message(
1029 &mut self,
1030 role: Role,
1031 segments: Vec<MessageSegment>,
1032 loaded_context: LoadedContext,
1033 creases: Vec<MessageCrease>,
1034 is_hidden: bool,
1035 cx: &mut Context<Self>,
1036 ) -> MessageId {
1037 let id = self.next_message_id.post_inc();
1038 self.messages.push(Message {
1039 id,
1040 role,
1041 segments,
1042 loaded_context,
1043 creases,
1044 is_hidden,
1045 });
1046 self.touch_updated_at();
1047 cx.emit(ThreadEvent::MessageAdded(id));
1048 id
1049 }
1050
1051 pub fn edit_message(
1052 &mut self,
1053 id: MessageId,
1054 new_role: Role,
1055 new_segments: Vec<MessageSegment>,
1056 creases: Vec<MessageCrease>,
1057 loaded_context: Option<LoadedContext>,
1058 checkpoint: Option<GitStoreCheckpoint>,
1059 cx: &mut Context<Self>,
1060 ) -> bool {
1061 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1062 return false;
1063 };
1064 message.role = new_role;
1065 message.segments = new_segments;
1066 message.creases = creases;
1067 if let Some(context) = loaded_context {
1068 message.loaded_context = context;
1069 }
1070 if let Some(git_checkpoint) = checkpoint {
1071 self.checkpoints_by_message.insert(
1072 id,
1073 ThreadCheckpoint {
1074 message_id: id,
1075 git_checkpoint,
1076 },
1077 );
1078 }
1079 self.touch_updated_at();
1080 cx.emit(ThreadEvent::MessageEdited(id));
1081 true
1082 }
1083
1084 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1085 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1086 return false;
1087 };
1088 self.messages.remove(index);
1089 self.touch_updated_at();
1090 cx.emit(ThreadEvent::MessageDeleted(id));
1091 true
1092 }
1093
1094 /// Returns the representation of this [`Thread`] in a textual form.
1095 ///
1096 /// This is the representation we use when attaching a thread as context to another thread.
1097 pub fn text(&self) -> String {
1098 let mut text = String::new();
1099
1100 for message in &self.messages {
1101 text.push_str(match message.role {
1102 language_model::Role::User => "User:",
1103 language_model::Role::Assistant => "Agent:",
1104 language_model::Role::System => "System:",
1105 });
1106 text.push('\n');
1107
1108 for segment in &message.segments {
1109 match segment {
1110 MessageSegment::Text(content) => text.push_str(content),
1111 MessageSegment::Thinking { text: content, .. } => {
1112 text.push_str(&format!("<think>{}</think>", content))
1113 }
1114 MessageSegment::RedactedThinking(_) => {}
1115 }
1116 }
1117 text.push('\n');
1118 }
1119
1120 text
1121 }
1122
1123 /// Serializes this thread into a format for storage or telemetry.
1124 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1125 let initial_project_snapshot = self.initial_project_snapshot.clone();
1126 cx.spawn(async move |this, cx| {
1127 let initial_project_snapshot = initial_project_snapshot.await;
1128 this.read_with(cx, |this, cx| SerializedThread {
1129 version: SerializedThread::VERSION.to_string(),
1130 summary: this.summary().or_default(),
1131 updated_at: this.updated_at(),
1132 messages: this
1133 .messages()
1134 .map(|message| SerializedMessage {
1135 id: message.id,
1136 role: message.role,
1137 segments: message
1138 .segments
1139 .iter()
1140 .map(|segment| match segment {
1141 MessageSegment::Text(text) => {
1142 SerializedMessageSegment::Text { text: text.clone() }
1143 }
1144 MessageSegment::Thinking { text, signature } => {
1145 SerializedMessageSegment::Thinking {
1146 text: text.clone(),
1147 signature: signature.clone(),
1148 }
1149 }
1150 MessageSegment::RedactedThinking(data) => {
1151 SerializedMessageSegment::RedactedThinking {
1152 data: data.clone(),
1153 }
1154 }
1155 })
1156 .collect(),
1157 tool_uses: this
1158 .tool_uses_for_message(message.id, cx)
1159 .into_iter()
1160 .map(|tool_use| SerializedToolUse {
1161 id: tool_use.id,
1162 name: tool_use.name,
1163 input: tool_use.input,
1164 })
1165 .collect(),
1166 tool_results: this
1167 .tool_results_for_message(message.id)
1168 .into_iter()
1169 .map(|tool_result| SerializedToolResult {
1170 tool_use_id: tool_result.tool_use_id.clone(),
1171 is_error: tool_result.is_error,
1172 content: tool_result.content.clone(),
1173 output: tool_result.output.clone(),
1174 })
1175 .collect(),
1176 context: message.loaded_context.text.clone(),
1177 creases: message
1178 .creases
1179 .iter()
1180 .map(|crease| SerializedCrease {
1181 start: crease.range.start,
1182 end: crease.range.end,
1183 icon_path: crease.metadata.icon_path.clone(),
1184 label: crease.metadata.label.clone(),
1185 })
1186 .collect(),
1187 is_hidden: message.is_hidden,
1188 })
1189 .collect(),
1190 initial_project_snapshot,
1191 cumulative_token_usage: this.cumulative_token_usage,
1192 request_token_usage: this.request_token_usage.clone(),
1193 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1194 exceeded_window_error: this.exceeded_window_error.clone(),
1195 model: this
1196 .configured_model
1197 .as_ref()
1198 .map(|model| SerializedLanguageModel {
1199 provider: model.provider.id().0.to_string(),
1200 model: model.model.id().0.to_string(),
1201 }),
1202 completion_mode: Some(this.completion_mode),
1203 tool_use_limit_reached: this.tool_use_limit_reached,
1204 profile: Some(this.profile.id().clone()),
1205 })
1206 })
1207 }
1208
1209 pub fn remaining_turns(&self) -> u32 {
1210 self.remaining_turns
1211 }
1212
1213 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1214 self.remaining_turns = remaining_turns;
1215 }
1216
1217 pub fn send_to_model(
1218 &mut self,
1219 model: Arc<dyn LanguageModel>,
1220 intent: CompletionIntent,
1221 window: Option<AnyWindowHandle>,
1222 cx: &mut Context<Self>,
1223 ) {
1224 if self.remaining_turns == 0 {
1225 return;
1226 }
1227
1228 self.remaining_turns -= 1;
1229
1230 let request = self.to_completion_request(model.clone(), intent, cx);
1231
1232 self.stream_completion(request, model, window, cx);
1233 }
1234
1235 pub fn used_tools_since_last_user_message(&self) -> bool {
1236 for message in self.messages.iter().rev() {
1237 if self.tool_use.message_has_tool_results(message.id) {
1238 return true;
1239 } else if message.role == Role::User {
1240 return false;
1241 }
1242 }
1243
1244 false
1245 }
1246
1247 pub fn to_completion_request(
1248 &self,
1249 model: Arc<dyn LanguageModel>,
1250 intent: CompletionIntent,
1251 cx: &mut Context<Self>,
1252 ) -> LanguageModelRequest {
1253 let mut request = LanguageModelRequest {
1254 thread_id: Some(self.id.to_string()),
1255 prompt_id: Some(self.last_prompt_id.to_string()),
1256 intent: Some(intent),
1257 mode: None,
1258 messages: vec![],
1259 tools: Vec::new(),
1260 tool_choice: None,
1261 stop: Vec::new(),
1262 temperature: AgentSettings::temperature_for_model(&model, cx),
1263 };
1264
1265 let available_tools = self.available_tools(cx, model.clone());
1266 let available_tool_names = available_tools
1267 .iter()
1268 .map(|tool| tool.name.clone())
1269 .collect();
1270
1271 let model_context = &ModelContext {
1272 available_tools: available_tool_names,
1273 };
1274
1275 if let Some(project_context) = self.project_context.borrow().as_ref() {
1276 match self
1277 .prompt_builder
1278 .generate_assistant_system_prompt(project_context, model_context)
1279 {
1280 Err(err) => {
1281 let message = format!("{err:?}").into();
1282 log::error!("{message}");
1283 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1284 header: "Error generating system prompt".into(),
1285 message,
1286 }));
1287 }
1288 Ok(system_prompt) => {
1289 request.messages.push(LanguageModelRequestMessage {
1290 role: Role::System,
1291 content: vec![MessageContent::Text(system_prompt)],
1292 cache: true,
1293 });
1294 }
1295 }
1296 } else {
1297 let message = "Context for system prompt unexpectedly not ready.".into();
1298 log::error!("{message}");
1299 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1300 header: "Error generating system prompt".into(),
1301 message,
1302 }));
1303 }
1304
1305 let mut message_ix_to_cache = None;
1306 for message in &self.messages {
1307 let mut request_message = LanguageModelRequestMessage {
1308 role: message.role,
1309 content: Vec::new(),
1310 cache: false,
1311 };
1312
1313 message
1314 .loaded_context
1315 .add_to_request_message(&mut request_message);
1316
1317 for segment in &message.segments {
1318 match segment {
1319 MessageSegment::Text(text) => {
1320 if !text.is_empty() {
1321 request_message
1322 .content
1323 .push(MessageContent::Text(text.into()));
1324 }
1325 }
1326 MessageSegment::Thinking { text, signature } => {
1327 if !text.is_empty() {
1328 request_message.content.push(MessageContent::Thinking {
1329 text: text.into(),
1330 signature: signature.clone(),
1331 });
1332 }
1333 }
1334 MessageSegment::RedactedThinking(data) => {
1335 request_message
1336 .content
1337 .push(MessageContent::RedactedThinking(data.clone()));
1338 }
1339 };
1340 }
1341
1342 let mut cache_message = true;
1343 let mut tool_results_message = LanguageModelRequestMessage {
1344 role: Role::User,
1345 content: Vec::new(),
1346 cache: false,
1347 };
1348 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1349 if let Some(tool_result) = tool_result {
1350 request_message
1351 .content
1352 .push(MessageContent::ToolUse(tool_use.clone()));
1353 tool_results_message
1354 .content
1355 .push(MessageContent::ToolResult(LanguageModelToolResult {
1356 tool_use_id: tool_use.id.clone(),
1357 tool_name: tool_result.tool_name.clone(),
1358 is_error: tool_result.is_error,
1359 content: if tool_result.content.is_empty() {
1360 // Surprisingly, the API fails if we return an empty string here.
1361 // It thinks we are sending a tool use without a tool result.
1362 "<Tool returned an empty string>".into()
1363 } else {
1364 tool_result.content.clone()
1365 },
1366 output: None,
1367 }));
1368 } else {
1369 cache_message = false;
1370 log::debug!(
1371 "skipped tool use {:?} because it is still pending",
1372 tool_use
1373 );
1374 }
1375 }
1376
1377 if cache_message {
1378 message_ix_to_cache = Some(request.messages.len());
1379 }
1380 request.messages.push(request_message);
1381
1382 if !tool_results_message.content.is_empty() {
1383 if cache_message {
1384 message_ix_to_cache = Some(request.messages.len());
1385 }
1386 request.messages.push(tool_results_message);
1387 }
1388 }
1389
1390 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1391 if let Some(message_ix_to_cache) = message_ix_to_cache {
1392 request.messages[message_ix_to_cache].cache = true;
1393 }
1394
1395 self.attached_tracked_files_state(&mut request.messages, cx);
1396
1397 request.tools = available_tools;
1398 request.mode = if model.supports_max_mode() {
1399 Some(self.completion_mode.into())
1400 } else {
1401 Some(CompletionMode::Normal.into())
1402 };
1403
1404 request
1405 }
1406
1407 fn to_summarize_request(
1408 &self,
1409 model: &Arc<dyn LanguageModel>,
1410 intent: CompletionIntent,
1411 added_user_message: String,
1412 cx: &App,
1413 ) -> LanguageModelRequest {
1414 let mut request = LanguageModelRequest {
1415 thread_id: None,
1416 prompt_id: None,
1417 intent: Some(intent),
1418 mode: None,
1419 messages: vec![],
1420 tools: Vec::new(),
1421 tool_choice: None,
1422 stop: Vec::new(),
1423 temperature: AgentSettings::temperature_for_model(model, cx),
1424 };
1425
1426 for message in &self.messages {
1427 let mut request_message = LanguageModelRequestMessage {
1428 role: message.role,
1429 content: Vec::new(),
1430 cache: false,
1431 };
1432
1433 for segment in &message.segments {
1434 match segment {
1435 MessageSegment::Text(text) => request_message
1436 .content
1437 .push(MessageContent::Text(text.clone())),
1438 MessageSegment::Thinking { .. } => {}
1439 MessageSegment::RedactedThinking(_) => {}
1440 }
1441 }
1442
1443 if request_message.content.is_empty() {
1444 continue;
1445 }
1446
1447 request.messages.push(request_message);
1448 }
1449
1450 request.messages.push(LanguageModelRequestMessage {
1451 role: Role::User,
1452 content: vec![MessageContent::Text(added_user_message)],
1453 cache: false,
1454 });
1455
1456 request
1457 }
1458
1459 fn attached_tracked_files_state(
1460 &self,
1461 messages: &mut Vec<LanguageModelRequestMessage>,
1462 cx: &App,
1463 ) {
1464 const STALE_FILES_HEADER: &str = include_str!("./prompts/stale_files_prompt_header.txt");
1465
1466 let mut stale_message = String::new();
1467
1468 let action_log = self.action_log.read(cx);
1469
1470 for stale_file in action_log.stale_buffers(cx) {
1471 let Some(file) = stale_file.read(cx).file() else {
1472 continue;
1473 };
1474
1475 if stale_message.is_empty() {
1476 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER.trim()).ok();
1477 }
1478
1479 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1480 }
1481
1482 let mut content = Vec::with_capacity(2);
1483
1484 if !stale_message.is_empty() {
1485 content.push(stale_message.into());
1486 }
1487
1488 if !content.is_empty() {
1489 let context_message = LanguageModelRequestMessage {
1490 role: Role::User,
1491 content,
1492 cache: false,
1493 };
1494
1495 messages.push(context_message);
1496 }
1497 }
1498
1499 pub fn stream_completion(
1500 &mut self,
1501 request: LanguageModelRequest,
1502 model: Arc<dyn LanguageModel>,
1503 window: Option<AnyWindowHandle>,
1504 cx: &mut Context<Self>,
1505 ) {
1506 self.tool_use_limit_reached = false;
1507
1508 let pending_completion_id = post_inc(&mut self.completion_count);
1509 let mut request_callback_parameters = if self.request_callback.is_some() {
1510 Some((request.clone(), Vec::new()))
1511 } else {
1512 None
1513 };
1514 let prompt_id = self.last_prompt_id.clone();
1515 let tool_use_metadata = ToolUseMetadata {
1516 model: model.clone(),
1517 thread_id: self.id.clone(),
1518 prompt_id: prompt_id.clone(),
1519 };
1520
1521 self.last_received_chunk_at = Some(Instant::now());
1522
1523 let task = cx.spawn(async move |thread, cx| {
1524 let stream_completion_future = model.stream_completion(request, &cx);
1525 let initial_token_usage =
1526 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1527 let stream_completion = async {
1528 let mut events = stream_completion_future.await?;
1529
1530 let mut stop_reason = StopReason::EndTurn;
1531 let mut current_token_usage = TokenUsage::default();
1532
1533 thread
1534 .update(cx, |_thread, cx| {
1535 cx.emit(ThreadEvent::NewRequest);
1536 })
1537 .ok();
1538
1539 let mut request_assistant_message_id = None;
1540
1541 while let Some(event) = events.next().await {
1542 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1543 response_events
1544 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1545 }
1546
1547 thread.update(cx, |thread, cx| {
1548 let event = match event {
1549 Ok(event) => event,
1550 Err(LanguageModelCompletionError::BadInputJson {
1551 id,
1552 tool_name,
1553 raw_input: invalid_input_json,
1554 json_parse_error,
1555 }) => {
1556 thread.receive_invalid_tool_json(
1557 id,
1558 tool_name,
1559 invalid_input_json,
1560 json_parse_error,
1561 window,
1562 cx,
1563 );
1564 return Ok(());
1565 }
1566 Err(LanguageModelCompletionError::Other(error)) => {
1567 return Err(error);
1568 }
1569 Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
1570 return Err(err.into());
1571 }
1572 };
1573
1574 match event {
1575 LanguageModelCompletionEvent::StartMessage { .. } => {
1576 request_assistant_message_id =
1577 Some(thread.insert_assistant_message(
1578 vec![MessageSegment::Text(String::new())],
1579 cx,
1580 ));
1581 }
1582 LanguageModelCompletionEvent::Stop(reason) => {
1583 stop_reason = reason;
1584 }
1585 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1586 thread.update_token_usage_at_last_message(token_usage);
1587 thread.cumulative_token_usage = thread.cumulative_token_usage
1588 + token_usage
1589 - current_token_usage;
1590 current_token_usage = token_usage;
1591 }
1592 LanguageModelCompletionEvent::Text(chunk) => {
1593 thread.received_chunk();
1594
1595 cx.emit(ThreadEvent::ReceivedTextChunk);
1596 if let Some(last_message) = thread.messages.last_mut() {
1597 if last_message.role == Role::Assistant
1598 && !thread.tool_use.has_tool_results(last_message.id)
1599 {
1600 last_message.push_text(&chunk);
1601 cx.emit(ThreadEvent::StreamedAssistantText(
1602 last_message.id,
1603 chunk,
1604 ));
1605 } else {
1606 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1607 // of a new Assistant response.
1608 //
1609 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1610 // will result in duplicating the text of the chunk in the rendered Markdown.
1611 request_assistant_message_id =
1612 Some(thread.insert_assistant_message(
1613 vec![MessageSegment::Text(chunk.to_string())],
1614 cx,
1615 ));
1616 };
1617 }
1618 }
1619 LanguageModelCompletionEvent::Thinking {
1620 text: chunk,
1621 signature,
1622 } => {
1623 thread.received_chunk();
1624
1625 if let Some(last_message) = thread.messages.last_mut() {
1626 if last_message.role == Role::Assistant
1627 && !thread.tool_use.has_tool_results(last_message.id)
1628 {
1629 last_message.push_thinking(&chunk, signature);
1630 cx.emit(ThreadEvent::StreamedAssistantThinking(
1631 last_message.id,
1632 chunk,
1633 ));
1634 } else {
1635 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1636 // of a new Assistant response.
1637 //
1638 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1639 // will result in duplicating the text of the chunk in the rendered Markdown.
1640 request_assistant_message_id =
1641 Some(thread.insert_assistant_message(
1642 vec![MessageSegment::Thinking {
1643 text: chunk.to_string(),
1644 signature,
1645 }],
1646 cx,
1647 ));
1648 };
1649 }
1650 }
1651 LanguageModelCompletionEvent::RedactedThinking {
1652 data
1653 } => {
1654 thread.received_chunk();
1655
1656 if let Some(last_message) = thread.messages.last_mut() {
1657 if last_message.role == Role::Assistant
1658 && !thread.tool_use.has_tool_results(last_message.id)
1659 {
1660 last_message.push_redacted_thinking(data);
1661 } else {
1662 request_assistant_message_id =
1663 Some(thread.insert_assistant_message(
1664 vec![MessageSegment::RedactedThinking(data)],
1665 cx,
1666 ));
1667 };
1668 }
1669 }
1670 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1671 let last_assistant_message_id = request_assistant_message_id
1672 .unwrap_or_else(|| {
1673 let new_assistant_message_id =
1674 thread.insert_assistant_message(vec![], cx);
1675 request_assistant_message_id =
1676 Some(new_assistant_message_id);
1677 new_assistant_message_id
1678 });
1679
1680 let tool_use_id = tool_use.id.clone();
1681 let streamed_input = if tool_use.is_input_complete {
1682 None
1683 } else {
1684 Some((&tool_use.input).clone())
1685 };
1686
1687 let ui_text = thread.tool_use.request_tool_use(
1688 last_assistant_message_id,
1689 tool_use,
1690 tool_use_metadata.clone(),
1691 cx,
1692 );
1693
1694 if let Some(input) = streamed_input {
1695 cx.emit(ThreadEvent::StreamedToolUse {
1696 tool_use_id,
1697 ui_text,
1698 input,
1699 });
1700 }
1701 }
1702 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1703 if let Some(completion) = thread
1704 .pending_completions
1705 .iter_mut()
1706 .find(|completion| completion.id == pending_completion_id)
1707 {
1708 match status_update {
1709 CompletionRequestStatus::Queued {
1710 position,
1711 } => {
1712 completion.queue_state = QueueState::Queued { position };
1713 }
1714 CompletionRequestStatus::Started => {
1715 completion.queue_state = QueueState::Started;
1716 }
1717 CompletionRequestStatus::Failed {
1718 code, message, request_id
1719 } => {
1720 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1721 }
1722 CompletionRequestStatus::UsageUpdated {
1723 amount, limit
1724 } => {
1725 let usage = RequestUsage { limit, amount: amount as i32 };
1726
1727 thread.last_usage = Some(usage);
1728 }
1729 CompletionRequestStatus::ToolUseLimitReached => {
1730 thread.tool_use_limit_reached = true;
1731 cx.emit(ThreadEvent::ToolUseLimitReached);
1732 }
1733 }
1734 }
1735 }
1736 }
1737
1738 thread.touch_updated_at();
1739 cx.emit(ThreadEvent::StreamedCompletion);
1740 cx.notify();
1741
1742 thread.auto_capture_telemetry(cx);
1743 Ok(())
1744 })??;
1745
1746 smol::future::yield_now().await;
1747 }
1748
1749 thread.update(cx, |thread, cx| {
1750 thread.last_received_chunk_at = None;
1751 thread
1752 .pending_completions
1753 .retain(|completion| completion.id != pending_completion_id);
1754
1755 // If there is a response without tool use, summarize the message. Otherwise,
1756 // allow two tool uses before summarizing.
1757 if matches!(thread.summary, ThreadSummary::Pending)
1758 && thread.messages.len() >= 2
1759 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1760 {
1761 thread.summarize(cx);
1762 }
1763 })?;
1764
1765 anyhow::Ok(stop_reason)
1766 };
1767
1768 let result = stream_completion.await;
1769
1770 thread
1771 .update(cx, |thread, cx| {
1772 thread.finalize_pending_checkpoint(cx);
1773 match result.as_ref() {
1774 Ok(stop_reason) => match stop_reason {
1775 StopReason::ToolUse => {
1776 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1777 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1778 }
1779 StopReason::EndTurn | StopReason::MaxTokens => {
1780 thread.project.update(cx, |project, cx| {
1781 project.set_agent_location(None, cx);
1782 });
1783 }
1784 StopReason::Refusal => {
1785 thread.project.update(cx, |project, cx| {
1786 project.set_agent_location(None, cx);
1787 });
1788
1789 // Remove the turn that was refused.
1790 //
1791 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1792 {
1793 let mut messages_to_remove = Vec::new();
1794
1795 for (ix, message) in thread.messages.iter().enumerate().rev() {
1796 messages_to_remove.push(message.id);
1797
1798 if message.role == Role::User {
1799 if ix == 0 {
1800 break;
1801 }
1802
1803 if let Some(prev_message) = thread.messages.get(ix - 1) {
1804 if prev_message.role == Role::Assistant {
1805 break;
1806 }
1807 }
1808 }
1809 }
1810
1811 for message_id in messages_to_remove {
1812 thread.delete_message(message_id, cx);
1813 }
1814 }
1815
1816 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1817 header: "Language model refusal".into(),
1818 message: "Model refused to generate content for safety reasons.".into(),
1819 }));
1820 }
1821 },
1822 Err(error) => {
1823 thread.project.update(cx, |project, cx| {
1824 project.set_agent_location(None, cx);
1825 });
1826
1827 if error.is::<PaymentRequiredError>() {
1828 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1829 } else if let Some(error) =
1830 error.downcast_ref::<ModelRequestLimitReachedError>()
1831 {
1832 cx.emit(ThreadEvent::ShowError(
1833 ThreadError::ModelRequestLimitReached { plan: error.plan },
1834 ));
1835 } else if let Some(known_error) =
1836 error.downcast_ref::<LanguageModelKnownError>()
1837 {
1838 match known_error {
1839 LanguageModelKnownError::ContextWindowLimitExceeded {
1840 tokens,
1841 } => {
1842 thread.exceeded_window_error = Some(ExceededWindowError {
1843 model_id: model.id(),
1844 token_count: *tokens,
1845 });
1846 cx.notify();
1847 }
1848 }
1849 } else {
1850 let error_message = error
1851 .chain()
1852 .map(|err| err.to_string())
1853 .collect::<Vec<_>>()
1854 .join("\n");
1855 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1856 header: "Error interacting with language model".into(),
1857 message: SharedString::from(error_message.clone()),
1858 }));
1859 }
1860
1861 thread.cancel_last_completion(window, cx);
1862 }
1863 }
1864
1865 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1866
1867 if let Some((request_callback, (request, response_events))) = thread
1868 .request_callback
1869 .as_mut()
1870 .zip(request_callback_parameters.as_ref())
1871 {
1872 request_callback(request, response_events);
1873 }
1874
1875 thread.auto_capture_telemetry(cx);
1876
1877 if let Ok(initial_usage) = initial_token_usage {
1878 let usage = thread.cumulative_token_usage - initial_usage;
1879
1880 telemetry::event!(
1881 "Assistant Thread Completion",
1882 thread_id = thread.id().to_string(),
1883 prompt_id = prompt_id,
1884 model = model.telemetry_id(),
1885 model_provider = model.provider_id().to_string(),
1886 input_tokens = usage.input_tokens,
1887 output_tokens = usage.output_tokens,
1888 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1889 cache_read_input_tokens = usage.cache_read_input_tokens,
1890 );
1891 }
1892 })
1893 .ok();
1894 });
1895
1896 self.pending_completions.push(PendingCompletion {
1897 id: pending_completion_id,
1898 queue_state: QueueState::Sending,
1899 _task: task,
1900 });
1901 }
1902
1903 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1904 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1905 println!("No thread summary model");
1906 return;
1907 };
1908
1909 if !model.provider.is_authenticated(cx) {
1910 return;
1911 }
1912
1913 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1914
1915 let request = self.to_summarize_request(
1916 &model.model,
1917 CompletionIntent::ThreadSummarization,
1918 added_user_message.into(),
1919 cx,
1920 );
1921
1922 self.summary = ThreadSummary::Generating;
1923
1924 self.pending_summary = cx.spawn(async move |this, cx| {
1925 let result = async {
1926 let mut messages = model.model.stream_completion(request, &cx).await?;
1927
1928 let mut new_summary = String::new();
1929 while let Some(event) = messages.next().await {
1930 let Ok(event) = event else {
1931 continue;
1932 };
1933 let text = match event {
1934 LanguageModelCompletionEvent::Text(text) => text,
1935 LanguageModelCompletionEvent::StatusUpdate(
1936 CompletionRequestStatus::UsageUpdated { amount, limit },
1937 ) => {
1938 this.update(cx, |thread, _cx| {
1939 thread.last_usage = Some(RequestUsage {
1940 limit,
1941 amount: amount as i32,
1942 });
1943 })?;
1944 continue;
1945 }
1946 _ => continue,
1947 };
1948
1949 let mut lines = text.lines();
1950 new_summary.extend(lines.next());
1951
1952 // Stop if the LLM generated multiple lines.
1953 if lines.next().is_some() {
1954 break;
1955 }
1956 }
1957
1958 anyhow::Ok(new_summary)
1959 }
1960 .await;
1961
1962 this.update(cx, |this, cx| {
1963 match result {
1964 Ok(new_summary) => {
1965 if new_summary.is_empty() {
1966 this.summary = ThreadSummary::Error;
1967 } else {
1968 this.summary = ThreadSummary::Ready(new_summary.into());
1969 }
1970 }
1971 Err(err) => {
1972 this.summary = ThreadSummary::Error;
1973 log::error!("Failed to generate thread summary: {}", err);
1974 }
1975 }
1976 cx.emit(ThreadEvent::SummaryGenerated);
1977 })
1978 .log_err()?;
1979
1980 Some(())
1981 });
1982 }
1983
1984 pub fn start_generating_detailed_summary_if_needed(
1985 &mut self,
1986 thread_store: WeakEntity<ThreadStore>,
1987 cx: &mut Context<Self>,
1988 ) {
1989 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1990 return;
1991 };
1992
1993 match &*self.detailed_summary_rx.borrow() {
1994 DetailedSummaryState::Generating { message_id, .. }
1995 | DetailedSummaryState::Generated { message_id, .. }
1996 if *message_id == last_message_id =>
1997 {
1998 // Already up-to-date
1999 return;
2000 }
2001 _ => {}
2002 }
2003
2004 let Some(ConfiguredModel { model, provider }) =
2005 LanguageModelRegistry::read_global(cx).thread_summary_model()
2006 else {
2007 return;
2008 };
2009
2010 if !provider.is_authenticated(cx) {
2011 return;
2012 }
2013
2014 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
2015
2016 let request = self.to_summarize_request(
2017 &model,
2018 CompletionIntent::ThreadContextSummarization,
2019 added_user_message.into(),
2020 cx,
2021 );
2022
2023 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
2024 message_id: last_message_id,
2025 };
2026
2027 // Replace the detailed summarization task if there is one, cancelling it. It would probably
2028 // be better to allow the old task to complete, but this would require logic for choosing
2029 // which result to prefer (the old task could complete after the new one, resulting in a
2030 // stale summary).
2031 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
2032 let stream = model.stream_completion_text(request, &cx);
2033 let Some(mut messages) = stream.await.log_err() else {
2034 thread
2035 .update(cx, |thread, _cx| {
2036 *thread.detailed_summary_tx.borrow_mut() =
2037 DetailedSummaryState::NotGenerated;
2038 })
2039 .ok()?;
2040 return None;
2041 };
2042
2043 let mut new_detailed_summary = String::new();
2044
2045 while let Some(chunk) = messages.stream.next().await {
2046 if let Some(chunk) = chunk.log_err() {
2047 new_detailed_summary.push_str(&chunk);
2048 }
2049 }
2050
2051 thread
2052 .update(cx, |thread, _cx| {
2053 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2054 text: new_detailed_summary.into(),
2055 message_id: last_message_id,
2056 };
2057 })
2058 .ok()?;
2059
2060 // Save thread so its summary can be reused later
2061 if let Some(thread) = thread.upgrade() {
2062 if let Ok(Ok(save_task)) = cx.update(|cx| {
2063 thread_store
2064 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2065 }) {
2066 save_task.await.log_err();
2067 }
2068 }
2069
2070 Some(())
2071 });
2072 }
2073
2074 pub async fn wait_for_detailed_summary_or_text(
2075 this: &Entity<Self>,
2076 cx: &mut AsyncApp,
2077 ) -> Option<SharedString> {
2078 let mut detailed_summary_rx = this
2079 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2080 .ok()?;
2081 loop {
2082 match detailed_summary_rx.recv().await? {
2083 DetailedSummaryState::Generating { .. } => {}
2084 DetailedSummaryState::NotGenerated => {
2085 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2086 }
2087 DetailedSummaryState::Generated { text, .. } => return Some(text),
2088 }
2089 }
2090 }
2091
2092 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2093 self.detailed_summary_rx
2094 .borrow()
2095 .text()
2096 .unwrap_or_else(|| self.text().into())
2097 }
2098
2099 pub fn is_generating_detailed_summary(&self) -> bool {
2100 matches!(
2101 &*self.detailed_summary_rx.borrow(),
2102 DetailedSummaryState::Generating { .. }
2103 )
2104 }
2105
2106 pub fn use_pending_tools(
2107 &mut self,
2108 window: Option<AnyWindowHandle>,
2109 cx: &mut Context<Self>,
2110 model: Arc<dyn LanguageModel>,
2111 ) -> Vec<PendingToolUse> {
2112 self.auto_capture_telemetry(cx);
2113 let request =
2114 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2115 let pending_tool_uses = self
2116 .tool_use
2117 .pending_tool_uses()
2118 .into_iter()
2119 .filter(|tool_use| tool_use.status.is_idle())
2120 .cloned()
2121 .collect::<Vec<_>>();
2122
2123 for tool_use in pending_tool_uses.iter() {
2124 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2125 if tool.needs_confirmation(&tool_use.input, cx)
2126 && !AgentSettings::get_global(cx).always_allow_tool_actions
2127 {
2128 self.tool_use.confirm_tool_use(
2129 tool_use.id.clone(),
2130 tool_use.ui_text.clone(),
2131 tool_use.input.clone(),
2132 request.clone(),
2133 tool,
2134 );
2135 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2136 } else {
2137 self.run_tool(
2138 tool_use.id.clone(),
2139 tool_use.ui_text.clone(),
2140 tool_use.input.clone(),
2141 request.clone(),
2142 tool,
2143 model.clone(),
2144 window,
2145 cx,
2146 );
2147 }
2148 } else {
2149 self.handle_hallucinated_tool_use(
2150 tool_use.id.clone(),
2151 tool_use.name.clone(),
2152 window,
2153 cx,
2154 );
2155 }
2156 }
2157
2158 pending_tool_uses
2159 }
2160
2161 pub fn handle_hallucinated_tool_use(
2162 &mut self,
2163 tool_use_id: LanguageModelToolUseId,
2164 hallucinated_tool_name: Arc<str>,
2165 window: Option<AnyWindowHandle>,
2166 cx: &mut Context<Thread>,
2167 ) {
2168 let available_tools = self.profile.enabled_tools(cx);
2169
2170 let tool_list = available_tools
2171 .iter()
2172 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2173 .collect::<Vec<_>>()
2174 .join("\n");
2175
2176 let error_message = format!(
2177 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2178 hallucinated_tool_name, tool_list
2179 );
2180
2181 let pending_tool_use = self.tool_use.insert_tool_output(
2182 tool_use_id.clone(),
2183 hallucinated_tool_name,
2184 Err(anyhow!("Missing tool call: {error_message}")),
2185 self.configured_model.as_ref(),
2186 );
2187
2188 cx.emit(ThreadEvent::MissingToolUse {
2189 tool_use_id: tool_use_id.clone(),
2190 ui_text: error_message.into(),
2191 });
2192
2193 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2194 }
2195
2196 pub fn receive_invalid_tool_json(
2197 &mut self,
2198 tool_use_id: LanguageModelToolUseId,
2199 tool_name: Arc<str>,
2200 invalid_json: Arc<str>,
2201 error: String,
2202 window: Option<AnyWindowHandle>,
2203 cx: &mut Context<Thread>,
2204 ) {
2205 log::error!("The model returned invalid input JSON: {invalid_json}");
2206
2207 let pending_tool_use = self.tool_use.insert_tool_output(
2208 tool_use_id.clone(),
2209 tool_name,
2210 Err(anyhow!("Error parsing input JSON: {error}")),
2211 self.configured_model.as_ref(),
2212 );
2213 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2214 pending_tool_use.ui_text.clone()
2215 } else {
2216 log::error!(
2217 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2218 );
2219 format!("Unknown tool {}", tool_use_id).into()
2220 };
2221
2222 cx.emit(ThreadEvent::InvalidToolInput {
2223 tool_use_id: tool_use_id.clone(),
2224 ui_text,
2225 invalid_input_json: invalid_json,
2226 });
2227
2228 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2229 }
2230
2231 pub fn run_tool(
2232 &mut self,
2233 tool_use_id: LanguageModelToolUseId,
2234 ui_text: impl Into<SharedString>,
2235 input: serde_json::Value,
2236 request: Arc<LanguageModelRequest>,
2237 tool: Arc<dyn Tool>,
2238 model: Arc<dyn LanguageModel>,
2239 window: Option<AnyWindowHandle>,
2240 cx: &mut Context<Thread>,
2241 ) {
2242 let task =
2243 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2244 self.tool_use
2245 .run_pending_tool(tool_use_id, ui_text.into(), task);
2246 }
2247
2248 fn spawn_tool_use(
2249 &mut self,
2250 tool_use_id: LanguageModelToolUseId,
2251 request: Arc<LanguageModelRequest>,
2252 input: serde_json::Value,
2253 tool: Arc<dyn Tool>,
2254 model: Arc<dyn LanguageModel>,
2255 window: Option<AnyWindowHandle>,
2256 cx: &mut Context<Thread>,
2257 ) -> Task<()> {
2258 let tool_name: Arc<str> = tool.name().into();
2259
2260 let tool_result = tool.run(
2261 input,
2262 request,
2263 self.project.clone(),
2264 self.action_log.clone(),
2265 model,
2266 window,
2267 cx,
2268 );
2269
2270 // Store the card separately if it exists
2271 if let Some(card) = tool_result.card.clone() {
2272 self.tool_use
2273 .insert_tool_result_card(tool_use_id.clone(), card);
2274 }
2275
2276 cx.spawn({
2277 async move |thread: WeakEntity<Thread>, cx| {
2278 let output = tool_result.output.await;
2279
2280 thread
2281 .update(cx, |thread, cx| {
2282 let pending_tool_use = thread.tool_use.insert_tool_output(
2283 tool_use_id.clone(),
2284 tool_name,
2285 output,
2286 thread.configured_model.as_ref(),
2287 );
2288 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2289 })
2290 .ok();
2291 }
2292 })
2293 }
2294
2295 fn tool_finished(
2296 &mut self,
2297 tool_use_id: LanguageModelToolUseId,
2298 pending_tool_use: Option<PendingToolUse>,
2299 canceled: bool,
2300 window: Option<AnyWindowHandle>,
2301 cx: &mut Context<Self>,
2302 ) {
2303 if self.all_tools_finished() {
2304 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2305 if !canceled {
2306 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2307 }
2308 self.auto_capture_telemetry(cx);
2309 }
2310 }
2311
2312 cx.emit(ThreadEvent::ToolFinished {
2313 tool_use_id,
2314 pending_tool_use,
2315 });
2316 }
2317
2318 /// Cancels the last pending completion, if there are any pending.
2319 ///
2320 /// Returns whether a completion was canceled.
2321 pub fn cancel_last_completion(
2322 &mut self,
2323 window: Option<AnyWindowHandle>,
2324 cx: &mut Context<Self>,
2325 ) -> bool {
2326 let mut canceled = self.pending_completions.pop().is_some();
2327
2328 for pending_tool_use in self.tool_use.cancel_pending() {
2329 canceled = true;
2330 self.tool_finished(
2331 pending_tool_use.id.clone(),
2332 Some(pending_tool_use),
2333 true,
2334 window,
2335 cx,
2336 );
2337 }
2338
2339 if canceled {
2340 cx.emit(ThreadEvent::CompletionCanceled);
2341
2342 // When canceled, we always want to insert the checkpoint.
2343 // (We skip over finalize_pending_checkpoint, because it
2344 // would conclude we didn't have anything to insert here.)
2345 if let Some(checkpoint) = self.pending_checkpoint.take() {
2346 self.insert_checkpoint(checkpoint, cx);
2347 }
2348 } else {
2349 self.finalize_pending_checkpoint(cx);
2350 }
2351
2352 canceled
2353 }
2354
2355 /// Signals that any in-progress editing should be canceled.
2356 ///
2357 /// This method is used to notify listeners (like ActiveThread) that
2358 /// they should cancel any editing operations.
2359 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2360 cx.emit(ThreadEvent::CancelEditing);
2361 }
2362
2363 pub fn feedback(&self) -> Option<ThreadFeedback> {
2364 self.feedback
2365 }
2366
2367 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2368 self.message_feedback.get(&message_id).copied()
2369 }
2370
2371 pub fn report_message_feedback(
2372 &mut self,
2373 message_id: MessageId,
2374 feedback: ThreadFeedback,
2375 cx: &mut Context<Self>,
2376 ) -> Task<Result<()>> {
2377 if self.message_feedback.get(&message_id) == Some(&feedback) {
2378 return Task::ready(Ok(()));
2379 }
2380
2381 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2382 let serialized_thread = self.serialize(cx);
2383 let thread_id = self.id().clone();
2384 let client = self.project.read(cx).client();
2385
2386 let enabled_tool_names: Vec<String> = self
2387 .profile
2388 .enabled_tools(cx)
2389 .iter()
2390 .map(|tool| tool.name())
2391 .collect();
2392
2393 self.message_feedback.insert(message_id, feedback);
2394
2395 cx.notify();
2396
2397 let message_content = self
2398 .message(message_id)
2399 .map(|msg| msg.to_string())
2400 .unwrap_or_default();
2401
2402 cx.background_spawn(async move {
2403 let final_project_snapshot = final_project_snapshot.await;
2404 let serialized_thread = serialized_thread.await?;
2405 let thread_data =
2406 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2407
2408 let rating = match feedback {
2409 ThreadFeedback::Positive => "positive",
2410 ThreadFeedback::Negative => "negative",
2411 };
2412 telemetry::event!(
2413 "Assistant Thread Rated",
2414 rating,
2415 thread_id,
2416 enabled_tool_names,
2417 message_id = message_id.0,
2418 message_content,
2419 thread_data,
2420 final_project_snapshot
2421 );
2422 client.telemetry().flush_events().await;
2423
2424 Ok(())
2425 })
2426 }
2427
2428 pub fn report_feedback(
2429 &mut self,
2430 feedback: ThreadFeedback,
2431 cx: &mut Context<Self>,
2432 ) -> Task<Result<()>> {
2433 let last_assistant_message_id = self
2434 .messages
2435 .iter()
2436 .rev()
2437 .find(|msg| msg.role == Role::Assistant)
2438 .map(|msg| msg.id);
2439
2440 if let Some(message_id) = last_assistant_message_id {
2441 self.report_message_feedback(message_id, feedback, cx)
2442 } else {
2443 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2444 let serialized_thread = self.serialize(cx);
2445 let thread_id = self.id().clone();
2446 let client = self.project.read(cx).client();
2447 self.feedback = Some(feedback);
2448 cx.notify();
2449
2450 cx.background_spawn(async move {
2451 let final_project_snapshot = final_project_snapshot.await;
2452 let serialized_thread = serialized_thread.await?;
2453 let thread_data = serde_json::to_value(serialized_thread)
2454 .unwrap_or_else(|_| serde_json::Value::Null);
2455
2456 let rating = match feedback {
2457 ThreadFeedback::Positive => "positive",
2458 ThreadFeedback::Negative => "negative",
2459 };
2460 telemetry::event!(
2461 "Assistant Thread Rated",
2462 rating,
2463 thread_id,
2464 thread_data,
2465 final_project_snapshot
2466 );
2467 client.telemetry().flush_events().await;
2468
2469 Ok(())
2470 })
2471 }
2472 }
2473
2474 /// Create a snapshot of the current project state including git information and unsaved buffers.
2475 fn project_snapshot(
2476 project: Entity<Project>,
2477 cx: &mut Context<Self>,
2478 ) -> Task<Arc<ProjectSnapshot>> {
2479 let git_store = project.read(cx).git_store().clone();
2480 let worktree_snapshots: Vec<_> = project
2481 .read(cx)
2482 .visible_worktrees(cx)
2483 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2484 .collect();
2485
2486 cx.spawn(async move |_, cx| {
2487 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2488
2489 let mut unsaved_buffers = Vec::new();
2490 cx.update(|app_cx| {
2491 let buffer_store = project.read(app_cx).buffer_store();
2492 for buffer_handle in buffer_store.read(app_cx).buffers() {
2493 let buffer = buffer_handle.read(app_cx);
2494 if buffer.is_dirty() {
2495 if let Some(file) = buffer.file() {
2496 let path = file.path().to_string_lossy().to_string();
2497 unsaved_buffers.push(path);
2498 }
2499 }
2500 }
2501 })
2502 .ok();
2503
2504 Arc::new(ProjectSnapshot {
2505 worktree_snapshots,
2506 unsaved_buffer_paths: unsaved_buffers,
2507 timestamp: Utc::now(),
2508 })
2509 })
2510 }
2511
2512 fn worktree_snapshot(
2513 worktree: Entity<project::Worktree>,
2514 git_store: Entity<GitStore>,
2515 cx: &App,
2516 ) -> Task<WorktreeSnapshot> {
2517 cx.spawn(async move |cx| {
2518 // Get worktree path and snapshot
2519 let worktree_info = cx.update(|app_cx| {
2520 let worktree = worktree.read(app_cx);
2521 let path = worktree.abs_path().to_string_lossy().to_string();
2522 let snapshot = worktree.snapshot();
2523 (path, snapshot)
2524 });
2525
2526 let Ok((worktree_path, _snapshot)) = worktree_info else {
2527 return WorktreeSnapshot {
2528 worktree_path: String::new(),
2529 git_state: None,
2530 };
2531 };
2532
2533 let git_state = git_store
2534 .update(cx, |git_store, cx| {
2535 git_store
2536 .repositories()
2537 .values()
2538 .find(|repo| {
2539 repo.read(cx)
2540 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2541 .is_some()
2542 })
2543 .cloned()
2544 })
2545 .ok()
2546 .flatten()
2547 .map(|repo| {
2548 repo.update(cx, |repo, _| {
2549 let current_branch =
2550 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2551 repo.send_job(None, |state, _| async move {
2552 let RepositoryState::Local { backend, .. } = state else {
2553 return GitState {
2554 remote_url: None,
2555 head_sha: None,
2556 current_branch,
2557 diff: None,
2558 };
2559 };
2560
2561 let remote_url = backend.remote_url("origin");
2562 let head_sha = backend.head_sha().await;
2563 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2564
2565 GitState {
2566 remote_url,
2567 head_sha,
2568 current_branch,
2569 diff,
2570 }
2571 })
2572 })
2573 });
2574
2575 let git_state = match git_state {
2576 Some(git_state) => match git_state.ok() {
2577 Some(git_state) => git_state.await.ok(),
2578 None => None,
2579 },
2580 None => None,
2581 };
2582
2583 WorktreeSnapshot {
2584 worktree_path,
2585 git_state,
2586 }
2587 })
2588 }
2589
2590 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2591 let mut markdown = Vec::new();
2592
2593 let summary = self.summary().or_default();
2594 writeln!(markdown, "# {summary}\n")?;
2595
2596 for message in self.messages() {
2597 writeln!(
2598 markdown,
2599 "## {role}\n",
2600 role = match message.role {
2601 Role::User => "User",
2602 Role::Assistant => "Agent",
2603 Role::System => "System",
2604 }
2605 )?;
2606
2607 if !message.loaded_context.text.is_empty() {
2608 writeln!(markdown, "{}", message.loaded_context.text)?;
2609 }
2610
2611 if !message.loaded_context.images.is_empty() {
2612 writeln!(
2613 markdown,
2614 "\n{} images attached as context.\n",
2615 message.loaded_context.images.len()
2616 )?;
2617 }
2618
2619 for segment in &message.segments {
2620 match segment {
2621 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2622 MessageSegment::Thinking { text, .. } => {
2623 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2624 }
2625 MessageSegment::RedactedThinking(_) => {}
2626 }
2627 }
2628
2629 for tool_use in self.tool_uses_for_message(message.id, cx) {
2630 writeln!(
2631 markdown,
2632 "**Use Tool: {} ({})**",
2633 tool_use.name, tool_use.id
2634 )?;
2635 writeln!(markdown, "```json")?;
2636 writeln!(
2637 markdown,
2638 "{}",
2639 serde_json::to_string_pretty(&tool_use.input)?
2640 )?;
2641 writeln!(markdown, "```")?;
2642 }
2643
2644 for tool_result in self.tool_results_for_message(message.id) {
2645 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2646 if tool_result.is_error {
2647 write!(markdown, " (Error)")?;
2648 }
2649
2650 writeln!(markdown, "**\n")?;
2651 match &tool_result.content {
2652 LanguageModelToolResultContent::Text(text) => {
2653 writeln!(markdown, "{text}")?;
2654 }
2655 LanguageModelToolResultContent::Image(image) => {
2656 writeln!(markdown, "", image.source)?;
2657 }
2658 }
2659
2660 if let Some(output) = tool_result.output.as_ref() {
2661 writeln!(
2662 markdown,
2663 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2664 serde_json::to_string_pretty(output)?
2665 )?;
2666 }
2667 }
2668 }
2669
2670 Ok(String::from_utf8_lossy(&markdown).to_string())
2671 }
2672
2673 pub fn keep_edits_in_range(
2674 &mut self,
2675 buffer: Entity<language::Buffer>,
2676 buffer_range: Range<language::Anchor>,
2677 cx: &mut Context<Self>,
2678 ) {
2679 self.action_log.update(cx, |action_log, cx| {
2680 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2681 });
2682 }
2683
2684 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2685 self.action_log
2686 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2687 }
2688
2689 pub fn reject_edits_in_ranges(
2690 &mut self,
2691 buffer: Entity<language::Buffer>,
2692 buffer_ranges: Vec<Range<language::Anchor>>,
2693 cx: &mut Context<Self>,
2694 ) -> Task<Result<()>> {
2695 self.action_log.update(cx, |action_log, cx| {
2696 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2697 })
2698 }
2699
2700 pub fn action_log(&self) -> &Entity<ActionLog> {
2701 &self.action_log
2702 }
2703
2704 pub fn project(&self) -> &Entity<Project> {
2705 &self.project
2706 }
2707
2708 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2709 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2710 return;
2711 }
2712
2713 let now = Instant::now();
2714 if let Some(last) = self.last_auto_capture_at {
2715 if now.duration_since(last).as_secs() < 10 {
2716 return;
2717 }
2718 }
2719
2720 self.last_auto_capture_at = Some(now);
2721
2722 let thread_id = self.id().clone();
2723 let github_login = self
2724 .project
2725 .read(cx)
2726 .user_store()
2727 .read(cx)
2728 .current_user()
2729 .map(|user| user.github_login.clone());
2730 let client = self.project.read(cx).client();
2731 let serialize_task = self.serialize(cx);
2732
2733 cx.background_executor()
2734 .spawn(async move {
2735 if let Ok(serialized_thread) = serialize_task.await {
2736 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2737 telemetry::event!(
2738 "Agent Thread Auto-Captured",
2739 thread_id = thread_id.to_string(),
2740 thread_data = thread_data,
2741 auto_capture_reason = "tracked_user",
2742 github_login = github_login
2743 );
2744
2745 client.telemetry().flush_events().await;
2746 }
2747 }
2748 })
2749 .detach();
2750 }
2751
2752 pub fn cumulative_token_usage(&self) -> TokenUsage {
2753 self.cumulative_token_usage
2754 }
2755
2756 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2757 let Some(model) = self.configured_model.as_ref() else {
2758 return TotalTokenUsage::default();
2759 };
2760
2761 let max = model.model.max_token_count();
2762
2763 let index = self
2764 .messages
2765 .iter()
2766 .position(|msg| msg.id == message_id)
2767 .unwrap_or(0);
2768
2769 if index == 0 {
2770 return TotalTokenUsage { total: 0, max };
2771 }
2772
2773 let token_usage = &self
2774 .request_token_usage
2775 .get(index - 1)
2776 .cloned()
2777 .unwrap_or_default();
2778
2779 TotalTokenUsage {
2780 total: token_usage.total_tokens() as usize,
2781 max,
2782 }
2783 }
2784
2785 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2786 let model = self.configured_model.as_ref()?;
2787
2788 let max = model.model.max_token_count();
2789
2790 if let Some(exceeded_error) = &self.exceeded_window_error {
2791 if model.model.id() == exceeded_error.model_id {
2792 return Some(TotalTokenUsage {
2793 total: exceeded_error.token_count,
2794 max,
2795 });
2796 }
2797 }
2798
2799 let total = self
2800 .token_usage_at_last_message()
2801 .unwrap_or_default()
2802 .total_tokens() as usize;
2803
2804 Some(TotalTokenUsage { total, max })
2805 }
2806
2807 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2808 self.request_token_usage
2809 .get(self.messages.len().saturating_sub(1))
2810 .or_else(|| self.request_token_usage.last())
2811 .cloned()
2812 }
2813
2814 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2815 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2816 self.request_token_usage
2817 .resize(self.messages.len(), placeholder);
2818
2819 if let Some(last) = self.request_token_usage.last_mut() {
2820 *last = token_usage;
2821 }
2822 }
2823
2824 pub fn deny_tool_use(
2825 &mut self,
2826 tool_use_id: LanguageModelToolUseId,
2827 tool_name: Arc<str>,
2828 window: Option<AnyWindowHandle>,
2829 cx: &mut Context<Self>,
2830 ) {
2831 let err = Err(anyhow::anyhow!(
2832 "Permission to run tool action denied by user"
2833 ));
2834
2835 self.tool_use.insert_tool_output(
2836 tool_use_id.clone(),
2837 tool_name,
2838 err,
2839 self.configured_model.as_ref(),
2840 );
2841 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2842 }
2843}
2844
2845#[derive(Debug, Clone, Error)]
2846pub enum ThreadError {
2847 #[error("Payment required")]
2848 PaymentRequired,
2849 #[error("Model request limit reached")]
2850 ModelRequestLimitReached { plan: Plan },
2851 #[error("Message {header}: {message}")]
2852 Message {
2853 header: SharedString,
2854 message: SharedString,
2855 },
2856}
2857
2858#[derive(Debug, Clone)]
2859pub enum ThreadEvent {
2860 ShowError(ThreadError),
2861 StreamedCompletion,
2862 ReceivedTextChunk,
2863 NewRequest,
2864 StreamedAssistantText(MessageId, String),
2865 StreamedAssistantThinking(MessageId, String),
2866 StreamedToolUse {
2867 tool_use_id: LanguageModelToolUseId,
2868 ui_text: Arc<str>,
2869 input: serde_json::Value,
2870 },
2871 MissingToolUse {
2872 tool_use_id: LanguageModelToolUseId,
2873 ui_text: Arc<str>,
2874 },
2875 InvalidToolInput {
2876 tool_use_id: LanguageModelToolUseId,
2877 ui_text: Arc<str>,
2878 invalid_input_json: Arc<str>,
2879 },
2880 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2881 MessageAdded(MessageId),
2882 MessageEdited(MessageId),
2883 MessageDeleted(MessageId),
2884 SummaryGenerated,
2885 SummaryChanged,
2886 UsePendingTools {
2887 tool_uses: Vec<PendingToolUse>,
2888 },
2889 ToolFinished {
2890 #[allow(unused)]
2891 tool_use_id: LanguageModelToolUseId,
2892 /// The pending tool use that corresponds to this tool.
2893 pending_tool_use: Option<PendingToolUse>,
2894 },
2895 CheckpointChanged,
2896 ToolConfirmationNeeded,
2897 ToolUseLimitReached,
2898 CancelEditing,
2899 CompletionCanceled,
2900 ProfileChanged,
2901}
2902
2903impl EventEmitter<ThreadEvent> for Thread {}
2904
2905struct PendingCompletion {
2906 id: usize,
2907 queue_state: QueueState,
2908 _task: Task<()>,
2909}
2910
2911/// Resolves tool name conflicts by ensuring all tool names are unique.
2912///
2913/// When multiple tools have the same name, this function applies the following rules:
2914/// 1. Native tools always keep their original name
2915/// 2. Context server tools get prefixed with their server ID and an underscore
2916/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
2917/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
2918///
2919/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
2920fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
2921 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
2922 let mut tool_name = tool.name();
2923 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2924 tool_name
2925 }
2926
2927 const MAX_TOOL_NAME_LENGTH: usize = 64;
2928
2929 let mut duplicated_tool_names = HashSet::default();
2930 let mut seen_tool_names = HashSet::default();
2931 for tool in tools {
2932 let tool_name = resolve_tool_name(tool);
2933 if seen_tool_names.contains(&tool_name) {
2934 debug_assert!(
2935 tool.source() != assistant_tool::ToolSource::Native,
2936 "There are two built-in tools with the same name: {}",
2937 tool_name
2938 );
2939 duplicated_tool_names.insert(tool_name);
2940 } else {
2941 seen_tool_names.insert(tool_name);
2942 }
2943 }
2944
2945 if duplicated_tool_names.is_empty() {
2946 return tools
2947 .into_iter()
2948 .map(|tool| (resolve_tool_name(tool), tool.clone()))
2949 .collect();
2950 }
2951
2952 tools
2953 .into_iter()
2954 .filter_map(|tool| {
2955 let mut tool_name = resolve_tool_name(tool);
2956 if !duplicated_tool_names.contains(&tool_name) {
2957 return Some((tool_name, tool.clone()));
2958 }
2959 match tool.source() {
2960 assistant_tool::ToolSource::Native => {
2961 // Built-in tools always keep their original name
2962 Some((tool_name, tool.clone()))
2963 }
2964 assistant_tool::ToolSource::ContextServer { id } => {
2965 // Context server tools are prefixed with the context server ID, and truncated if necessary
2966 tool_name.insert(0, '_');
2967 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
2968 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
2969 let mut id = id.to_string();
2970 id.truncate(len);
2971 tool_name.insert_str(0, &id);
2972 } else {
2973 tool_name.insert_str(0, &id);
2974 }
2975
2976 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2977
2978 if seen_tool_names.contains(&tool_name) {
2979 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
2980 None
2981 } else {
2982 Some((tool_name, tool.clone()))
2983 }
2984 }
2985 }
2986 })
2987 .collect()
2988}
2989
2990#[cfg(test)]
2991mod tests {
2992 use super::*;
2993 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2994 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
2995 use assistant_tool::ToolRegistry;
2996 use editor::EditorSettings;
2997 use gpui::TestAppContext;
2998 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2999 use project::{FakeFs, Project};
3000 use prompt_store::PromptBuilder;
3001 use serde_json::json;
3002 use settings::{Settings, SettingsStore};
3003 use std::sync::Arc;
3004 use theme::ThemeSettings;
3005 use ui::IconName;
3006 use util::path;
3007 use workspace::Workspace;
3008
3009 #[gpui::test]
3010 async fn test_message_with_context(cx: &mut TestAppContext) {
3011 init_test_settings(cx);
3012
3013 let project = create_test_project(
3014 cx,
3015 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3016 )
3017 .await;
3018
3019 let (_workspace, _thread_store, thread, context_store, model) =
3020 setup_test_environment(cx, project.clone()).await;
3021
3022 add_file_to_context(&project, &context_store, "test/code.rs", cx)
3023 .await
3024 .unwrap();
3025
3026 let context =
3027 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3028 let loaded_context = cx
3029 .update(|cx| load_context(vec![context], &project, &None, cx))
3030 .await;
3031
3032 // Insert user message with context
3033 let message_id = thread.update(cx, |thread, cx| {
3034 thread.insert_user_message(
3035 "Please explain this code",
3036 loaded_context,
3037 None,
3038 Vec::new(),
3039 cx,
3040 )
3041 });
3042
3043 // Check content and context in message object
3044 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3045
3046 // Use different path format strings based on platform for the test
3047 #[cfg(windows)]
3048 let path_part = r"test\code.rs";
3049 #[cfg(not(windows))]
3050 let path_part = "test/code.rs";
3051
3052 let expected_context = format!(
3053 r#"
3054<context>
3055The following items were attached by the user. They are up-to-date and don't need to be re-read.
3056
3057<files>
3058```rs {path_part}
3059fn main() {{
3060 println!("Hello, world!");
3061}}
3062```
3063</files>
3064</context>
3065"#
3066 );
3067
3068 assert_eq!(message.role, Role::User);
3069 assert_eq!(message.segments.len(), 1);
3070 assert_eq!(
3071 message.segments[0],
3072 MessageSegment::Text("Please explain this code".to_string())
3073 );
3074 assert_eq!(message.loaded_context.text, expected_context);
3075
3076 // Check message in request
3077 let request = thread.update(cx, |thread, cx| {
3078 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3079 });
3080
3081 assert_eq!(request.messages.len(), 2);
3082 let expected_full_message = format!("{}Please explain this code", expected_context);
3083 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3084 }
3085
3086 #[gpui::test]
3087 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3088 init_test_settings(cx);
3089
3090 let project = create_test_project(
3091 cx,
3092 json!({
3093 "file1.rs": "fn function1() {}\n",
3094 "file2.rs": "fn function2() {}\n",
3095 "file3.rs": "fn function3() {}\n",
3096 "file4.rs": "fn function4() {}\n",
3097 }),
3098 )
3099 .await;
3100
3101 let (_, _thread_store, thread, context_store, model) =
3102 setup_test_environment(cx, project.clone()).await;
3103
3104 // First message with context 1
3105 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3106 .await
3107 .unwrap();
3108 let new_contexts = context_store.update(cx, |store, cx| {
3109 store.new_context_for_thread(thread.read(cx), None)
3110 });
3111 assert_eq!(new_contexts.len(), 1);
3112 let loaded_context = cx
3113 .update(|cx| load_context(new_contexts, &project, &None, cx))
3114 .await;
3115 let message1_id = thread.update(cx, |thread, cx| {
3116 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3117 });
3118
3119 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3120 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3121 .await
3122 .unwrap();
3123 let new_contexts = context_store.update(cx, |store, cx| {
3124 store.new_context_for_thread(thread.read(cx), None)
3125 });
3126 assert_eq!(new_contexts.len(), 1);
3127 let loaded_context = cx
3128 .update(|cx| load_context(new_contexts, &project, &None, cx))
3129 .await;
3130 let message2_id = thread.update(cx, |thread, cx| {
3131 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3132 });
3133
3134 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3135 //
3136 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3137 .await
3138 .unwrap();
3139 let new_contexts = context_store.update(cx, |store, cx| {
3140 store.new_context_for_thread(thread.read(cx), None)
3141 });
3142 assert_eq!(new_contexts.len(), 1);
3143 let loaded_context = cx
3144 .update(|cx| load_context(new_contexts, &project, &None, cx))
3145 .await;
3146 let message3_id = thread.update(cx, |thread, cx| {
3147 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3148 });
3149
3150 // Check what contexts are included in each message
3151 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3152 (
3153 thread.message(message1_id).unwrap().clone(),
3154 thread.message(message2_id).unwrap().clone(),
3155 thread.message(message3_id).unwrap().clone(),
3156 )
3157 });
3158
3159 // First message should include context 1
3160 assert!(message1.loaded_context.text.contains("file1.rs"));
3161
3162 // Second message should include only context 2 (not 1)
3163 assert!(!message2.loaded_context.text.contains("file1.rs"));
3164 assert!(message2.loaded_context.text.contains("file2.rs"));
3165
3166 // Third message should include only context 3 (not 1 or 2)
3167 assert!(!message3.loaded_context.text.contains("file1.rs"));
3168 assert!(!message3.loaded_context.text.contains("file2.rs"));
3169 assert!(message3.loaded_context.text.contains("file3.rs"));
3170
3171 // Check entire request to make sure all contexts are properly included
3172 let request = thread.update(cx, |thread, cx| {
3173 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3174 });
3175
3176 // The request should contain all 3 messages
3177 assert_eq!(request.messages.len(), 4);
3178
3179 // Check that the contexts are properly formatted in each message
3180 assert!(request.messages[1].string_contents().contains("file1.rs"));
3181 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3182 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3183
3184 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3185 assert!(request.messages[2].string_contents().contains("file2.rs"));
3186 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3187
3188 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3189 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3190 assert!(request.messages[3].string_contents().contains("file3.rs"));
3191
3192 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3193 .await
3194 .unwrap();
3195 let new_contexts = context_store.update(cx, |store, cx| {
3196 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3197 });
3198 assert_eq!(new_contexts.len(), 3);
3199 let loaded_context = cx
3200 .update(|cx| load_context(new_contexts, &project, &None, cx))
3201 .await
3202 .loaded_context;
3203
3204 assert!(!loaded_context.text.contains("file1.rs"));
3205 assert!(loaded_context.text.contains("file2.rs"));
3206 assert!(loaded_context.text.contains("file3.rs"));
3207 assert!(loaded_context.text.contains("file4.rs"));
3208
3209 let new_contexts = context_store.update(cx, |store, cx| {
3210 // Remove file4.rs
3211 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3212 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3213 });
3214 assert_eq!(new_contexts.len(), 2);
3215 let loaded_context = cx
3216 .update(|cx| load_context(new_contexts, &project, &None, cx))
3217 .await
3218 .loaded_context;
3219
3220 assert!(!loaded_context.text.contains("file1.rs"));
3221 assert!(loaded_context.text.contains("file2.rs"));
3222 assert!(loaded_context.text.contains("file3.rs"));
3223 assert!(!loaded_context.text.contains("file4.rs"));
3224
3225 let new_contexts = context_store.update(cx, |store, cx| {
3226 // Remove file3.rs
3227 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3228 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3229 });
3230 assert_eq!(new_contexts.len(), 1);
3231 let loaded_context = cx
3232 .update(|cx| load_context(new_contexts, &project, &None, cx))
3233 .await
3234 .loaded_context;
3235
3236 assert!(!loaded_context.text.contains("file1.rs"));
3237 assert!(loaded_context.text.contains("file2.rs"));
3238 assert!(!loaded_context.text.contains("file3.rs"));
3239 assert!(!loaded_context.text.contains("file4.rs"));
3240 }
3241
3242 #[gpui::test]
3243 async fn test_message_without_files(cx: &mut TestAppContext) {
3244 init_test_settings(cx);
3245
3246 let project = create_test_project(
3247 cx,
3248 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3249 )
3250 .await;
3251
3252 let (_, _thread_store, thread, _context_store, model) =
3253 setup_test_environment(cx, project.clone()).await;
3254
3255 // Insert user message without any context (empty context vector)
3256 let message_id = thread.update(cx, |thread, cx| {
3257 thread.insert_user_message(
3258 "What is the best way to learn Rust?",
3259 ContextLoadResult::default(),
3260 None,
3261 Vec::new(),
3262 cx,
3263 )
3264 });
3265
3266 // Check content and context in message object
3267 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3268
3269 // Context should be empty when no files are included
3270 assert_eq!(message.role, Role::User);
3271 assert_eq!(message.segments.len(), 1);
3272 assert_eq!(
3273 message.segments[0],
3274 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3275 );
3276 assert_eq!(message.loaded_context.text, "");
3277
3278 // Check message in request
3279 let request = thread.update(cx, |thread, cx| {
3280 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3281 });
3282
3283 assert_eq!(request.messages.len(), 2);
3284 assert_eq!(
3285 request.messages[1].string_contents(),
3286 "What is the best way to learn Rust?"
3287 );
3288
3289 // Add second message, also without context
3290 let message2_id = thread.update(cx, |thread, cx| {
3291 thread.insert_user_message(
3292 "Are there any good books?",
3293 ContextLoadResult::default(),
3294 None,
3295 Vec::new(),
3296 cx,
3297 )
3298 });
3299
3300 let message2 =
3301 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3302 assert_eq!(message2.loaded_context.text, "");
3303
3304 // Check that both messages appear in the request
3305 let request = thread.update(cx, |thread, cx| {
3306 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3307 });
3308
3309 assert_eq!(request.messages.len(), 3);
3310 assert_eq!(
3311 request.messages[1].string_contents(),
3312 "What is the best way to learn Rust?"
3313 );
3314 assert_eq!(
3315 request.messages[2].string_contents(),
3316 "Are there any good books?"
3317 );
3318 }
3319
3320 #[gpui::test]
3321 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3322 init_test_settings(cx);
3323
3324 let project = create_test_project(
3325 cx,
3326 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3327 )
3328 .await;
3329
3330 let (_workspace, _thread_store, thread, context_store, model) =
3331 setup_test_environment(cx, project.clone()).await;
3332
3333 // Open buffer and add it to context
3334 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3335 .await
3336 .unwrap();
3337
3338 let context =
3339 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3340 let loaded_context = cx
3341 .update(|cx| load_context(vec![context], &project, &None, cx))
3342 .await;
3343
3344 // Insert user message with the buffer as context
3345 thread.update(cx, |thread, cx| {
3346 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3347 });
3348
3349 // Create a request and check that it doesn't have a stale buffer warning yet
3350 let initial_request = thread.update(cx, |thread, cx| {
3351 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3352 });
3353
3354 // Make sure we don't have a stale file warning yet
3355 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3356 msg.string_contents()
3357 .contains("These files changed since last read:")
3358 });
3359 assert!(
3360 !has_stale_warning,
3361 "Should not have stale buffer warning before buffer is modified"
3362 );
3363
3364 // Modify the buffer
3365 buffer.update(cx, |buffer, cx| {
3366 // Find a position at the end of line 1
3367 buffer.edit(
3368 [(1..1, "\n println!(\"Added a new line\");\n")],
3369 None,
3370 cx,
3371 );
3372 });
3373
3374 // Insert another user message without context
3375 thread.update(cx, |thread, cx| {
3376 thread.insert_user_message(
3377 "What does the code do now?",
3378 ContextLoadResult::default(),
3379 None,
3380 Vec::new(),
3381 cx,
3382 )
3383 });
3384
3385 // Create a new request and check for the stale buffer warning
3386 let new_request = thread.update(cx, |thread, cx| {
3387 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3388 });
3389
3390 // We should have a stale file warning as the last message
3391 let last_message = new_request
3392 .messages
3393 .last()
3394 .expect("Request should have messages");
3395
3396 // The last message should be the stale buffer notification
3397 assert_eq!(last_message.role, Role::User);
3398
3399 // Check the exact content of the message
3400 let expected_content = "These files changed since last read:\n- code.rs\n";
3401 assert_eq!(
3402 last_message.string_contents(),
3403 expected_content,
3404 "Last message should be exactly the stale buffer notification"
3405 );
3406 }
3407
3408 #[gpui::test]
3409 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3410 init_test_settings(cx);
3411
3412 let project = create_test_project(
3413 cx,
3414 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3415 )
3416 .await;
3417
3418 let (_workspace, thread_store, thread, _context_store, _model) =
3419 setup_test_environment(cx, project.clone()).await;
3420
3421 // Check that we are starting with the default profile
3422 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3423 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3424 assert_eq!(
3425 profile,
3426 AgentProfile::new(AgentProfileId::default(), tool_set)
3427 );
3428 }
3429
3430 #[gpui::test]
3431 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3432 init_test_settings(cx);
3433
3434 let project = create_test_project(
3435 cx,
3436 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3437 )
3438 .await;
3439
3440 let (_workspace, thread_store, thread, _context_store, _model) =
3441 setup_test_environment(cx, project.clone()).await;
3442
3443 // Profile gets serialized with default values
3444 let serialized = thread
3445 .update(cx, |thread, cx| thread.serialize(cx))
3446 .await
3447 .unwrap();
3448
3449 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3450
3451 let deserialized = cx.update(|cx| {
3452 thread.update(cx, |thread, cx| {
3453 Thread::deserialize(
3454 thread.id.clone(),
3455 serialized,
3456 thread.project.clone(),
3457 thread.tools.clone(),
3458 thread.prompt_builder.clone(),
3459 thread.project_context.clone(),
3460 None,
3461 cx,
3462 )
3463 })
3464 });
3465 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3466
3467 assert_eq!(
3468 deserialized.profile,
3469 AgentProfile::new(AgentProfileId::default(), tool_set)
3470 );
3471 }
3472
3473 #[gpui::test]
3474 async fn test_temperature_setting(cx: &mut TestAppContext) {
3475 init_test_settings(cx);
3476
3477 let project = create_test_project(
3478 cx,
3479 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3480 )
3481 .await;
3482
3483 let (_workspace, _thread_store, thread, _context_store, model) =
3484 setup_test_environment(cx, project.clone()).await;
3485
3486 // Both model and provider
3487 cx.update(|cx| {
3488 AgentSettings::override_global(
3489 AgentSettings {
3490 model_parameters: vec![LanguageModelParameters {
3491 provider: Some(model.provider_id().0.to_string().into()),
3492 model: Some(model.id().0.clone()),
3493 temperature: Some(0.66),
3494 }],
3495 ..AgentSettings::get_global(cx).clone()
3496 },
3497 cx,
3498 );
3499 });
3500
3501 let request = thread.update(cx, |thread, cx| {
3502 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3503 });
3504 assert_eq!(request.temperature, Some(0.66));
3505
3506 // Only model
3507 cx.update(|cx| {
3508 AgentSettings::override_global(
3509 AgentSettings {
3510 model_parameters: vec![LanguageModelParameters {
3511 provider: None,
3512 model: Some(model.id().0.clone()),
3513 temperature: Some(0.66),
3514 }],
3515 ..AgentSettings::get_global(cx).clone()
3516 },
3517 cx,
3518 );
3519 });
3520
3521 let request = thread.update(cx, |thread, cx| {
3522 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3523 });
3524 assert_eq!(request.temperature, Some(0.66));
3525
3526 // Only provider
3527 cx.update(|cx| {
3528 AgentSettings::override_global(
3529 AgentSettings {
3530 model_parameters: vec![LanguageModelParameters {
3531 provider: Some(model.provider_id().0.to_string().into()),
3532 model: None,
3533 temperature: Some(0.66),
3534 }],
3535 ..AgentSettings::get_global(cx).clone()
3536 },
3537 cx,
3538 );
3539 });
3540
3541 let request = thread.update(cx, |thread, cx| {
3542 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3543 });
3544 assert_eq!(request.temperature, Some(0.66));
3545
3546 // Same model name, different provider
3547 cx.update(|cx| {
3548 AgentSettings::override_global(
3549 AgentSettings {
3550 model_parameters: vec![LanguageModelParameters {
3551 provider: Some("anthropic".into()),
3552 model: Some(model.id().0.clone()),
3553 temperature: Some(0.66),
3554 }],
3555 ..AgentSettings::get_global(cx).clone()
3556 },
3557 cx,
3558 );
3559 });
3560
3561 let request = thread.update(cx, |thread, cx| {
3562 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3563 });
3564 assert_eq!(request.temperature, None);
3565 }
3566
3567 #[gpui::test]
3568 async fn test_thread_summary(cx: &mut TestAppContext) {
3569 init_test_settings(cx);
3570
3571 let project = create_test_project(cx, json!({})).await;
3572
3573 let (_, _thread_store, thread, _context_store, model) =
3574 setup_test_environment(cx, project.clone()).await;
3575
3576 // Initial state should be pending
3577 thread.read_with(cx, |thread, _| {
3578 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3579 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3580 });
3581
3582 // Manually setting the summary should not be allowed in this state
3583 thread.update(cx, |thread, cx| {
3584 thread.set_summary("This should not work", cx);
3585 });
3586
3587 thread.read_with(cx, |thread, _| {
3588 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3589 });
3590
3591 // Send a message
3592 thread.update(cx, |thread, cx| {
3593 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3594 thread.send_to_model(
3595 model.clone(),
3596 CompletionIntent::ThreadSummarization,
3597 None,
3598 cx,
3599 );
3600 });
3601
3602 let fake_model = model.as_fake();
3603 simulate_successful_response(&fake_model, cx);
3604
3605 // Should start generating summary when there are >= 2 messages
3606 thread.read_with(cx, |thread, _| {
3607 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3608 });
3609
3610 // Should not be able to set the summary while generating
3611 thread.update(cx, |thread, cx| {
3612 thread.set_summary("This should not work either", cx);
3613 });
3614
3615 thread.read_with(cx, |thread, _| {
3616 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3617 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3618 });
3619
3620 cx.run_until_parked();
3621 fake_model.stream_last_completion_response("Brief");
3622 fake_model.stream_last_completion_response(" Introduction");
3623 fake_model.end_last_completion_stream();
3624 cx.run_until_parked();
3625
3626 // Summary should be set
3627 thread.read_with(cx, |thread, _| {
3628 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3629 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3630 });
3631
3632 // Now we should be able to set a summary
3633 thread.update(cx, |thread, cx| {
3634 thread.set_summary("Brief Intro", cx);
3635 });
3636
3637 thread.read_with(cx, |thread, _| {
3638 assert_eq!(thread.summary().or_default(), "Brief Intro");
3639 });
3640
3641 // Test setting an empty summary (should default to DEFAULT)
3642 thread.update(cx, |thread, cx| {
3643 thread.set_summary("", cx);
3644 });
3645
3646 thread.read_with(cx, |thread, _| {
3647 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3648 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3649 });
3650 }
3651
3652 #[gpui::test]
3653 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3654 init_test_settings(cx);
3655
3656 let project = create_test_project(cx, json!({})).await;
3657
3658 let (_, _thread_store, thread, _context_store, model) =
3659 setup_test_environment(cx, project.clone()).await;
3660
3661 test_summarize_error(&model, &thread, cx);
3662
3663 // Now we should be able to set a summary
3664 thread.update(cx, |thread, cx| {
3665 thread.set_summary("Brief Intro", cx);
3666 });
3667
3668 thread.read_with(cx, |thread, _| {
3669 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3670 assert_eq!(thread.summary().or_default(), "Brief Intro");
3671 });
3672 }
3673
3674 #[gpui::test]
3675 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3676 init_test_settings(cx);
3677
3678 let project = create_test_project(cx, json!({})).await;
3679
3680 let (_, _thread_store, thread, _context_store, model) =
3681 setup_test_environment(cx, project.clone()).await;
3682
3683 test_summarize_error(&model, &thread, cx);
3684
3685 // Sending another message should not trigger another summarize request
3686 thread.update(cx, |thread, cx| {
3687 thread.insert_user_message(
3688 "How are you?",
3689 ContextLoadResult::default(),
3690 None,
3691 vec![],
3692 cx,
3693 );
3694 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3695 });
3696
3697 let fake_model = model.as_fake();
3698 simulate_successful_response(&fake_model, cx);
3699
3700 thread.read_with(cx, |thread, _| {
3701 // State is still Error, not Generating
3702 assert!(matches!(thread.summary(), ThreadSummary::Error));
3703 });
3704
3705 // But the summarize request can be invoked manually
3706 thread.update(cx, |thread, cx| {
3707 thread.summarize(cx);
3708 });
3709
3710 thread.read_with(cx, |thread, _| {
3711 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3712 });
3713
3714 cx.run_until_parked();
3715 fake_model.stream_last_completion_response("A successful summary");
3716 fake_model.end_last_completion_stream();
3717 cx.run_until_parked();
3718
3719 thread.read_with(cx, |thread, _| {
3720 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3721 assert_eq!(thread.summary().or_default(), "A successful summary");
3722 });
3723 }
3724
3725 #[gpui::test]
3726 fn test_resolve_tool_name_conflicts() {
3727 use assistant_tool::{Tool, ToolSource};
3728
3729 assert_resolve_tool_name_conflicts(
3730 vec![
3731 TestTool::new("tool1", ToolSource::Native),
3732 TestTool::new("tool2", ToolSource::Native),
3733 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3734 ],
3735 vec!["tool1", "tool2", "tool3"],
3736 );
3737
3738 assert_resolve_tool_name_conflicts(
3739 vec![
3740 TestTool::new("tool1", ToolSource::Native),
3741 TestTool::new("tool2", ToolSource::Native),
3742 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3743 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3744 ],
3745 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3746 );
3747
3748 assert_resolve_tool_name_conflicts(
3749 vec![
3750 TestTool::new("tool1", ToolSource::Native),
3751 TestTool::new("tool2", ToolSource::Native),
3752 TestTool::new("tool3", ToolSource::Native),
3753 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3754 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3755 ],
3756 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3757 );
3758
3759 // Test that tool with very long name is always truncated
3760 assert_resolve_tool_name_conflicts(
3761 vec![TestTool::new(
3762 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3763 ToolSource::Native,
3764 )],
3765 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3766 );
3767
3768 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3769 assert_resolve_tool_name_conflicts(
3770 vec![
3771 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3772 TestTool::new(
3773 "tool-with-very-very-very-long-name",
3774 ToolSource::ContextServer {
3775 id: "mcp-with-very-very-very-long-name".into(),
3776 },
3777 ),
3778 ],
3779 vec![
3780 "tool-with-very-very-very-long-name",
3781 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3782 ],
3783 );
3784
3785 fn assert_resolve_tool_name_conflicts(
3786 tools: Vec<TestTool>,
3787 expected: Vec<impl Into<String>>,
3788 ) {
3789 let tools: Vec<Arc<dyn Tool>> = tools
3790 .into_iter()
3791 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3792 .collect();
3793 let tools = resolve_tool_name_conflicts(&tools);
3794 assert_eq!(tools.len(), expected.len());
3795 for (i, expected_name) in expected.into_iter().enumerate() {
3796 let expected_name = expected_name.into();
3797 let actual_name = &tools[i].0;
3798 assert_eq!(
3799 actual_name, &expected_name,
3800 "Expected '{}' got '{}' at index {}",
3801 expected_name, actual_name, i
3802 );
3803 }
3804 }
3805
3806 struct TestTool {
3807 name: String,
3808 source: ToolSource,
3809 }
3810
3811 impl TestTool {
3812 fn new(name: impl Into<String>, source: ToolSource) -> Self {
3813 Self {
3814 name: name.into(),
3815 source,
3816 }
3817 }
3818 }
3819
3820 impl Tool for TestTool {
3821 fn name(&self) -> String {
3822 self.name.clone()
3823 }
3824
3825 fn icon(&self) -> IconName {
3826 IconName::Ai
3827 }
3828
3829 fn may_perform_edits(&self) -> bool {
3830 false
3831 }
3832
3833 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3834 true
3835 }
3836
3837 fn source(&self) -> ToolSource {
3838 self.source.clone()
3839 }
3840
3841 fn description(&self) -> String {
3842 "Test tool".to_string()
3843 }
3844
3845 fn ui_text(&self, _input: &serde_json::Value) -> String {
3846 "Test tool".to_string()
3847 }
3848
3849 fn run(
3850 self: Arc<Self>,
3851 _input: serde_json::Value,
3852 _request: Arc<LanguageModelRequest>,
3853 _project: Entity<Project>,
3854 _action_log: Entity<ActionLog>,
3855 _model: Arc<dyn LanguageModel>,
3856 _window: Option<AnyWindowHandle>,
3857 _cx: &mut App,
3858 ) -> assistant_tool::ToolResult {
3859 assistant_tool::ToolResult {
3860 output: Task::ready(Err(anyhow::anyhow!("No content"))),
3861 card: None,
3862 }
3863 }
3864 }
3865 }
3866
3867 fn test_summarize_error(
3868 model: &Arc<dyn LanguageModel>,
3869 thread: &Entity<Thread>,
3870 cx: &mut TestAppContext,
3871 ) {
3872 thread.update(cx, |thread, cx| {
3873 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3874 thread.send_to_model(
3875 model.clone(),
3876 CompletionIntent::ThreadSummarization,
3877 None,
3878 cx,
3879 );
3880 });
3881
3882 let fake_model = model.as_fake();
3883 simulate_successful_response(&fake_model, cx);
3884
3885 thread.read_with(cx, |thread, _| {
3886 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3887 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3888 });
3889
3890 // Simulate summary request ending
3891 cx.run_until_parked();
3892 fake_model.end_last_completion_stream();
3893 cx.run_until_parked();
3894
3895 // State is set to Error and default message
3896 thread.read_with(cx, |thread, _| {
3897 assert!(matches!(thread.summary(), ThreadSummary::Error));
3898 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3899 });
3900 }
3901
3902 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3903 cx.run_until_parked();
3904 fake_model.stream_last_completion_response("Assistant response");
3905 fake_model.end_last_completion_stream();
3906 cx.run_until_parked();
3907 }
3908
3909 fn init_test_settings(cx: &mut TestAppContext) {
3910 cx.update(|cx| {
3911 let settings_store = SettingsStore::test(cx);
3912 cx.set_global(settings_store);
3913 language::init(cx);
3914 Project::init_settings(cx);
3915 AgentSettings::register(cx);
3916 prompt_store::init(cx);
3917 thread_store::init(cx);
3918 workspace::init_settings(cx);
3919 language_model::init_settings(cx);
3920 ThemeSettings::register(cx);
3921 EditorSettings::register(cx);
3922 ToolRegistry::default_global(cx);
3923 });
3924 }
3925
3926 // Helper to create a test project with test files
3927 async fn create_test_project(
3928 cx: &mut TestAppContext,
3929 files: serde_json::Value,
3930 ) -> Entity<Project> {
3931 let fs = FakeFs::new(cx.executor());
3932 fs.insert_tree(path!("/test"), files).await;
3933 Project::test(fs, [path!("/test").as_ref()], cx).await
3934 }
3935
3936 async fn setup_test_environment(
3937 cx: &mut TestAppContext,
3938 project: Entity<Project>,
3939 ) -> (
3940 Entity<Workspace>,
3941 Entity<ThreadStore>,
3942 Entity<Thread>,
3943 Entity<ContextStore>,
3944 Arc<dyn LanguageModel>,
3945 ) {
3946 let (workspace, cx) =
3947 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3948
3949 let thread_store = cx
3950 .update(|_, cx| {
3951 ThreadStore::load(
3952 project.clone(),
3953 cx.new(|_| ToolWorkingSet::default()),
3954 None,
3955 Arc::new(PromptBuilder::new(None).unwrap()),
3956 cx,
3957 )
3958 })
3959 .await
3960 .unwrap();
3961
3962 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3963 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3964
3965 let provider = Arc::new(FakeLanguageModelProvider);
3966 let model = provider.test_model();
3967 let model: Arc<dyn LanguageModel> = Arc::new(model);
3968
3969 cx.update(|_, cx| {
3970 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3971 registry.set_default_model(
3972 Some(ConfiguredModel {
3973 provider: provider.clone(),
3974 model: model.clone(),
3975 }),
3976 cx,
3977 );
3978 registry.set_thread_summary_model(
3979 Some(ConfiguredModel {
3980 provider,
3981 model: model.clone(),
3982 }),
3983 cx,
3984 );
3985 })
3986 });
3987
3988 (workspace, thread_store, thread, context_store, model)
3989 }
3990
3991 async fn add_file_to_context(
3992 project: &Entity<Project>,
3993 context_store: &Entity<ContextStore>,
3994 path: &str,
3995 cx: &mut TestAppContext,
3996 ) -> Result<Entity<language::Buffer>> {
3997 let buffer_path = project
3998 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3999 .unwrap();
4000
4001 let buffer = project
4002 .update(cx, |project, cx| {
4003 project.open_buffer(buffer_path.clone(), cx)
4004 })
4005 .await
4006 .unwrap();
4007
4008 context_store.update(cx, |context_store, cx| {
4009 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
4010 });
4011
4012 Ok(buffer)
4013 }
4014}