1use std::io::Write;
2use std::ops::Range;
3use std::sync::Arc;
4use std::time::Instant;
5
6use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
7use anyhow::{Result, anyhow};
8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
9use chrono::{DateTime, Utc};
10use collections::{HashMap, HashSet};
11use editor::display_map::CreaseMetadata;
12use feature_flags::{self, FeatureFlagAppExt};
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
26 StopReason, TokenUsage,
27};
28use postage::stream::Stream as _;
29use project::Project;
30use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
31use prompt_store::{ModelContext, PromptBuilder};
32use proto::Plan;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use thiserror::Error;
37use ui::Window;
38use util::{ResultExt as _, post_inc};
39
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::agent_profile::AgentProfile;
45use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
46use crate::thread_store::{
47 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
48 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
49};
50use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
51
52#[derive(
53 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
54)]
55pub struct ThreadId(Arc<str>);
56
57impl ThreadId {
58 pub fn new() -> Self {
59 Self(Uuid::new_v4().to_string().into())
60 }
61}
62
63impl std::fmt::Display for ThreadId {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 write!(f, "{}", self.0)
66 }
67}
68
69impl From<&str> for ThreadId {
70 fn from(value: &str) -> Self {
71 Self(value.into())
72 }
73}
74
75/// The ID of the user prompt that initiated a request.
76///
77/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
78#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
79pub struct PromptId(Arc<str>);
80
81impl PromptId {
82 pub fn new() -> Self {
83 Self(Uuid::new_v4().to_string().into())
84 }
85}
86
87impl std::fmt::Display for PromptId {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(f, "{}", self.0)
90 }
91}
92
93#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
94pub struct MessageId(pub(crate) usize);
95
96impl MessageId {
97 fn post_inc(&mut self) -> Self {
98 Self(post_inc(&mut self.0))
99 }
100}
101
102/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
103#[derive(Clone, Debug)]
104pub struct MessageCrease {
105 pub range: Range<usize>,
106 pub metadata: CreaseMetadata,
107 /// None for a deserialized message, Some otherwise.
108 pub context: Option<AgentContextHandle>,
109}
110
111/// A message in a [`Thread`].
112#[derive(Debug, Clone)]
113pub struct Message {
114 pub id: MessageId,
115 pub role: Role,
116 pub segments: Vec<MessageSegment>,
117 pub loaded_context: LoadedContext,
118 pub creases: Vec<MessageCrease>,
119 pub is_hidden: bool,
120}
121
122impl Message {
123 /// Returns whether the message contains any meaningful text that should be displayed
124 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
125 pub fn should_display_content(&self) -> bool {
126 self.segments.iter().all(|segment| segment.should_display())
127 }
128
129 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
130 if let Some(MessageSegment::Thinking {
131 text: segment,
132 signature: current_signature,
133 }) = self.segments.last_mut()
134 {
135 if let Some(signature) = signature {
136 *current_signature = Some(signature);
137 }
138 segment.push_str(text);
139 } else {
140 self.segments.push(MessageSegment::Thinking {
141 text: text.to_string(),
142 signature,
143 });
144 }
145 }
146
147 pub fn push_redacted_thinking(&mut self, data: String) {
148 self.segments.push(MessageSegment::RedactedThinking(data));
149 }
150
151 pub fn push_text(&mut self, text: &str) {
152 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
153 segment.push_str(text);
154 } else {
155 self.segments.push(MessageSegment::Text(text.to_string()));
156 }
157 }
158
159 pub fn to_string(&self) -> String {
160 let mut result = String::new();
161
162 if !self.loaded_context.text.is_empty() {
163 result.push_str(&self.loaded_context.text);
164 }
165
166 for segment in &self.segments {
167 match segment {
168 MessageSegment::Text(text) => result.push_str(text),
169 MessageSegment::Thinking { text, .. } => {
170 result.push_str("<think>\n");
171 result.push_str(text);
172 result.push_str("\n</think>");
173 }
174 MessageSegment::RedactedThinking(_) => {}
175 }
176 }
177
178 result
179 }
180}
181
182#[derive(Debug, Clone, PartialEq, Eq)]
183pub enum MessageSegment {
184 Text(String),
185 Thinking {
186 text: String,
187 signature: Option<String>,
188 },
189 RedactedThinking(String),
190}
191
192impl MessageSegment {
193 pub fn should_display(&self) -> bool {
194 match self {
195 Self::Text(text) => text.is_empty(),
196 Self::Thinking { text, .. } => text.is_empty(),
197 Self::RedactedThinking(_) => false,
198 }
199 }
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
203pub struct ProjectSnapshot {
204 pub worktree_snapshots: Vec<WorktreeSnapshot>,
205 pub unsaved_buffer_paths: Vec<String>,
206 pub timestamp: DateTime<Utc>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
210pub struct WorktreeSnapshot {
211 pub worktree_path: String,
212 pub git_state: Option<GitState>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
216pub struct GitState {
217 pub remote_url: Option<String>,
218 pub head_sha: Option<String>,
219 pub current_branch: Option<String>,
220 pub diff: Option<String>,
221}
222
223#[derive(Clone, Debug)]
224pub struct ThreadCheckpoint {
225 message_id: MessageId,
226 git_checkpoint: GitStoreCheckpoint,
227}
228
229#[derive(Copy, Clone, Debug, PartialEq, Eq)]
230pub enum ThreadFeedback {
231 Positive,
232 Negative,
233}
234
235pub enum LastRestoreCheckpoint {
236 Pending {
237 message_id: MessageId,
238 },
239 Error {
240 message_id: MessageId,
241 error: String,
242 },
243}
244
245impl LastRestoreCheckpoint {
246 pub fn message_id(&self) -> MessageId {
247 match self {
248 LastRestoreCheckpoint::Pending { message_id } => *message_id,
249 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
250 }
251 }
252}
253
254#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
255pub enum DetailedSummaryState {
256 #[default]
257 NotGenerated,
258 Generating {
259 message_id: MessageId,
260 },
261 Generated {
262 text: SharedString,
263 message_id: MessageId,
264 },
265}
266
267impl DetailedSummaryState {
268 fn text(&self) -> Option<SharedString> {
269 if let Self::Generated { text, .. } = self {
270 Some(text.clone())
271 } else {
272 None
273 }
274 }
275}
276
277#[derive(Default, Debug)]
278pub struct TotalTokenUsage {
279 pub total: u64,
280 pub max: u64,
281}
282
283impl TotalTokenUsage {
284 pub fn ratio(&self) -> TokenUsageRatio {
285 #[cfg(debug_assertions)]
286 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
287 .unwrap_or("0.8".to_string())
288 .parse()
289 .unwrap();
290 #[cfg(not(debug_assertions))]
291 let warning_threshold: f32 = 0.8;
292
293 // When the maximum is unknown because there is no selected model,
294 // avoid showing the token limit warning.
295 if self.max == 0 {
296 TokenUsageRatio::Normal
297 } else if self.total >= self.max {
298 TokenUsageRatio::Exceeded
299 } else if self.total as f32 / self.max as f32 >= warning_threshold {
300 TokenUsageRatio::Warning
301 } else {
302 TokenUsageRatio::Normal
303 }
304 }
305
306 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
307 TotalTokenUsage {
308 total: self.total + tokens,
309 max: self.max,
310 }
311 }
312}
313
314#[derive(Debug, Default, PartialEq, Eq)]
315pub enum TokenUsageRatio {
316 #[default]
317 Normal,
318 Warning,
319 Exceeded,
320}
321
322#[derive(Debug, Clone, Copy)]
323pub enum QueueState {
324 Sending,
325 Queued { position: usize },
326 Started,
327}
328
329/// A thread of conversation with the LLM.
330pub struct Thread {
331 id: ThreadId,
332 updated_at: DateTime<Utc>,
333 summary: ThreadSummary,
334 pending_summary: Task<Option<()>>,
335 detailed_summary_task: Task<Option<()>>,
336 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
337 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
338 completion_mode: agent_settings::CompletionMode,
339 messages: Vec<Message>,
340 next_message_id: MessageId,
341 last_prompt_id: PromptId,
342 project_context: SharedProjectContext,
343 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
344 completion_count: usize,
345 pending_completions: Vec<PendingCompletion>,
346 project: Entity<Project>,
347 prompt_builder: Arc<PromptBuilder>,
348 tools: Entity<ToolWorkingSet>,
349 tool_use: ToolUseState,
350 action_log: Entity<ActionLog>,
351 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
352 pending_checkpoint: Option<ThreadCheckpoint>,
353 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
354 request_token_usage: Vec<TokenUsage>,
355 cumulative_token_usage: TokenUsage,
356 exceeded_window_error: Option<ExceededWindowError>,
357 last_usage: Option<RequestUsage>,
358 tool_use_limit_reached: bool,
359 feedback: Option<ThreadFeedback>,
360 message_feedback: HashMap<MessageId, ThreadFeedback>,
361 last_auto_capture_at: Option<Instant>,
362 last_received_chunk_at: Option<Instant>,
363 request_callback: Option<
364 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
365 >,
366 remaining_turns: u32,
367 configured_model: Option<ConfiguredModel>,
368 profile: AgentProfile,
369}
370
371#[derive(Clone, Debug, PartialEq, Eq)]
372pub enum ThreadSummary {
373 Pending,
374 Generating,
375 Ready(SharedString),
376 Error,
377}
378
379impl ThreadSummary {
380 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
381
382 pub fn or_default(&self) -> SharedString {
383 self.unwrap_or(Self::DEFAULT)
384 }
385
386 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
387 self.ready().unwrap_or_else(|| message.into())
388 }
389
390 pub fn ready(&self) -> Option<SharedString> {
391 match self {
392 ThreadSummary::Ready(summary) => Some(summary.clone()),
393 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
394 }
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
399pub struct ExceededWindowError {
400 /// Model used when last message exceeded context window
401 model_id: LanguageModelId,
402 /// Token count including last message
403 token_count: u64,
404}
405
406impl Thread {
407 pub fn new(
408 project: Entity<Project>,
409 tools: Entity<ToolWorkingSet>,
410 prompt_builder: Arc<PromptBuilder>,
411 system_prompt: SharedProjectContext,
412 cx: &mut Context<Self>,
413 ) -> Self {
414 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
415 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
416 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
417
418 Self {
419 id: ThreadId::new(),
420 updated_at: Utc::now(),
421 summary: ThreadSummary::Pending,
422 pending_summary: Task::ready(None),
423 detailed_summary_task: Task::ready(None),
424 detailed_summary_tx,
425 detailed_summary_rx,
426 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
427 messages: Vec::new(),
428 next_message_id: MessageId(0),
429 last_prompt_id: PromptId::new(),
430 project_context: system_prompt,
431 checkpoints_by_message: HashMap::default(),
432 completion_count: 0,
433 pending_completions: Vec::new(),
434 project: project.clone(),
435 prompt_builder,
436 tools: tools.clone(),
437 last_restore_checkpoint: None,
438 pending_checkpoint: None,
439 tool_use: ToolUseState::new(tools.clone()),
440 action_log: cx.new(|_| ActionLog::new(project.clone())),
441 initial_project_snapshot: {
442 let project_snapshot = Self::project_snapshot(project, cx);
443 cx.foreground_executor()
444 .spawn(async move { Some(project_snapshot.await) })
445 .shared()
446 },
447 request_token_usage: Vec::new(),
448 cumulative_token_usage: TokenUsage::default(),
449 exceeded_window_error: None,
450 last_usage: None,
451 tool_use_limit_reached: false,
452 feedback: None,
453 message_feedback: HashMap::default(),
454 last_auto_capture_at: None,
455 last_received_chunk_at: None,
456 request_callback: None,
457 remaining_turns: u32::MAX,
458 configured_model,
459 profile: AgentProfile::new(profile_id, tools),
460 }
461 }
462
463 pub fn deserialize(
464 id: ThreadId,
465 serialized: SerializedThread,
466 project: Entity<Project>,
467 tools: Entity<ToolWorkingSet>,
468 prompt_builder: Arc<PromptBuilder>,
469 project_context: SharedProjectContext,
470 window: Option<&mut Window>, // None in headless mode
471 cx: &mut Context<Self>,
472 ) -> Self {
473 let next_message_id = MessageId(
474 serialized
475 .messages
476 .last()
477 .map(|message| message.id.0 + 1)
478 .unwrap_or(0),
479 );
480 let tool_use = ToolUseState::from_serialized_messages(
481 tools.clone(),
482 &serialized.messages,
483 project.clone(),
484 window,
485 cx,
486 );
487 let (detailed_summary_tx, detailed_summary_rx) =
488 postage::watch::channel_with(serialized.detailed_summary_state);
489
490 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
491 serialized
492 .model
493 .and_then(|model| {
494 let model = SelectedModel {
495 provider: model.provider.clone().into(),
496 model: model.model.clone().into(),
497 };
498 registry.select_model(&model, cx)
499 })
500 .or_else(|| registry.default_model())
501 });
502
503 let completion_mode = serialized
504 .completion_mode
505 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
506 let profile_id = serialized
507 .profile
508 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
509
510 Self {
511 id,
512 updated_at: serialized.updated_at,
513 summary: ThreadSummary::Ready(serialized.summary),
514 pending_summary: Task::ready(None),
515 detailed_summary_task: Task::ready(None),
516 detailed_summary_tx,
517 detailed_summary_rx,
518 completion_mode,
519 messages: serialized
520 .messages
521 .into_iter()
522 .map(|message| Message {
523 id: message.id,
524 role: message.role,
525 segments: message
526 .segments
527 .into_iter()
528 .map(|segment| match segment {
529 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
530 SerializedMessageSegment::Thinking { text, signature } => {
531 MessageSegment::Thinking { text, signature }
532 }
533 SerializedMessageSegment::RedactedThinking { data } => {
534 MessageSegment::RedactedThinking(data)
535 }
536 })
537 .collect(),
538 loaded_context: LoadedContext {
539 contexts: Vec::new(),
540 text: message.context,
541 images: Vec::new(),
542 },
543 creases: message
544 .creases
545 .into_iter()
546 .map(|crease| MessageCrease {
547 range: crease.start..crease.end,
548 metadata: CreaseMetadata {
549 icon_path: crease.icon_path,
550 label: crease.label,
551 },
552 context: None,
553 })
554 .collect(),
555 is_hidden: message.is_hidden,
556 })
557 .collect(),
558 next_message_id,
559 last_prompt_id: PromptId::new(),
560 project_context,
561 checkpoints_by_message: HashMap::default(),
562 completion_count: 0,
563 pending_completions: Vec::new(),
564 last_restore_checkpoint: None,
565 pending_checkpoint: None,
566 project: project.clone(),
567 prompt_builder,
568 tools: tools.clone(),
569 tool_use,
570 action_log: cx.new(|_| ActionLog::new(project)),
571 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
572 request_token_usage: serialized.request_token_usage,
573 cumulative_token_usage: serialized.cumulative_token_usage,
574 exceeded_window_error: None,
575 last_usage: None,
576 tool_use_limit_reached: serialized.tool_use_limit_reached,
577 feedback: None,
578 message_feedback: HashMap::default(),
579 last_auto_capture_at: None,
580 last_received_chunk_at: None,
581 request_callback: None,
582 remaining_turns: u32::MAX,
583 configured_model,
584 profile: AgentProfile::new(profile_id, tools),
585 }
586 }
587
588 pub fn set_request_callback(
589 &mut self,
590 callback: impl 'static
591 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
592 ) {
593 self.request_callback = Some(Box::new(callback));
594 }
595
596 pub fn id(&self) -> &ThreadId {
597 &self.id
598 }
599
600 pub fn profile(&self) -> &AgentProfile {
601 &self.profile
602 }
603
604 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
605 if &id != self.profile.id() {
606 self.profile = AgentProfile::new(id, self.tools.clone());
607 cx.emit(ThreadEvent::ProfileChanged);
608 }
609 }
610
611 pub fn is_empty(&self) -> bool {
612 self.messages.is_empty()
613 }
614
615 pub fn updated_at(&self) -> DateTime<Utc> {
616 self.updated_at
617 }
618
619 pub fn touch_updated_at(&mut self) {
620 self.updated_at = Utc::now();
621 }
622
623 pub fn advance_prompt_id(&mut self) {
624 self.last_prompt_id = PromptId::new();
625 }
626
627 pub fn project_context(&self) -> SharedProjectContext {
628 self.project_context.clone()
629 }
630
631 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
632 if self.configured_model.is_none() {
633 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
634 }
635 self.configured_model.clone()
636 }
637
638 pub fn configured_model(&self) -> Option<ConfiguredModel> {
639 self.configured_model.clone()
640 }
641
642 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
643 self.configured_model = model;
644 cx.notify();
645 }
646
647 pub fn summary(&self) -> &ThreadSummary {
648 &self.summary
649 }
650
651 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
652 let current_summary = match &self.summary {
653 ThreadSummary::Pending | ThreadSummary::Generating => return,
654 ThreadSummary::Ready(summary) => summary,
655 ThreadSummary::Error => &ThreadSummary::DEFAULT,
656 };
657
658 let mut new_summary = new_summary.into();
659
660 if new_summary.is_empty() {
661 new_summary = ThreadSummary::DEFAULT;
662 }
663
664 if current_summary != &new_summary {
665 self.summary = ThreadSummary::Ready(new_summary);
666 cx.emit(ThreadEvent::SummaryChanged);
667 }
668 }
669
670 pub fn completion_mode(&self) -> CompletionMode {
671 self.completion_mode
672 }
673
674 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
675 self.completion_mode = mode;
676 }
677
678 pub fn message(&self, id: MessageId) -> Option<&Message> {
679 let index = self
680 .messages
681 .binary_search_by(|message| message.id.cmp(&id))
682 .ok()?;
683
684 self.messages.get(index)
685 }
686
687 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
688 self.messages.iter()
689 }
690
691 pub fn is_generating(&self) -> bool {
692 !self.pending_completions.is_empty() || !self.all_tools_finished()
693 }
694
695 /// Indicates whether streaming of language model events is stale.
696 /// When `is_generating()` is false, this method returns `None`.
697 pub fn is_generation_stale(&self) -> Option<bool> {
698 const STALE_THRESHOLD: u128 = 250;
699
700 self.last_received_chunk_at
701 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
702 }
703
704 fn received_chunk(&mut self) {
705 self.last_received_chunk_at = Some(Instant::now());
706 }
707
708 pub fn queue_state(&self) -> Option<QueueState> {
709 self.pending_completions
710 .first()
711 .map(|pending_completion| pending_completion.queue_state)
712 }
713
714 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
715 &self.tools
716 }
717
718 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
719 self.tool_use
720 .pending_tool_uses()
721 .into_iter()
722 .find(|tool_use| &tool_use.id == id)
723 }
724
725 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
726 self.tool_use
727 .pending_tool_uses()
728 .into_iter()
729 .filter(|tool_use| tool_use.status.needs_confirmation())
730 }
731
732 pub fn has_pending_tool_uses(&self) -> bool {
733 !self.tool_use.pending_tool_uses().is_empty()
734 }
735
736 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
737 self.checkpoints_by_message.get(&id).cloned()
738 }
739
740 pub fn restore_checkpoint(
741 &mut self,
742 checkpoint: ThreadCheckpoint,
743 cx: &mut Context<Self>,
744 ) -> Task<Result<()>> {
745 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
746 message_id: checkpoint.message_id,
747 });
748 cx.emit(ThreadEvent::CheckpointChanged);
749 cx.notify();
750
751 let git_store = self.project().read(cx).git_store().clone();
752 let restore = git_store.update(cx, |git_store, cx| {
753 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
754 });
755
756 cx.spawn(async move |this, cx| {
757 let result = restore.await;
758 this.update(cx, |this, cx| {
759 if let Err(err) = result.as_ref() {
760 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
761 message_id: checkpoint.message_id,
762 error: err.to_string(),
763 });
764 } else {
765 this.truncate(checkpoint.message_id, cx);
766 this.last_restore_checkpoint = None;
767 }
768 this.pending_checkpoint = None;
769 cx.emit(ThreadEvent::CheckpointChanged);
770 cx.notify();
771 })?;
772 result
773 })
774 }
775
776 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
777 let pending_checkpoint = if self.is_generating() {
778 return;
779 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
780 checkpoint
781 } else {
782 return;
783 };
784
785 self.finalize_checkpoint(pending_checkpoint, cx);
786 }
787
788 fn finalize_checkpoint(
789 &mut self,
790 pending_checkpoint: ThreadCheckpoint,
791 cx: &mut Context<Self>,
792 ) {
793 let git_store = self.project.read(cx).git_store().clone();
794 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
795 cx.spawn(async move |this, cx| match final_checkpoint.await {
796 Ok(final_checkpoint) => {
797 let equal = git_store
798 .update(cx, |store, cx| {
799 store.compare_checkpoints(
800 pending_checkpoint.git_checkpoint.clone(),
801 final_checkpoint.clone(),
802 cx,
803 )
804 })?
805 .await
806 .unwrap_or(false);
807
808 if !equal {
809 this.update(cx, |this, cx| {
810 this.insert_checkpoint(pending_checkpoint, cx)
811 })?;
812 }
813
814 Ok(())
815 }
816 Err(_) => this.update(cx, |this, cx| {
817 this.insert_checkpoint(pending_checkpoint, cx)
818 }),
819 })
820 .detach();
821 }
822
823 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
824 self.checkpoints_by_message
825 .insert(checkpoint.message_id, checkpoint);
826 cx.emit(ThreadEvent::CheckpointChanged);
827 cx.notify();
828 }
829
830 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
831 self.last_restore_checkpoint.as_ref()
832 }
833
834 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
835 let Some(message_ix) = self
836 .messages
837 .iter()
838 .rposition(|message| message.id == message_id)
839 else {
840 return;
841 };
842 for deleted_message in self.messages.drain(message_ix..) {
843 self.checkpoints_by_message.remove(&deleted_message.id);
844 }
845 cx.notify();
846 }
847
848 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
849 self.messages
850 .iter()
851 .find(|message| message.id == id)
852 .into_iter()
853 .flat_map(|message| message.loaded_context.contexts.iter())
854 }
855
856 pub fn is_turn_end(&self, ix: usize) -> bool {
857 if self.messages.is_empty() {
858 return false;
859 }
860
861 if !self.is_generating() && ix == self.messages.len() - 1 {
862 return true;
863 }
864
865 let Some(message) = self.messages.get(ix) else {
866 return false;
867 };
868
869 if message.role != Role::Assistant {
870 return false;
871 }
872
873 self.messages
874 .get(ix + 1)
875 .and_then(|message| {
876 self.message(message.id)
877 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
878 })
879 .unwrap_or(false)
880 }
881
882 pub fn last_usage(&self) -> Option<RequestUsage> {
883 self.last_usage
884 }
885
886 pub fn tool_use_limit_reached(&self) -> bool {
887 self.tool_use_limit_reached
888 }
889
890 /// Returns whether all of the tool uses have finished running.
891 pub fn all_tools_finished(&self) -> bool {
892 // If the only pending tool uses left are the ones with errors, then
893 // that means that we've finished running all of the pending tools.
894 self.tool_use
895 .pending_tool_uses()
896 .iter()
897 .all(|pending_tool_use| pending_tool_use.status.is_error())
898 }
899
900 /// Returns whether any pending tool uses may perform edits
901 pub fn has_pending_edit_tool_uses(&self) -> bool {
902 self.tool_use
903 .pending_tool_uses()
904 .iter()
905 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
906 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
907 }
908
909 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
910 self.tool_use.tool_uses_for_message(id, cx)
911 }
912
913 pub fn tool_results_for_message(
914 &self,
915 assistant_message_id: MessageId,
916 ) -> Vec<&LanguageModelToolResult> {
917 self.tool_use.tool_results_for_message(assistant_message_id)
918 }
919
920 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
921 self.tool_use.tool_result(id)
922 }
923
924 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
925 match &self.tool_use.tool_result(id)?.content {
926 LanguageModelToolResultContent::Text(text) => Some(text),
927 LanguageModelToolResultContent::Image(_) => {
928 // TODO: We should display image
929 None
930 }
931 }
932 }
933
934 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
935 self.tool_use.tool_result_card(id).cloned()
936 }
937
938 /// Return tools that are both enabled and supported by the model
939 pub fn available_tools(
940 &self,
941 cx: &App,
942 model: Arc<dyn LanguageModel>,
943 ) -> Vec<LanguageModelRequestTool> {
944 if model.supports_tools() {
945 resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
946 .into_iter()
947 .filter_map(|(name, tool)| {
948 // Skip tools that cannot be supported
949 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
950 Some(LanguageModelRequestTool {
951 name,
952 description: tool.description(),
953 input_schema,
954 })
955 })
956 .collect()
957 } else {
958 Vec::default()
959 }
960 }
961
962 pub fn insert_user_message(
963 &mut self,
964 text: impl Into<String>,
965 loaded_context: ContextLoadResult,
966 git_checkpoint: Option<GitStoreCheckpoint>,
967 creases: Vec<MessageCrease>,
968 cx: &mut Context<Self>,
969 ) -> MessageId {
970 if !loaded_context.referenced_buffers.is_empty() {
971 self.action_log.update(cx, |log, cx| {
972 for buffer in loaded_context.referenced_buffers {
973 log.buffer_read(buffer, cx);
974 }
975 });
976 }
977
978 let message_id = self.insert_message(
979 Role::User,
980 vec![MessageSegment::Text(text.into())],
981 loaded_context.loaded_context,
982 creases,
983 false,
984 cx,
985 );
986
987 if let Some(git_checkpoint) = git_checkpoint {
988 self.pending_checkpoint = Some(ThreadCheckpoint {
989 message_id,
990 git_checkpoint,
991 });
992 }
993
994 self.auto_capture_telemetry(cx);
995
996 message_id
997 }
998
999 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
1000 let id = self.insert_message(
1001 Role::User,
1002 vec![MessageSegment::Text("Continue where you left off".into())],
1003 LoadedContext::default(),
1004 vec![],
1005 true,
1006 cx,
1007 );
1008 self.pending_checkpoint = None;
1009
1010 id
1011 }
1012
1013 pub fn insert_assistant_message(
1014 &mut self,
1015 segments: Vec<MessageSegment>,
1016 cx: &mut Context<Self>,
1017 ) -> MessageId {
1018 self.insert_message(
1019 Role::Assistant,
1020 segments,
1021 LoadedContext::default(),
1022 Vec::new(),
1023 false,
1024 cx,
1025 )
1026 }
1027
1028 pub fn insert_message(
1029 &mut self,
1030 role: Role,
1031 segments: Vec<MessageSegment>,
1032 loaded_context: LoadedContext,
1033 creases: Vec<MessageCrease>,
1034 is_hidden: bool,
1035 cx: &mut Context<Self>,
1036 ) -> MessageId {
1037 let id = self.next_message_id.post_inc();
1038 self.messages.push(Message {
1039 id,
1040 role,
1041 segments,
1042 loaded_context,
1043 creases,
1044 is_hidden,
1045 });
1046 self.touch_updated_at();
1047 cx.emit(ThreadEvent::MessageAdded(id));
1048 id
1049 }
1050
1051 pub fn edit_message(
1052 &mut self,
1053 id: MessageId,
1054 new_role: Role,
1055 new_segments: Vec<MessageSegment>,
1056 creases: Vec<MessageCrease>,
1057 loaded_context: Option<LoadedContext>,
1058 checkpoint: Option<GitStoreCheckpoint>,
1059 cx: &mut Context<Self>,
1060 ) -> bool {
1061 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1062 return false;
1063 };
1064 message.role = new_role;
1065 message.segments = new_segments;
1066 message.creases = creases;
1067 if let Some(context) = loaded_context {
1068 message.loaded_context = context;
1069 }
1070 if let Some(git_checkpoint) = checkpoint {
1071 self.checkpoints_by_message.insert(
1072 id,
1073 ThreadCheckpoint {
1074 message_id: id,
1075 git_checkpoint,
1076 },
1077 );
1078 }
1079 self.touch_updated_at();
1080 cx.emit(ThreadEvent::MessageEdited(id));
1081 true
1082 }
1083
1084 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1085 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1086 return false;
1087 };
1088 self.messages.remove(index);
1089 self.touch_updated_at();
1090 cx.emit(ThreadEvent::MessageDeleted(id));
1091 true
1092 }
1093
1094 /// Returns the representation of this [`Thread`] in a textual form.
1095 ///
1096 /// This is the representation we use when attaching a thread as context to another thread.
1097 pub fn text(&self) -> String {
1098 let mut text = String::new();
1099
1100 for message in &self.messages {
1101 text.push_str(match message.role {
1102 language_model::Role::User => "User:",
1103 language_model::Role::Assistant => "Agent:",
1104 language_model::Role::System => "System:",
1105 });
1106 text.push('\n');
1107
1108 for segment in &message.segments {
1109 match segment {
1110 MessageSegment::Text(content) => text.push_str(content),
1111 MessageSegment::Thinking { text: content, .. } => {
1112 text.push_str(&format!("<think>{}</think>", content))
1113 }
1114 MessageSegment::RedactedThinking(_) => {}
1115 }
1116 }
1117 text.push('\n');
1118 }
1119
1120 text
1121 }
1122
1123 /// Serializes this thread into a format for storage or telemetry.
1124 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1125 let initial_project_snapshot = self.initial_project_snapshot.clone();
1126 cx.spawn(async move |this, cx| {
1127 let initial_project_snapshot = initial_project_snapshot.await;
1128 this.read_with(cx, |this, cx| SerializedThread {
1129 version: SerializedThread::VERSION.to_string(),
1130 summary: this.summary().or_default(),
1131 updated_at: this.updated_at(),
1132 messages: this
1133 .messages()
1134 .map(|message| SerializedMessage {
1135 id: message.id,
1136 role: message.role,
1137 segments: message
1138 .segments
1139 .iter()
1140 .map(|segment| match segment {
1141 MessageSegment::Text(text) => {
1142 SerializedMessageSegment::Text { text: text.clone() }
1143 }
1144 MessageSegment::Thinking { text, signature } => {
1145 SerializedMessageSegment::Thinking {
1146 text: text.clone(),
1147 signature: signature.clone(),
1148 }
1149 }
1150 MessageSegment::RedactedThinking(data) => {
1151 SerializedMessageSegment::RedactedThinking {
1152 data: data.clone(),
1153 }
1154 }
1155 })
1156 .collect(),
1157 tool_uses: this
1158 .tool_uses_for_message(message.id, cx)
1159 .into_iter()
1160 .map(|tool_use| SerializedToolUse {
1161 id: tool_use.id,
1162 name: tool_use.name,
1163 input: tool_use.input,
1164 })
1165 .collect(),
1166 tool_results: this
1167 .tool_results_for_message(message.id)
1168 .into_iter()
1169 .map(|tool_result| SerializedToolResult {
1170 tool_use_id: tool_result.tool_use_id.clone(),
1171 is_error: tool_result.is_error,
1172 content: tool_result.content.clone(),
1173 output: tool_result.output.clone(),
1174 })
1175 .collect(),
1176 context: message.loaded_context.text.clone(),
1177 creases: message
1178 .creases
1179 .iter()
1180 .map(|crease| SerializedCrease {
1181 start: crease.range.start,
1182 end: crease.range.end,
1183 icon_path: crease.metadata.icon_path.clone(),
1184 label: crease.metadata.label.clone(),
1185 })
1186 .collect(),
1187 is_hidden: message.is_hidden,
1188 })
1189 .collect(),
1190 initial_project_snapshot,
1191 cumulative_token_usage: this.cumulative_token_usage,
1192 request_token_usage: this.request_token_usage.clone(),
1193 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1194 exceeded_window_error: this.exceeded_window_error.clone(),
1195 model: this
1196 .configured_model
1197 .as_ref()
1198 .map(|model| SerializedLanguageModel {
1199 provider: model.provider.id().0.to_string(),
1200 model: model.model.id().0.to_string(),
1201 }),
1202 completion_mode: Some(this.completion_mode),
1203 tool_use_limit_reached: this.tool_use_limit_reached,
1204 profile: Some(this.profile.id().clone()),
1205 })
1206 })
1207 }
1208
1209 pub fn remaining_turns(&self) -> u32 {
1210 self.remaining_turns
1211 }
1212
1213 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1214 self.remaining_turns = remaining_turns;
1215 }
1216
1217 pub fn send_to_model(
1218 &mut self,
1219 model: Arc<dyn LanguageModel>,
1220 intent: CompletionIntent,
1221 window: Option<AnyWindowHandle>,
1222 cx: &mut Context<Self>,
1223 ) {
1224 if self.remaining_turns == 0 {
1225 return;
1226 }
1227
1228 self.remaining_turns -= 1;
1229
1230 let request = self.to_completion_request(model.clone(), intent, cx);
1231
1232 self.stream_completion(request, model, window, cx);
1233 }
1234
1235 pub fn used_tools_since_last_user_message(&self) -> bool {
1236 for message in self.messages.iter().rev() {
1237 if self.tool_use.message_has_tool_results(message.id) {
1238 return true;
1239 } else if message.role == Role::User {
1240 return false;
1241 }
1242 }
1243
1244 false
1245 }
1246
1247 pub fn to_completion_request(
1248 &self,
1249 model: Arc<dyn LanguageModel>,
1250 intent: CompletionIntent,
1251 cx: &mut Context<Self>,
1252 ) -> LanguageModelRequest {
1253 let mut request = LanguageModelRequest {
1254 thread_id: Some(self.id.to_string()),
1255 prompt_id: Some(self.last_prompt_id.to_string()),
1256 intent: Some(intent),
1257 mode: None,
1258 messages: vec![],
1259 tools: Vec::new(),
1260 tool_choice: None,
1261 stop: Vec::new(),
1262 temperature: AgentSettings::temperature_for_model(&model, cx),
1263 };
1264
1265 let available_tools = self.available_tools(cx, model.clone());
1266 let available_tool_names = available_tools
1267 .iter()
1268 .map(|tool| tool.name.clone())
1269 .collect();
1270
1271 let model_context = &ModelContext {
1272 available_tools: available_tool_names,
1273 };
1274
1275 if let Some(project_context) = self.project_context.borrow().as_ref() {
1276 match self
1277 .prompt_builder
1278 .generate_assistant_system_prompt(project_context, model_context)
1279 {
1280 Err(err) => {
1281 let message = format!("{err:?}").into();
1282 log::error!("{message}");
1283 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1284 header: "Error generating system prompt".into(),
1285 message,
1286 }));
1287 }
1288 Ok(system_prompt) => {
1289 request.messages.push(LanguageModelRequestMessage {
1290 role: Role::System,
1291 content: vec![MessageContent::Text(system_prompt)],
1292 cache: true,
1293 });
1294 }
1295 }
1296 } else {
1297 let message = "Context for system prompt unexpectedly not ready.".into();
1298 log::error!("{message}");
1299 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1300 header: "Error generating system prompt".into(),
1301 message,
1302 }));
1303 }
1304
1305 let mut message_ix_to_cache = None;
1306 for message in &self.messages {
1307 let mut request_message = LanguageModelRequestMessage {
1308 role: message.role,
1309 content: Vec::new(),
1310 cache: false,
1311 };
1312
1313 message
1314 .loaded_context
1315 .add_to_request_message(&mut request_message);
1316
1317 for segment in &message.segments {
1318 match segment {
1319 MessageSegment::Text(text) => {
1320 if !text.is_empty() {
1321 request_message
1322 .content
1323 .push(MessageContent::Text(text.into()));
1324 }
1325 }
1326 MessageSegment::Thinking { text, signature } => {
1327 if !text.is_empty() {
1328 request_message.content.push(MessageContent::Thinking {
1329 text: text.into(),
1330 signature: signature.clone(),
1331 });
1332 }
1333 }
1334 MessageSegment::RedactedThinking(data) => {
1335 request_message
1336 .content
1337 .push(MessageContent::RedactedThinking(data.clone()));
1338 }
1339 };
1340 }
1341
1342 let mut cache_message = true;
1343 let mut tool_results_message = LanguageModelRequestMessage {
1344 role: Role::User,
1345 content: Vec::new(),
1346 cache: false,
1347 };
1348 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1349 if let Some(tool_result) = tool_result {
1350 request_message
1351 .content
1352 .push(MessageContent::ToolUse(tool_use.clone()));
1353 tool_results_message
1354 .content
1355 .push(MessageContent::ToolResult(LanguageModelToolResult {
1356 tool_use_id: tool_use.id.clone(),
1357 tool_name: tool_result.tool_name.clone(),
1358 is_error: tool_result.is_error,
1359 content: if tool_result.content.is_empty() {
1360 // Surprisingly, the API fails if we return an empty string here.
1361 // It thinks we are sending a tool use without a tool result.
1362 "<Tool returned an empty string>".into()
1363 } else {
1364 tool_result.content.clone()
1365 },
1366 output: None,
1367 }));
1368 } else {
1369 cache_message = false;
1370 log::debug!(
1371 "skipped tool use {:?} because it is still pending",
1372 tool_use
1373 );
1374 }
1375 }
1376
1377 if cache_message {
1378 message_ix_to_cache = Some(request.messages.len());
1379 }
1380 request.messages.push(request_message);
1381
1382 if !tool_results_message.content.is_empty() {
1383 if cache_message {
1384 message_ix_to_cache = Some(request.messages.len());
1385 }
1386 request.messages.push(tool_results_message);
1387 }
1388 }
1389
1390 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1391 if let Some(message_ix_to_cache) = message_ix_to_cache {
1392 request.messages[message_ix_to_cache].cache = true;
1393 }
1394
1395 request.tools = available_tools;
1396 request.mode = if model.supports_max_mode() {
1397 Some(self.completion_mode.into())
1398 } else {
1399 Some(CompletionMode::Normal.into())
1400 };
1401
1402 request
1403 }
1404
1405 fn to_summarize_request(
1406 &self,
1407 model: &Arc<dyn LanguageModel>,
1408 intent: CompletionIntent,
1409 added_user_message: String,
1410 cx: &App,
1411 ) -> LanguageModelRequest {
1412 let mut request = LanguageModelRequest {
1413 thread_id: None,
1414 prompt_id: None,
1415 intent: Some(intent),
1416 mode: None,
1417 messages: vec![],
1418 tools: Vec::new(),
1419 tool_choice: None,
1420 stop: Vec::new(),
1421 temperature: AgentSettings::temperature_for_model(model, cx),
1422 };
1423
1424 for message in &self.messages {
1425 let mut request_message = LanguageModelRequestMessage {
1426 role: message.role,
1427 content: Vec::new(),
1428 cache: false,
1429 };
1430
1431 for segment in &message.segments {
1432 match segment {
1433 MessageSegment::Text(text) => request_message
1434 .content
1435 .push(MessageContent::Text(text.clone())),
1436 MessageSegment::Thinking { .. } => {}
1437 MessageSegment::RedactedThinking(_) => {}
1438 }
1439 }
1440
1441 if request_message.content.is_empty() {
1442 continue;
1443 }
1444
1445 request.messages.push(request_message);
1446 }
1447
1448 request.messages.push(LanguageModelRequestMessage {
1449 role: Role::User,
1450 content: vec![MessageContent::Text(added_user_message)],
1451 cache: false,
1452 });
1453
1454 request
1455 }
1456
1457 pub fn stream_completion(
1458 &mut self,
1459 request: LanguageModelRequest,
1460 model: Arc<dyn LanguageModel>,
1461 window: Option<AnyWindowHandle>,
1462 cx: &mut Context<Self>,
1463 ) {
1464 self.tool_use_limit_reached = false;
1465
1466 let pending_completion_id = post_inc(&mut self.completion_count);
1467 let mut request_callback_parameters = if self.request_callback.is_some() {
1468 Some((request.clone(), Vec::new()))
1469 } else {
1470 None
1471 };
1472 let prompt_id = self.last_prompt_id.clone();
1473 let tool_use_metadata = ToolUseMetadata {
1474 model: model.clone(),
1475 thread_id: self.id.clone(),
1476 prompt_id: prompt_id.clone(),
1477 };
1478
1479 self.last_received_chunk_at = Some(Instant::now());
1480
1481 let task = cx.spawn(async move |thread, cx| {
1482 let stream_completion_future = model.stream_completion(request, &cx);
1483 let initial_token_usage =
1484 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1485 let stream_completion = async {
1486 let mut events = stream_completion_future.await?;
1487
1488 let mut stop_reason = StopReason::EndTurn;
1489 let mut current_token_usage = TokenUsage::default();
1490
1491 thread
1492 .update(cx, |_thread, cx| {
1493 cx.emit(ThreadEvent::NewRequest);
1494 })
1495 .ok();
1496
1497 let mut request_assistant_message_id = None;
1498
1499 while let Some(event) = events.next().await {
1500 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1501 response_events
1502 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1503 }
1504
1505 thread.update(cx, |thread, cx| {
1506 let event = match event {
1507 Ok(event) => event,
1508 Err(LanguageModelCompletionError::BadInputJson {
1509 id,
1510 tool_name,
1511 raw_input: invalid_input_json,
1512 json_parse_error,
1513 }) => {
1514 thread.receive_invalid_tool_json(
1515 id,
1516 tool_name,
1517 invalid_input_json,
1518 json_parse_error,
1519 window,
1520 cx,
1521 );
1522 return Ok(());
1523 }
1524 Err(LanguageModelCompletionError::Other(error)) => {
1525 return Err(error);
1526 }
1527 Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
1528 return Err(err.into());
1529 }
1530 };
1531
1532 match event {
1533 LanguageModelCompletionEvent::StartMessage { .. } => {
1534 request_assistant_message_id =
1535 Some(thread.insert_assistant_message(
1536 vec![MessageSegment::Text(String::new())],
1537 cx,
1538 ));
1539 }
1540 LanguageModelCompletionEvent::Stop(reason) => {
1541 stop_reason = reason;
1542 }
1543 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1544 thread.update_token_usage_at_last_message(token_usage);
1545 thread.cumulative_token_usage = thread.cumulative_token_usage
1546 + token_usage
1547 - current_token_usage;
1548 current_token_usage = token_usage;
1549 }
1550 LanguageModelCompletionEvent::Text(chunk) => {
1551 thread.received_chunk();
1552
1553 cx.emit(ThreadEvent::ReceivedTextChunk);
1554 if let Some(last_message) = thread.messages.last_mut() {
1555 if last_message.role == Role::Assistant
1556 && !thread.tool_use.has_tool_results(last_message.id)
1557 {
1558 last_message.push_text(&chunk);
1559 cx.emit(ThreadEvent::StreamedAssistantText(
1560 last_message.id,
1561 chunk,
1562 ));
1563 } else {
1564 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1565 // of a new Assistant response.
1566 //
1567 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1568 // will result in duplicating the text of the chunk in the rendered Markdown.
1569 request_assistant_message_id =
1570 Some(thread.insert_assistant_message(
1571 vec![MessageSegment::Text(chunk.to_string())],
1572 cx,
1573 ));
1574 };
1575 }
1576 }
1577 LanguageModelCompletionEvent::Thinking {
1578 text: chunk,
1579 signature,
1580 } => {
1581 thread.received_chunk();
1582
1583 if let Some(last_message) = thread.messages.last_mut() {
1584 if last_message.role == Role::Assistant
1585 && !thread.tool_use.has_tool_results(last_message.id)
1586 {
1587 last_message.push_thinking(&chunk, signature);
1588 cx.emit(ThreadEvent::StreamedAssistantThinking(
1589 last_message.id,
1590 chunk,
1591 ));
1592 } else {
1593 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1594 // of a new Assistant response.
1595 //
1596 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1597 // will result in duplicating the text of the chunk in the rendered Markdown.
1598 request_assistant_message_id =
1599 Some(thread.insert_assistant_message(
1600 vec![MessageSegment::Thinking {
1601 text: chunk.to_string(),
1602 signature,
1603 }],
1604 cx,
1605 ));
1606 };
1607 }
1608 }
1609 LanguageModelCompletionEvent::RedactedThinking {
1610 data
1611 } => {
1612 thread.received_chunk();
1613
1614 if let Some(last_message) = thread.messages.last_mut() {
1615 if last_message.role == Role::Assistant
1616 && !thread.tool_use.has_tool_results(last_message.id)
1617 {
1618 last_message.push_redacted_thinking(data);
1619 } else {
1620 request_assistant_message_id =
1621 Some(thread.insert_assistant_message(
1622 vec![MessageSegment::RedactedThinking(data)],
1623 cx,
1624 ));
1625 };
1626 }
1627 }
1628 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1629 let last_assistant_message_id = request_assistant_message_id
1630 .unwrap_or_else(|| {
1631 let new_assistant_message_id =
1632 thread.insert_assistant_message(vec![], cx);
1633 request_assistant_message_id =
1634 Some(new_assistant_message_id);
1635 new_assistant_message_id
1636 });
1637
1638 let tool_use_id = tool_use.id.clone();
1639 let streamed_input = if tool_use.is_input_complete {
1640 None
1641 } else {
1642 Some((&tool_use.input).clone())
1643 };
1644
1645 let ui_text = thread.tool_use.request_tool_use(
1646 last_assistant_message_id,
1647 tool_use,
1648 tool_use_metadata.clone(),
1649 cx,
1650 );
1651
1652 if let Some(input) = streamed_input {
1653 cx.emit(ThreadEvent::StreamedToolUse {
1654 tool_use_id,
1655 ui_text,
1656 input,
1657 });
1658 }
1659 }
1660 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1661 if let Some(completion) = thread
1662 .pending_completions
1663 .iter_mut()
1664 .find(|completion| completion.id == pending_completion_id)
1665 {
1666 match status_update {
1667 CompletionRequestStatus::Queued {
1668 position,
1669 } => {
1670 completion.queue_state = QueueState::Queued { position };
1671 }
1672 CompletionRequestStatus::Started => {
1673 completion.queue_state = QueueState::Started;
1674 }
1675 CompletionRequestStatus::Failed {
1676 code, message, request_id
1677 } => {
1678 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1679 }
1680 CompletionRequestStatus::UsageUpdated {
1681 amount, limit
1682 } => {
1683 let usage = RequestUsage { limit, amount: amount as i32 };
1684
1685 thread.last_usage = Some(usage);
1686 }
1687 CompletionRequestStatus::ToolUseLimitReached => {
1688 thread.tool_use_limit_reached = true;
1689 cx.emit(ThreadEvent::ToolUseLimitReached);
1690 }
1691 }
1692 }
1693 }
1694 }
1695
1696 thread.touch_updated_at();
1697 cx.emit(ThreadEvent::StreamedCompletion);
1698 cx.notify();
1699
1700 thread.auto_capture_telemetry(cx);
1701 Ok(())
1702 })??;
1703
1704 smol::future::yield_now().await;
1705 }
1706
1707 thread.update(cx, |thread, cx| {
1708 thread.last_received_chunk_at = None;
1709 thread
1710 .pending_completions
1711 .retain(|completion| completion.id != pending_completion_id);
1712
1713 // If there is a response without tool use, summarize the message. Otherwise,
1714 // allow two tool uses before summarizing.
1715 if matches!(thread.summary, ThreadSummary::Pending)
1716 && thread.messages.len() >= 2
1717 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1718 {
1719 thread.summarize(cx);
1720 }
1721 })?;
1722
1723 anyhow::Ok(stop_reason)
1724 };
1725
1726 let result = stream_completion.await;
1727
1728 thread
1729 .update(cx, |thread, cx| {
1730 thread.finalize_pending_checkpoint(cx);
1731 match result.as_ref() {
1732 Ok(stop_reason) => match stop_reason {
1733 StopReason::ToolUse => {
1734 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1735 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1736 }
1737 StopReason::EndTurn | StopReason::MaxTokens => {
1738 thread.project.update(cx, |project, cx| {
1739 project.set_agent_location(None, cx);
1740 });
1741 }
1742 StopReason::Refusal => {
1743 thread.project.update(cx, |project, cx| {
1744 project.set_agent_location(None, cx);
1745 });
1746
1747 // Remove the turn that was refused.
1748 //
1749 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1750 {
1751 let mut messages_to_remove = Vec::new();
1752
1753 for (ix, message) in thread.messages.iter().enumerate().rev() {
1754 messages_to_remove.push(message.id);
1755
1756 if message.role == Role::User {
1757 if ix == 0 {
1758 break;
1759 }
1760
1761 if let Some(prev_message) = thread.messages.get(ix - 1) {
1762 if prev_message.role == Role::Assistant {
1763 break;
1764 }
1765 }
1766 }
1767 }
1768
1769 for message_id in messages_to_remove {
1770 thread.delete_message(message_id, cx);
1771 }
1772 }
1773
1774 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1775 header: "Language model refusal".into(),
1776 message: "Model refused to generate content for safety reasons.".into(),
1777 }));
1778 }
1779 },
1780 Err(error) => {
1781 thread.project.update(cx, |project, cx| {
1782 project.set_agent_location(None, cx);
1783 });
1784
1785 if error.is::<PaymentRequiredError>() {
1786 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1787 } else if let Some(error) =
1788 error.downcast_ref::<ModelRequestLimitReachedError>()
1789 {
1790 cx.emit(ThreadEvent::ShowError(
1791 ThreadError::ModelRequestLimitReached { plan: error.plan },
1792 ));
1793 } else if let Some(known_error) =
1794 error.downcast_ref::<LanguageModelKnownError>()
1795 {
1796 match known_error {
1797 LanguageModelKnownError::ContextWindowLimitExceeded {
1798 tokens,
1799 } => {
1800 thread.exceeded_window_error = Some(ExceededWindowError {
1801 model_id: model.id(),
1802 token_count: *tokens,
1803 });
1804 cx.notify();
1805 }
1806 }
1807 } else {
1808 let error_message = error
1809 .chain()
1810 .map(|err| err.to_string())
1811 .collect::<Vec<_>>()
1812 .join("\n");
1813 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1814 header: "Error interacting with language model".into(),
1815 message: SharedString::from(error_message.clone()),
1816 }));
1817 }
1818
1819 thread.cancel_last_completion(window, cx);
1820 }
1821 }
1822
1823 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1824
1825 if let Some((request_callback, (request, response_events))) = thread
1826 .request_callback
1827 .as_mut()
1828 .zip(request_callback_parameters.as_ref())
1829 {
1830 request_callback(request, response_events);
1831 }
1832
1833 thread.auto_capture_telemetry(cx);
1834
1835 if let Ok(initial_usage) = initial_token_usage {
1836 let usage = thread.cumulative_token_usage - initial_usage;
1837
1838 telemetry::event!(
1839 "Assistant Thread Completion",
1840 thread_id = thread.id().to_string(),
1841 prompt_id = prompt_id,
1842 model = model.telemetry_id(),
1843 model_provider = model.provider_id().to_string(),
1844 input_tokens = usage.input_tokens,
1845 output_tokens = usage.output_tokens,
1846 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1847 cache_read_input_tokens = usage.cache_read_input_tokens,
1848 );
1849 }
1850 })
1851 .ok();
1852 });
1853
1854 self.pending_completions.push(PendingCompletion {
1855 id: pending_completion_id,
1856 queue_state: QueueState::Sending,
1857 _task: task,
1858 });
1859 }
1860
1861 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1862 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1863 println!("No thread summary model");
1864 return;
1865 };
1866
1867 if !model.provider.is_authenticated(cx) {
1868 return;
1869 }
1870
1871 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1872
1873 let request = self.to_summarize_request(
1874 &model.model,
1875 CompletionIntent::ThreadSummarization,
1876 added_user_message.into(),
1877 cx,
1878 );
1879
1880 self.summary = ThreadSummary::Generating;
1881
1882 self.pending_summary = cx.spawn(async move |this, cx| {
1883 let result = async {
1884 let mut messages = model.model.stream_completion(request, &cx).await?;
1885
1886 let mut new_summary = String::new();
1887 while let Some(event) = messages.next().await {
1888 let Ok(event) = event else {
1889 continue;
1890 };
1891 let text = match event {
1892 LanguageModelCompletionEvent::Text(text) => text,
1893 LanguageModelCompletionEvent::StatusUpdate(
1894 CompletionRequestStatus::UsageUpdated { amount, limit },
1895 ) => {
1896 this.update(cx, |thread, _cx| {
1897 thread.last_usage = Some(RequestUsage {
1898 limit,
1899 amount: amount as i32,
1900 });
1901 })?;
1902 continue;
1903 }
1904 _ => continue,
1905 };
1906
1907 let mut lines = text.lines();
1908 new_summary.extend(lines.next());
1909
1910 // Stop if the LLM generated multiple lines.
1911 if lines.next().is_some() {
1912 break;
1913 }
1914 }
1915
1916 anyhow::Ok(new_summary)
1917 }
1918 .await;
1919
1920 this.update(cx, |this, cx| {
1921 match result {
1922 Ok(new_summary) => {
1923 if new_summary.is_empty() {
1924 this.summary = ThreadSummary::Error;
1925 } else {
1926 this.summary = ThreadSummary::Ready(new_summary.into());
1927 }
1928 }
1929 Err(err) => {
1930 this.summary = ThreadSummary::Error;
1931 log::error!("Failed to generate thread summary: {}", err);
1932 }
1933 }
1934 cx.emit(ThreadEvent::SummaryGenerated);
1935 })
1936 .log_err()?;
1937
1938 Some(())
1939 });
1940 }
1941
1942 pub fn start_generating_detailed_summary_if_needed(
1943 &mut self,
1944 thread_store: WeakEntity<ThreadStore>,
1945 cx: &mut Context<Self>,
1946 ) {
1947 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1948 return;
1949 };
1950
1951 match &*self.detailed_summary_rx.borrow() {
1952 DetailedSummaryState::Generating { message_id, .. }
1953 | DetailedSummaryState::Generated { message_id, .. }
1954 if *message_id == last_message_id =>
1955 {
1956 // Already up-to-date
1957 return;
1958 }
1959 _ => {}
1960 }
1961
1962 let Some(ConfiguredModel { model, provider }) =
1963 LanguageModelRegistry::read_global(cx).thread_summary_model()
1964 else {
1965 return;
1966 };
1967
1968 if !provider.is_authenticated(cx) {
1969 return;
1970 }
1971
1972 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
1973
1974 let request = self.to_summarize_request(
1975 &model,
1976 CompletionIntent::ThreadContextSummarization,
1977 added_user_message.into(),
1978 cx,
1979 );
1980
1981 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1982 message_id: last_message_id,
1983 };
1984
1985 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1986 // be better to allow the old task to complete, but this would require logic for choosing
1987 // which result to prefer (the old task could complete after the new one, resulting in a
1988 // stale summary).
1989 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1990 let stream = model.stream_completion_text(request, &cx);
1991 let Some(mut messages) = stream.await.log_err() else {
1992 thread
1993 .update(cx, |thread, _cx| {
1994 *thread.detailed_summary_tx.borrow_mut() =
1995 DetailedSummaryState::NotGenerated;
1996 })
1997 .ok()?;
1998 return None;
1999 };
2000
2001 let mut new_detailed_summary = String::new();
2002
2003 while let Some(chunk) = messages.stream.next().await {
2004 if let Some(chunk) = chunk.log_err() {
2005 new_detailed_summary.push_str(&chunk);
2006 }
2007 }
2008
2009 thread
2010 .update(cx, |thread, _cx| {
2011 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2012 text: new_detailed_summary.into(),
2013 message_id: last_message_id,
2014 };
2015 })
2016 .ok()?;
2017
2018 // Save thread so its summary can be reused later
2019 if let Some(thread) = thread.upgrade() {
2020 if let Ok(Ok(save_task)) = cx.update(|cx| {
2021 thread_store
2022 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2023 }) {
2024 save_task.await.log_err();
2025 }
2026 }
2027
2028 Some(())
2029 });
2030 }
2031
2032 pub async fn wait_for_detailed_summary_or_text(
2033 this: &Entity<Self>,
2034 cx: &mut AsyncApp,
2035 ) -> Option<SharedString> {
2036 let mut detailed_summary_rx = this
2037 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2038 .ok()?;
2039 loop {
2040 match detailed_summary_rx.recv().await? {
2041 DetailedSummaryState::Generating { .. } => {}
2042 DetailedSummaryState::NotGenerated => {
2043 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2044 }
2045 DetailedSummaryState::Generated { text, .. } => return Some(text),
2046 }
2047 }
2048 }
2049
2050 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2051 self.detailed_summary_rx
2052 .borrow()
2053 .text()
2054 .unwrap_or_else(|| self.text().into())
2055 }
2056
2057 pub fn is_generating_detailed_summary(&self) -> bool {
2058 matches!(
2059 &*self.detailed_summary_rx.borrow(),
2060 DetailedSummaryState::Generating { .. }
2061 )
2062 }
2063
2064 pub fn use_pending_tools(
2065 &mut self,
2066 window: Option<AnyWindowHandle>,
2067 cx: &mut Context<Self>,
2068 model: Arc<dyn LanguageModel>,
2069 ) -> Vec<PendingToolUse> {
2070 self.auto_capture_telemetry(cx);
2071 let request =
2072 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2073 let pending_tool_uses = self
2074 .tool_use
2075 .pending_tool_uses()
2076 .into_iter()
2077 .filter(|tool_use| tool_use.status.is_idle())
2078 .cloned()
2079 .collect::<Vec<_>>();
2080
2081 for tool_use in pending_tool_uses.iter() {
2082 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2083 if tool.needs_confirmation(&tool_use.input, cx)
2084 && !AgentSettings::get_global(cx).always_allow_tool_actions
2085 {
2086 self.tool_use.confirm_tool_use(
2087 tool_use.id.clone(),
2088 tool_use.ui_text.clone(),
2089 tool_use.input.clone(),
2090 request.clone(),
2091 tool,
2092 );
2093 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2094 } else {
2095 self.run_tool(
2096 tool_use.id.clone(),
2097 tool_use.ui_text.clone(),
2098 tool_use.input.clone(),
2099 request.clone(),
2100 tool,
2101 model.clone(),
2102 window,
2103 cx,
2104 );
2105 }
2106 } else {
2107 self.handle_hallucinated_tool_use(
2108 tool_use.id.clone(),
2109 tool_use.name.clone(),
2110 window,
2111 cx,
2112 );
2113 }
2114 }
2115
2116 pending_tool_uses
2117 }
2118
2119 pub fn handle_hallucinated_tool_use(
2120 &mut self,
2121 tool_use_id: LanguageModelToolUseId,
2122 hallucinated_tool_name: Arc<str>,
2123 window: Option<AnyWindowHandle>,
2124 cx: &mut Context<Thread>,
2125 ) {
2126 let available_tools = self.profile.enabled_tools(cx);
2127
2128 let tool_list = available_tools
2129 .iter()
2130 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2131 .collect::<Vec<_>>()
2132 .join("\n");
2133
2134 let error_message = format!(
2135 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2136 hallucinated_tool_name, tool_list
2137 );
2138
2139 let pending_tool_use = self.tool_use.insert_tool_output(
2140 tool_use_id.clone(),
2141 hallucinated_tool_name,
2142 Err(anyhow!("Missing tool call: {error_message}")),
2143 self.configured_model.as_ref(),
2144 );
2145
2146 cx.emit(ThreadEvent::MissingToolUse {
2147 tool_use_id: tool_use_id.clone(),
2148 ui_text: error_message.into(),
2149 });
2150
2151 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2152 }
2153
2154 pub fn receive_invalid_tool_json(
2155 &mut self,
2156 tool_use_id: LanguageModelToolUseId,
2157 tool_name: Arc<str>,
2158 invalid_json: Arc<str>,
2159 error: String,
2160 window: Option<AnyWindowHandle>,
2161 cx: &mut Context<Thread>,
2162 ) {
2163 log::error!("The model returned invalid input JSON: {invalid_json}");
2164
2165 let pending_tool_use = self.tool_use.insert_tool_output(
2166 tool_use_id.clone(),
2167 tool_name,
2168 Err(anyhow!("Error parsing input JSON: {error}")),
2169 self.configured_model.as_ref(),
2170 );
2171 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2172 pending_tool_use.ui_text.clone()
2173 } else {
2174 log::error!(
2175 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2176 );
2177 format!("Unknown tool {}", tool_use_id).into()
2178 };
2179
2180 cx.emit(ThreadEvent::InvalidToolInput {
2181 tool_use_id: tool_use_id.clone(),
2182 ui_text,
2183 invalid_input_json: invalid_json,
2184 });
2185
2186 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2187 }
2188
2189 pub fn run_tool(
2190 &mut self,
2191 tool_use_id: LanguageModelToolUseId,
2192 ui_text: impl Into<SharedString>,
2193 input: serde_json::Value,
2194 request: Arc<LanguageModelRequest>,
2195 tool: Arc<dyn Tool>,
2196 model: Arc<dyn LanguageModel>,
2197 window: Option<AnyWindowHandle>,
2198 cx: &mut Context<Thread>,
2199 ) {
2200 let task =
2201 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2202 self.tool_use
2203 .run_pending_tool(tool_use_id, ui_text.into(), task);
2204 }
2205
2206 fn spawn_tool_use(
2207 &mut self,
2208 tool_use_id: LanguageModelToolUseId,
2209 request: Arc<LanguageModelRequest>,
2210 input: serde_json::Value,
2211 tool: Arc<dyn Tool>,
2212 model: Arc<dyn LanguageModel>,
2213 window: Option<AnyWindowHandle>,
2214 cx: &mut Context<Thread>,
2215 ) -> Task<()> {
2216 let tool_name: Arc<str> = tool.name().into();
2217
2218 let tool_result = tool.run(
2219 input,
2220 request,
2221 self.project.clone(),
2222 self.action_log.clone(),
2223 model,
2224 window,
2225 cx,
2226 );
2227
2228 // Store the card separately if it exists
2229 if let Some(card) = tool_result.card.clone() {
2230 self.tool_use
2231 .insert_tool_result_card(tool_use_id.clone(), card);
2232 }
2233
2234 cx.spawn({
2235 async move |thread: WeakEntity<Thread>, cx| {
2236 let output = tool_result.output.await;
2237
2238 thread
2239 .update(cx, |thread, cx| {
2240 let pending_tool_use = thread.tool_use.insert_tool_output(
2241 tool_use_id.clone(),
2242 tool_name,
2243 output,
2244 thread.configured_model.as_ref(),
2245 );
2246 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2247 })
2248 .ok();
2249 }
2250 })
2251 }
2252
2253 fn tool_finished(
2254 &mut self,
2255 tool_use_id: LanguageModelToolUseId,
2256 pending_tool_use: Option<PendingToolUse>,
2257 canceled: bool,
2258 window: Option<AnyWindowHandle>,
2259 cx: &mut Context<Self>,
2260 ) {
2261 if self.all_tools_finished() {
2262 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2263 if !canceled {
2264 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2265 }
2266 self.auto_capture_telemetry(cx);
2267 }
2268 }
2269
2270 cx.emit(ThreadEvent::ToolFinished {
2271 tool_use_id,
2272 pending_tool_use,
2273 });
2274 }
2275
2276 /// Cancels the last pending completion, if there are any pending.
2277 ///
2278 /// Returns whether a completion was canceled.
2279 pub fn cancel_last_completion(
2280 &mut self,
2281 window: Option<AnyWindowHandle>,
2282 cx: &mut Context<Self>,
2283 ) -> bool {
2284 let mut canceled = self.pending_completions.pop().is_some();
2285
2286 for pending_tool_use in self.tool_use.cancel_pending() {
2287 canceled = true;
2288 self.tool_finished(
2289 pending_tool_use.id.clone(),
2290 Some(pending_tool_use),
2291 true,
2292 window,
2293 cx,
2294 );
2295 }
2296
2297 if canceled {
2298 cx.emit(ThreadEvent::CompletionCanceled);
2299
2300 // When canceled, we always want to insert the checkpoint.
2301 // (We skip over finalize_pending_checkpoint, because it
2302 // would conclude we didn't have anything to insert here.)
2303 if let Some(checkpoint) = self.pending_checkpoint.take() {
2304 self.insert_checkpoint(checkpoint, cx);
2305 }
2306 } else {
2307 self.finalize_pending_checkpoint(cx);
2308 }
2309
2310 canceled
2311 }
2312
2313 /// Signals that any in-progress editing should be canceled.
2314 ///
2315 /// This method is used to notify listeners (like ActiveThread) that
2316 /// they should cancel any editing operations.
2317 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2318 cx.emit(ThreadEvent::CancelEditing);
2319 }
2320
2321 pub fn feedback(&self) -> Option<ThreadFeedback> {
2322 self.feedback
2323 }
2324
2325 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2326 self.message_feedback.get(&message_id).copied()
2327 }
2328
2329 pub fn report_message_feedback(
2330 &mut self,
2331 message_id: MessageId,
2332 feedback: ThreadFeedback,
2333 cx: &mut Context<Self>,
2334 ) -> Task<Result<()>> {
2335 if self.message_feedback.get(&message_id) == Some(&feedback) {
2336 return Task::ready(Ok(()));
2337 }
2338
2339 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2340 let serialized_thread = self.serialize(cx);
2341 let thread_id = self.id().clone();
2342 let client = self.project.read(cx).client();
2343
2344 let enabled_tool_names: Vec<String> = self
2345 .profile
2346 .enabled_tools(cx)
2347 .iter()
2348 .map(|tool| tool.name())
2349 .collect();
2350
2351 self.message_feedback.insert(message_id, feedback);
2352
2353 cx.notify();
2354
2355 let message_content = self
2356 .message(message_id)
2357 .map(|msg| msg.to_string())
2358 .unwrap_or_default();
2359
2360 cx.background_spawn(async move {
2361 let final_project_snapshot = final_project_snapshot.await;
2362 let serialized_thread = serialized_thread.await?;
2363 let thread_data =
2364 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2365
2366 let rating = match feedback {
2367 ThreadFeedback::Positive => "positive",
2368 ThreadFeedback::Negative => "negative",
2369 };
2370 telemetry::event!(
2371 "Assistant Thread Rated",
2372 rating,
2373 thread_id,
2374 enabled_tool_names,
2375 message_id = message_id.0,
2376 message_content,
2377 thread_data,
2378 final_project_snapshot
2379 );
2380 client.telemetry().flush_events().await;
2381
2382 Ok(())
2383 })
2384 }
2385
2386 pub fn report_feedback(
2387 &mut self,
2388 feedback: ThreadFeedback,
2389 cx: &mut Context<Self>,
2390 ) -> Task<Result<()>> {
2391 let last_assistant_message_id = self
2392 .messages
2393 .iter()
2394 .rev()
2395 .find(|msg| msg.role == Role::Assistant)
2396 .map(|msg| msg.id);
2397
2398 if let Some(message_id) = last_assistant_message_id {
2399 self.report_message_feedback(message_id, feedback, cx)
2400 } else {
2401 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2402 let serialized_thread = self.serialize(cx);
2403 let thread_id = self.id().clone();
2404 let client = self.project.read(cx).client();
2405 self.feedback = Some(feedback);
2406 cx.notify();
2407
2408 cx.background_spawn(async move {
2409 let final_project_snapshot = final_project_snapshot.await;
2410 let serialized_thread = serialized_thread.await?;
2411 let thread_data = serde_json::to_value(serialized_thread)
2412 .unwrap_or_else(|_| serde_json::Value::Null);
2413
2414 let rating = match feedback {
2415 ThreadFeedback::Positive => "positive",
2416 ThreadFeedback::Negative => "negative",
2417 };
2418 telemetry::event!(
2419 "Assistant Thread Rated",
2420 rating,
2421 thread_id,
2422 thread_data,
2423 final_project_snapshot
2424 );
2425 client.telemetry().flush_events().await;
2426
2427 Ok(())
2428 })
2429 }
2430 }
2431
2432 /// Create a snapshot of the current project state including git information and unsaved buffers.
2433 fn project_snapshot(
2434 project: Entity<Project>,
2435 cx: &mut Context<Self>,
2436 ) -> Task<Arc<ProjectSnapshot>> {
2437 let git_store = project.read(cx).git_store().clone();
2438 let worktree_snapshots: Vec<_> = project
2439 .read(cx)
2440 .visible_worktrees(cx)
2441 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2442 .collect();
2443
2444 cx.spawn(async move |_, cx| {
2445 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2446
2447 let mut unsaved_buffers = Vec::new();
2448 cx.update(|app_cx| {
2449 let buffer_store = project.read(app_cx).buffer_store();
2450 for buffer_handle in buffer_store.read(app_cx).buffers() {
2451 let buffer = buffer_handle.read(app_cx);
2452 if buffer.is_dirty() {
2453 if let Some(file) = buffer.file() {
2454 let path = file.path().to_string_lossy().to_string();
2455 unsaved_buffers.push(path);
2456 }
2457 }
2458 }
2459 })
2460 .ok();
2461
2462 Arc::new(ProjectSnapshot {
2463 worktree_snapshots,
2464 unsaved_buffer_paths: unsaved_buffers,
2465 timestamp: Utc::now(),
2466 })
2467 })
2468 }
2469
2470 fn worktree_snapshot(
2471 worktree: Entity<project::Worktree>,
2472 git_store: Entity<GitStore>,
2473 cx: &App,
2474 ) -> Task<WorktreeSnapshot> {
2475 cx.spawn(async move |cx| {
2476 // Get worktree path and snapshot
2477 let worktree_info = cx.update(|app_cx| {
2478 let worktree = worktree.read(app_cx);
2479 let path = worktree.abs_path().to_string_lossy().to_string();
2480 let snapshot = worktree.snapshot();
2481 (path, snapshot)
2482 });
2483
2484 let Ok((worktree_path, _snapshot)) = worktree_info else {
2485 return WorktreeSnapshot {
2486 worktree_path: String::new(),
2487 git_state: None,
2488 };
2489 };
2490
2491 let git_state = git_store
2492 .update(cx, |git_store, cx| {
2493 git_store
2494 .repositories()
2495 .values()
2496 .find(|repo| {
2497 repo.read(cx)
2498 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2499 .is_some()
2500 })
2501 .cloned()
2502 })
2503 .ok()
2504 .flatten()
2505 .map(|repo| {
2506 repo.update(cx, |repo, _| {
2507 let current_branch =
2508 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2509 repo.send_job(None, |state, _| async move {
2510 let RepositoryState::Local { backend, .. } = state else {
2511 return GitState {
2512 remote_url: None,
2513 head_sha: None,
2514 current_branch,
2515 diff: None,
2516 };
2517 };
2518
2519 let remote_url = backend.remote_url("origin");
2520 let head_sha = backend.head_sha().await;
2521 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2522
2523 GitState {
2524 remote_url,
2525 head_sha,
2526 current_branch,
2527 diff,
2528 }
2529 })
2530 })
2531 });
2532
2533 let git_state = match git_state {
2534 Some(git_state) => match git_state.ok() {
2535 Some(git_state) => git_state.await.ok(),
2536 None => None,
2537 },
2538 None => None,
2539 };
2540
2541 WorktreeSnapshot {
2542 worktree_path,
2543 git_state,
2544 }
2545 })
2546 }
2547
2548 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2549 let mut markdown = Vec::new();
2550
2551 let summary = self.summary().or_default();
2552 writeln!(markdown, "# {summary}\n")?;
2553
2554 for message in self.messages() {
2555 writeln!(
2556 markdown,
2557 "## {role}\n",
2558 role = match message.role {
2559 Role::User => "User",
2560 Role::Assistant => "Agent",
2561 Role::System => "System",
2562 }
2563 )?;
2564
2565 if !message.loaded_context.text.is_empty() {
2566 writeln!(markdown, "{}", message.loaded_context.text)?;
2567 }
2568
2569 if !message.loaded_context.images.is_empty() {
2570 writeln!(
2571 markdown,
2572 "\n{} images attached as context.\n",
2573 message.loaded_context.images.len()
2574 )?;
2575 }
2576
2577 for segment in &message.segments {
2578 match segment {
2579 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2580 MessageSegment::Thinking { text, .. } => {
2581 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2582 }
2583 MessageSegment::RedactedThinking(_) => {}
2584 }
2585 }
2586
2587 for tool_use in self.tool_uses_for_message(message.id, cx) {
2588 writeln!(
2589 markdown,
2590 "**Use Tool: {} ({})**",
2591 tool_use.name, tool_use.id
2592 )?;
2593 writeln!(markdown, "```json")?;
2594 writeln!(
2595 markdown,
2596 "{}",
2597 serde_json::to_string_pretty(&tool_use.input)?
2598 )?;
2599 writeln!(markdown, "```")?;
2600 }
2601
2602 for tool_result in self.tool_results_for_message(message.id) {
2603 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2604 if tool_result.is_error {
2605 write!(markdown, " (Error)")?;
2606 }
2607
2608 writeln!(markdown, "**\n")?;
2609 match &tool_result.content {
2610 LanguageModelToolResultContent::Text(text) => {
2611 writeln!(markdown, "{text}")?;
2612 }
2613 LanguageModelToolResultContent::Image(image) => {
2614 writeln!(markdown, "", image.source)?;
2615 }
2616 }
2617
2618 if let Some(output) = tool_result.output.as_ref() {
2619 writeln!(
2620 markdown,
2621 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2622 serde_json::to_string_pretty(output)?
2623 )?;
2624 }
2625 }
2626 }
2627
2628 Ok(String::from_utf8_lossy(&markdown).to_string())
2629 }
2630
2631 pub fn keep_edits_in_range(
2632 &mut self,
2633 buffer: Entity<language::Buffer>,
2634 buffer_range: Range<language::Anchor>,
2635 cx: &mut Context<Self>,
2636 ) {
2637 self.action_log.update(cx, |action_log, cx| {
2638 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2639 });
2640 }
2641
2642 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2643 self.action_log
2644 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2645 }
2646
2647 pub fn reject_edits_in_ranges(
2648 &mut self,
2649 buffer: Entity<language::Buffer>,
2650 buffer_ranges: Vec<Range<language::Anchor>>,
2651 cx: &mut Context<Self>,
2652 ) -> Task<Result<()>> {
2653 self.action_log.update(cx, |action_log, cx| {
2654 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2655 })
2656 }
2657
2658 pub fn action_log(&self) -> &Entity<ActionLog> {
2659 &self.action_log
2660 }
2661
2662 pub fn project(&self) -> &Entity<Project> {
2663 &self.project
2664 }
2665
2666 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2667 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2668 return;
2669 }
2670
2671 let now = Instant::now();
2672 if let Some(last) = self.last_auto_capture_at {
2673 if now.duration_since(last).as_secs() < 10 {
2674 return;
2675 }
2676 }
2677
2678 self.last_auto_capture_at = Some(now);
2679
2680 let thread_id = self.id().clone();
2681 let github_login = self
2682 .project
2683 .read(cx)
2684 .user_store()
2685 .read(cx)
2686 .current_user()
2687 .map(|user| user.github_login.clone());
2688 let client = self.project.read(cx).client();
2689 let serialize_task = self.serialize(cx);
2690
2691 cx.background_executor()
2692 .spawn(async move {
2693 if let Ok(serialized_thread) = serialize_task.await {
2694 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2695 telemetry::event!(
2696 "Agent Thread Auto-Captured",
2697 thread_id = thread_id.to_string(),
2698 thread_data = thread_data,
2699 auto_capture_reason = "tracked_user",
2700 github_login = github_login
2701 );
2702
2703 client.telemetry().flush_events().await;
2704 }
2705 }
2706 })
2707 .detach();
2708 }
2709
2710 pub fn cumulative_token_usage(&self) -> TokenUsage {
2711 self.cumulative_token_usage
2712 }
2713
2714 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2715 let Some(model) = self.configured_model.as_ref() else {
2716 return TotalTokenUsage::default();
2717 };
2718
2719 let max = model.model.max_token_count();
2720
2721 let index = self
2722 .messages
2723 .iter()
2724 .position(|msg| msg.id == message_id)
2725 .unwrap_or(0);
2726
2727 if index == 0 {
2728 return TotalTokenUsage { total: 0, max };
2729 }
2730
2731 let token_usage = &self
2732 .request_token_usage
2733 .get(index - 1)
2734 .cloned()
2735 .unwrap_or_default();
2736
2737 TotalTokenUsage {
2738 total: token_usage.total_tokens(),
2739 max,
2740 }
2741 }
2742
2743 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2744 let model = self.configured_model.as_ref()?;
2745
2746 let max = model.model.max_token_count();
2747
2748 if let Some(exceeded_error) = &self.exceeded_window_error {
2749 if model.model.id() == exceeded_error.model_id {
2750 return Some(TotalTokenUsage {
2751 total: exceeded_error.token_count,
2752 max,
2753 });
2754 }
2755 }
2756
2757 let total = self
2758 .token_usage_at_last_message()
2759 .unwrap_or_default()
2760 .total_tokens();
2761
2762 Some(TotalTokenUsage { total, max })
2763 }
2764
2765 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2766 self.request_token_usage
2767 .get(self.messages.len().saturating_sub(1))
2768 .or_else(|| self.request_token_usage.last())
2769 .cloned()
2770 }
2771
2772 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2773 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2774 self.request_token_usage
2775 .resize(self.messages.len(), placeholder);
2776
2777 if let Some(last) = self.request_token_usage.last_mut() {
2778 *last = token_usage;
2779 }
2780 }
2781
2782 pub fn deny_tool_use(
2783 &mut self,
2784 tool_use_id: LanguageModelToolUseId,
2785 tool_name: Arc<str>,
2786 window: Option<AnyWindowHandle>,
2787 cx: &mut Context<Self>,
2788 ) {
2789 let err = Err(anyhow::anyhow!(
2790 "Permission to run tool action denied by user"
2791 ));
2792
2793 self.tool_use.insert_tool_output(
2794 tool_use_id.clone(),
2795 tool_name,
2796 err,
2797 self.configured_model.as_ref(),
2798 );
2799 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2800 }
2801}
2802
2803#[derive(Debug, Clone, Error)]
2804pub enum ThreadError {
2805 #[error("Payment required")]
2806 PaymentRequired,
2807 #[error("Model request limit reached")]
2808 ModelRequestLimitReached { plan: Plan },
2809 #[error("Message {header}: {message}")]
2810 Message {
2811 header: SharedString,
2812 message: SharedString,
2813 },
2814}
2815
2816#[derive(Debug, Clone)]
2817pub enum ThreadEvent {
2818 ShowError(ThreadError),
2819 StreamedCompletion,
2820 ReceivedTextChunk,
2821 NewRequest,
2822 StreamedAssistantText(MessageId, String),
2823 StreamedAssistantThinking(MessageId, String),
2824 StreamedToolUse {
2825 tool_use_id: LanguageModelToolUseId,
2826 ui_text: Arc<str>,
2827 input: serde_json::Value,
2828 },
2829 MissingToolUse {
2830 tool_use_id: LanguageModelToolUseId,
2831 ui_text: Arc<str>,
2832 },
2833 InvalidToolInput {
2834 tool_use_id: LanguageModelToolUseId,
2835 ui_text: Arc<str>,
2836 invalid_input_json: Arc<str>,
2837 },
2838 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2839 MessageAdded(MessageId),
2840 MessageEdited(MessageId),
2841 MessageDeleted(MessageId),
2842 SummaryGenerated,
2843 SummaryChanged,
2844 UsePendingTools {
2845 tool_uses: Vec<PendingToolUse>,
2846 },
2847 ToolFinished {
2848 #[allow(unused)]
2849 tool_use_id: LanguageModelToolUseId,
2850 /// The pending tool use that corresponds to this tool.
2851 pending_tool_use: Option<PendingToolUse>,
2852 },
2853 CheckpointChanged,
2854 ToolConfirmationNeeded,
2855 ToolUseLimitReached,
2856 CancelEditing,
2857 CompletionCanceled,
2858 ProfileChanged,
2859}
2860
2861impl EventEmitter<ThreadEvent> for Thread {}
2862
2863struct PendingCompletion {
2864 id: usize,
2865 queue_state: QueueState,
2866 _task: Task<()>,
2867}
2868
2869/// Resolves tool name conflicts by ensuring all tool names are unique.
2870///
2871/// When multiple tools have the same name, this function applies the following rules:
2872/// 1. Native tools always keep their original name
2873/// 2. Context server tools get prefixed with their server ID and an underscore
2874/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
2875/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
2876///
2877/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
2878fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
2879 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
2880 let mut tool_name = tool.name();
2881 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2882 tool_name
2883 }
2884
2885 const MAX_TOOL_NAME_LENGTH: usize = 64;
2886
2887 let mut duplicated_tool_names = HashSet::default();
2888 let mut seen_tool_names = HashSet::default();
2889 for tool in tools {
2890 let tool_name = resolve_tool_name(tool);
2891 if seen_tool_names.contains(&tool_name) {
2892 debug_assert!(
2893 tool.source() != assistant_tool::ToolSource::Native,
2894 "There are two built-in tools with the same name: {}",
2895 tool_name
2896 );
2897 duplicated_tool_names.insert(tool_name);
2898 } else {
2899 seen_tool_names.insert(tool_name);
2900 }
2901 }
2902
2903 if duplicated_tool_names.is_empty() {
2904 return tools
2905 .into_iter()
2906 .map(|tool| (resolve_tool_name(tool), tool.clone()))
2907 .collect();
2908 }
2909
2910 tools
2911 .into_iter()
2912 .filter_map(|tool| {
2913 let mut tool_name = resolve_tool_name(tool);
2914 if !duplicated_tool_names.contains(&tool_name) {
2915 return Some((tool_name, tool.clone()));
2916 }
2917 match tool.source() {
2918 assistant_tool::ToolSource::Native => {
2919 // Built-in tools always keep their original name
2920 Some((tool_name, tool.clone()))
2921 }
2922 assistant_tool::ToolSource::ContextServer { id } => {
2923 // Context server tools are prefixed with the context server ID, and truncated if necessary
2924 tool_name.insert(0, '_');
2925 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
2926 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
2927 let mut id = id.to_string();
2928 id.truncate(len);
2929 tool_name.insert_str(0, &id);
2930 } else {
2931 tool_name.insert_str(0, &id);
2932 }
2933
2934 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
2935
2936 if seen_tool_names.contains(&tool_name) {
2937 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
2938 None
2939 } else {
2940 Some((tool_name, tool.clone()))
2941 }
2942 }
2943 }
2944 })
2945 .collect()
2946}
2947
2948#[cfg(test)]
2949mod tests {
2950 use super::*;
2951 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2952 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
2953 use assistant_tool::ToolRegistry;
2954 use editor::EditorSettings;
2955 use gpui::TestAppContext;
2956 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2957 use project::{FakeFs, Project};
2958 use prompt_store::PromptBuilder;
2959 use serde_json::json;
2960 use settings::{Settings, SettingsStore};
2961 use std::sync::Arc;
2962 use theme::ThemeSettings;
2963 use ui::IconName;
2964 use util::path;
2965 use workspace::Workspace;
2966
2967 #[gpui::test]
2968 async fn test_message_with_context(cx: &mut TestAppContext) {
2969 init_test_settings(cx);
2970
2971 let project = create_test_project(
2972 cx,
2973 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2974 )
2975 .await;
2976
2977 let (_workspace, _thread_store, thread, context_store, model) =
2978 setup_test_environment(cx, project.clone()).await;
2979
2980 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2981 .await
2982 .unwrap();
2983
2984 let context =
2985 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2986 let loaded_context = cx
2987 .update(|cx| load_context(vec![context], &project, &None, cx))
2988 .await;
2989
2990 // Insert user message with context
2991 let message_id = thread.update(cx, |thread, cx| {
2992 thread.insert_user_message(
2993 "Please explain this code",
2994 loaded_context,
2995 None,
2996 Vec::new(),
2997 cx,
2998 )
2999 });
3000
3001 // Check content and context in message object
3002 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3003
3004 // Use different path format strings based on platform for the test
3005 #[cfg(windows)]
3006 let path_part = r"test\code.rs";
3007 #[cfg(not(windows))]
3008 let path_part = "test/code.rs";
3009
3010 let expected_context = format!(
3011 r#"
3012<context>
3013The following items were attached by the user. They are up-to-date and don't need to be re-read.
3014
3015<files>
3016```rs {path_part}
3017fn main() {{
3018 println!("Hello, world!");
3019}}
3020```
3021</files>
3022</context>
3023"#
3024 );
3025
3026 assert_eq!(message.role, Role::User);
3027 assert_eq!(message.segments.len(), 1);
3028 assert_eq!(
3029 message.segments[0],
3030 MessageSegment::Text("Please explain this code".to_string())
3031 );
3032 assert_eq!(message.loaded_context.text, expected_context);
3033
3034 // Check message in request
3035 let request = thread.update(cx, |thread, cx| {
3036 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3037 });
3038
3039 assert_eq!(request.messages.len(), 2);
3040 let expected_full_message = format!("{}Please explain this code", expected_context);
3041 assert_eq!(request.messages[1].string_contents(), expected_full_message);
3042 }
3043
3044 #[gpui::test]
3045 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
3046 init_test_settings(cx);
3047
3048 let project = create_test_project(
3049 cx,
3050 json!({
3051 "file1.rs": "fn function1() {}\n",
3052 "file2.rs": "fn function2() {}\n",
3053 "file3.rs": "fn function3() {}\n",
3054 "file4.rs": "fn function4() {}\n",
3055 }),
3056 )
3057 .await;
3058
3059 let (_, _thread_store, thread, context_store, model) =
3060 setup_test_environment(cx, project.clone()).await;
3061
3062 // First message with context 1
3063 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
3064 .await
3065 .unwrap();
3066 let new_contexts = context_store.update(cx, |store, cx| {
3067 store.new_context_for_thread(thread.read(cx), None)
3068 });
3069 assert_eq!(new_contexts.len(), 1);
3070 let loaded_context = cx
3071 .update(|cx| load_context(new_contexts, &project, &None, cx))
3072 .await;
3073 let message1_id = thread.update(cx, |thread, cx| {
3074 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
3075 });
3076
3077 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3078 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3079 .await
3080 .unwrap();
3081 let new_contexts = context_store.update(cx, |store, cx| {
3082 store.new_context_for_thread(thread.read(cx), None)
3083 });
3084 assert_eq!(new_contexts.len(), 1);
3085 let loaded_context = cx
3086 .update(|cx| load_context(new_contexts, &project, &None, cx))
3087 .await;
3088 let message2_id = thread.update(cx, |thread, cx| {
3089 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3090 });
3091
3092 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3093 //
3094 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3095 .await
3096 .unwrap();
3097 let new_contexts = context_store.update(cx, |store, cx| {
3098 store.new_context_for_thread(thread.read(cx), None)
3099 });
3100 assert_eq!(new_contexts.len(), 1);
3101 let loaded_context = cx
3102 .update(|cx| load_context(new_contexts, &project, &None, cx))
3103 .await;
3104 let message3_id = thread.update(cx, |thread, cx| {
3105 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3106 });
3107
3108 // Check what contexts are included in each message
3109 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3110 (
3111 thread.message(message1_id).unwrap().clone(),
3112 thread.message(message2_id).unwrap().clone(),
3113 thread.message(message3_id).unwrap().clone(),
3114 )
3115 });
3116
3117 // First message should include context 1
3118 assert!(message1.loaded_context.text.contains("file1.rs"));
3119
3120 // Second message should include only context 2 (not 1)
3121 assert!(!message2.loaded_context.text.contains("file1.rs"));
3122 assert!(message2.loaded_context.text.contains("file2.rs"));
3123
3124 // Third message should include only context 3 (not 1 or 2)
3125 assert!(!message3.loaded_context.text.contains("file1.rs"));
3126 assert!(!message3.loaded_context.text.contains("file2.rs"));
3127 assert!(message3.loaded_context.text.contains("file3.rs"));
3128
3129 // Check entire request to make sure all contexts are properly included
3130 let request = thread.update(cx, |thread, cx| {
3131 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3132 });
3133
3134 // The request should contain all 3 messages
3135 assert_eq!(request.messages.len(), 4);
3136
3137 // Check that the contexts are properly formatted in each message
3138 assert!(request.messages[1].string_contents().contains("file1.rs"));
3139 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3140 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3141
3142 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3143 assert!(request.messages[2].string_contents().contains("file2.rs"));
3144 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3145
3146 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3147 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3148 assert!(request.messages[3].string_contents().contains("file3.rs"));
3149
3150 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3151 .await
3152 .unwrap();
3153 let new_contexts = context_store.update(cx, |store, cx| {
3154 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3155 });
3156 assert_eq!(new_contexts.len(), 3);
3157 let loaded_context = cx
3158 .update(|cx| load_context(new_contexts, &project, &None, cx))
3159 .await
3160 .loaded_context;
3161
3162 assert!(!loaded_context.text.contains("file1.rs"));
3163 assert!(loaded_context.text.contains("file2.rs"));
3164 assert!(loaded_context.text.contains("file3.rs"));
3165 assert!(loaded_context.text.contains("file4.rs"));
3166
3167 let new_contexts = context_store.update(cx, |store, cx| {
3168 // Remove file4.rs
3169 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3170 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3171 });
3172 assert_eq!(new_contexts.len(), 2);
3173 let loaded_context = cx
3174 .update(|cx| load_context(new_contexts, &project, &None, cx))
3175 .await
3176 .loaded_context;
3177
3178 assert!(!loaded_context.text.contains("file1.rs"));
3179 assert!(loaded_context.text.contains("file2.rs"));
3180 assert!(loaded_context.text.contains("file3.rs"));
3181 assert!(!loaded_context.text.contains("file4.rs"));
3182
3183 let new_contexts = context_store.update(cx, |store, cx| {
3184 // Remove file3.rs
3185 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3186 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3187 });
3188 assert_eq!(new_contexts.len(), 1);
3189 let loaded_context = cx
3190 .update(|cx| load_context(new_contexts, &project, &None, cx))
3191 .await
3192 .loaded_context;
3193
3194 assert!(!loaded_context.text.contains("file1.rs"));
3195 assert!(loaded_context.text.contains("file2.rs"));
3196 assert!(!loaded_context.text.contains("file3.rs"));
3197 assert!(!loaded_context.text.contains("file4.rs"));
3198 }
3199
3200 #[gpui::test]
3201 async fn test_message_without_files(cx: &mut TestAppContext) {
3202 init_test_settings(cx);
3203
3204 let project = create_test_project(
3205 cx,
3206 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3207 )
3208 .await;
3209
3210 let (_, _thread_store, thread, _context_store, model) =
3211 setup_test_environment(cx, project.clone()).await;
3212
3213 // Insert user message without any context (empty context vector)
3214 let message_id = thread.update(cx, |thread, cx| {
3215 thread.insert_user_message(
3216 "What is the best way to learn Rust?",
3217 ContextLoadResult::default(),
3218 None,
3219 Vec::new(),
3220 cx,
3221 )
3222 });
3223
3224 // Check content and context in message object
3225 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3226
3227 // Context should be empty when no files are included
3228 assert_eq!(message.role, Role::User);
3229 assert_eq!(message.segments.len(), 1);
3230 assert_eq!(
3231 message.segments[0],
3232 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3233 );
3234 assert_eq!(message.loaded_context.text, "");
3235
3236 // Check message in request
3237 let request = thread.update(cx, |thread, cx| {
3238 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3239 });
3240
3241 assert_eq!(request.messages.len(), 2);
3242 assert_eq!(
3243 request.messages[1].string_contents(),
3244 "What is the best way to learn Rust?"
3245 );
3246
3247 // Add second message, also without context
3248 let message2_id = thread.update(cx, |thread, cx| {
3249 thread.insert_user_message(
3250 "Are there any good books?",
3251 ContextLoadResult::default(),
3252 None,
3253 Vec::new(),
3254 cx,
3255 )
3256 });
3257
3258 let message2 =
3259 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3260 assert_eq!(message2.loaded_context.text, "");
3261
3262 // Check that both messages appear in the request
3263 let request = thread.update(cx, |thread, cx| {
3264 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3265 });
3266
3267 assert_eq!(request.messages.len(), 3);
3268 assert_eq!(
3269 request.messages[1].string_contents(),
3270 "What is the best way to learn Rust?"
3271 );
3272 assert_eq!(
3273 request.messages[2].string_contents(),
3274 "Are there any good books?"
3275 );
3276 }
3277
3278 #[gpui::test]
3279 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3280 init_test_settings(cx);
3281
3282 let project = create_test_project(
3283 cx,
3284 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3285 )
3286 .await;
3287
3288 let (_workspace, thread_store, thread, _context_store, _model) =
3289 setup_test_environment(cx, project.clone()).await;
3290
3291 // Check that we are starting with the default profile
3292 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3293 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3294 assert_eq!(
3295 profile,
3296 AgentProfile::new(AgentProfileId::default(), tool_set)
3297 );
3298 }
3299
3300 #[gpui::test]
3301 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3302 init_test_settings(cx);
3303
3304 let project = create_test_project(
3305 cx,
3306 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3307 )
3308 .await;
3309
3310 let (_workspace, thread_store, thread, _context_store, _model) =
3311 setup_test_environment(cx, project.clone()).await;
3312
3313 // Profile gets serialized with default values
3314 let serialized = thread
3315 .update(cx, |thread, cx| thread.serialize(cx))
3316 .await
3317 .unwrap();
3318
3319 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3320
3321 let deserialized = cx.update(|cx| {
3322 thread.update(cx, |thread, cx| {
3323 Thread::deserialize(
3324 thread.id.clone(),
3325 serialized,
3326 thread.project.clone(),
3327 thread.tools.clone(),
3328 thread.prompt_builder.clone(),
3329 thread.project_context.clone(),
3330 None,
3331 cx,
3332 )
3333 })
3334 });
3335 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3336
3337 assert_eq!(
3338 deserialized.profile,
3339 AgentProfile::new(AgentProfileId::default(), tool_set)
3340 );
3341 }
3342
3343 #[gpui::test]
3344 async fn test_temperature_setting(cx: &mut TestAppContext) {
3345 init_test_settings(cx);
3346
3347 let project = create_test_project(
3348 cx,
3349 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3350 )
3351 .await;
3352
3353 let (_workspace, _thread_store, thread, _context_store, model) =
3354 setup_test_environment(cx, project.clone()).await;
3355
3356 // Both model and provider
3357 cx.update(|cx| {
3358 AgentSettings::override_global(
3359 AgentSettings {
3360 model_parameters: vec![LanguageModelParameters {
3361 provider: Some(model.provider_id().0.to_string().into()),
3362 model: Some(model.id().0.clone()),
3363 temperature: Some(0.66),
3364 }],
3365 ..AgentSettings::get_global(cx).clone()
3366 },
3367 cx,
3368 );
3369 });
3370
3371 let request = thread.update(cx, |thread, cx| {
3372 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3373 });
3374 assert_eq!(request.temperature, Some(0.66));
3375
3376 // Only model
3377 cx.update(|cx| {
3378 AgentSettings::override_global(
3379 AgentSettings {
3380 model_parameters: vec![LanguageModelParameters {
3381 provider: None,
3382 model: Some(model.id().0.clone()),
3383 temperature: Some(0.66),
3384 }],
3385 ..AgentSettings::get_global(cx).clone()
3386 },
3387 cx,
3388 );
3389 });
3390
3391 let request = thread.update(cx, |thread, cx| {
3392 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3393 });
3394 assert_eq!(request.temperature, Some(0.66));
3395
3396 // Only provider
3397 cx.update(|cx| {
3398 AgentSettings::override_global(
3399 AgentSettings {
3400 model_parameters: vec![LanguageModelParameters {
3401 provider: Some(model.provider_id().0.to_string().into()),
3402 model: None,
3403 temperature: Some(0.66),
3404 }],
3405 ..AgentSettings::get_global(cx).clone()
3406 },
3407 cx,
3408 );
3409 });
3410
3411 let request = thread.update(cx, |thread, cx| {
3412 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3413 });
3414 assert_eq!(request.temperature, Some(0.66));
3415
3416 // Same model name, different provider
3417 cx.update(|cx| {
3418 AgentSettings::override_global(
3419 AgentSettings {
3420 model_parameters: vec![LanguageModelParameters {
3421 provider: Some("anthropic".into()),
3422 model: Some(model.id().0.clone()),
3423 temperature: Some(0.66),
3424 }],
3425 ..AgentSettings::get_global(cx).clone()
3426 },
3427 cx,
3428 );
3429 });
3430
3431 let request = thread.update(cx, |thread, cx| {
3432 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3433 });
3434 assert_eq!(request.temperature, None);
3435 }
3436
3437 #[gpui::test]
3438 async fn test_thread_summary(cx: &mut TestAppContext) {
3439 init_test_settings(cx);
3440
3441 let project = create_test_project(cx, json!({})).await;
3442
3443 let (_, _thread_store, thread, _context_store, model) =
3444 setup_test_environment(cx, project.clone()).await;
3445
3446 // Initial state should be pending
3447 thread.read_with(cx, |thread, _| {
3448 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3449 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3450 });
3451
3452 // Manually setting the summary should not be allowed in this state
3453 thread.update(cx, |thread, cx| {
3454 thread.set_summary("This should not work", cx);
3455 });
3456
3457 thread.read_with(cx, |thread, _| {
3458 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3459 });
3460
3461 // Send a message
3462 thread.update(cx, |thread, cx| {
3463 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3464 thread.send_to_model(
3465 model.clone(),
3466 CompletionIntent::ThreadSummarization,
3467 None,
3468 cx,
3469 );
3470 });
3471
3472 let fake_model = model.as_fake();
3473 simulate_successful_response(&fake_model, cx);
3474
3475 // Should start generating summary when there are >= 2 messages
3476 thread.read_with(cx, |thread, _| {
3477 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3478 });
3479
3480 // Should not be able to set the summary while generating
3481 thread.update(cx, |thread, cx| {
3482 thread.set_summary("This should not work either", cx);
3483 });
3484
3485 thread.read_with(cx, |thread, _| {
3486 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3487 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3488 });
3489
3490 cx.run_until_parked();
3491 fake_model.stream_last_completion_response("Brief");
3492 fake_model.stream_last_completion_response(" Introduction");
3493 fake_model.end_last_completion_stream();
3494 cx.run_until_parked();
3495
3496 // Summary should be set
3497 thread.read_with(cx, |thread, _| {
3498 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3499 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3500 });
3501
3502 // Now we should be able to set a summary
3503 thread.update(cx, |thread, cx| {
3504 thread.set_summary("Brief Intro", cx);
3505 });
3506
3507 thread.read_with(cx, |thread, _| {
3508 assert_eq!(thread.summary().or_default(), "Brief Intro");
3509 });
3510
3511 // Test setting an empty summary (should default to DEFAULT)
3512 thread.update(cx, |thread, cx| {
3513 thread.set_summary("", cx);
3514 });
3515
3516 thread.read_with(cx, |thread, _| {
3517 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3518 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3519 });
3520 }
3521
3522 #[gpui::test]
3523 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3524 init_test_settings(cx);
3525
3526 let project = create_test_project(cx, json!({})).await;
3527
3528 let (_, _thread_store, thread, _context_store, model) =
3529 setup_test_environment(cx, project.clone()).await;
3530
3531 test_summarize_error(&model, &thread, cx);
3532
3533 // Now we should be able to set a summary
3534 thread.update(cx, |thread, cx| {
3535 thread.set_summary("Brief Intro", cx);
3536 });
3537
3538 thread.read_with(cx, |thread, _| {
3539 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3540 assert_eq!(thread.summary().or_default(), "Brief Intro");
3541 });
3542 }
3543
3544 #[gpui::test]
3545 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3546 init_test_settings(cx);
3547
3548 let project = create_test_project(cx, json!({})).await;
3549
3550 let (_, _thread_store, thread, _context_store, model) =
3551 setup_test_environment(cx, project.clone()).await;
3552
3553 test_summarize_error(&model, &thread, cx);
3554
3555 // Sending another message should not trigger another summarize request
3556 thread.update(cx, |thread, cx| {
3557 thread.insert_user_message(
3558 "How are you?",
3559 ContextLoadResult::default(),
3560 None,
3561 vec![],
3562 cx,
3563 );
3564 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3565 });
3566
3567 let fake_model = model.as_fake();
3568 simulate_successful_response(&fake_model, cx);
3569
3570 thread.read_with(cx, |thread, _| {
3571 // State is still Error, not Generating
3572 assert!(matches!(thread.summary(), ThreadSummary::Error));
3573 });
3574
3575 // But the summarize request can be invoked manually
3576 thread.update(cx, |thread, cx| {
3577 thread.summarize(cx);
3578 });
3579
3580 thread.read_with(cx, |thread, _| {
3581 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3582 });
3583
3584 cx.run_until_parked();
3585 fake_model.stream_last_completion_response("A successful summary");
3586 fake_model.end_last_completion_stream();
3587 cx.run_until_parked();
3588
3589 thread.read_with(cx, |thread, _| {
3590 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3591 assert_eq!(thread.summary().or_default(), "A successful summary");
3592 });
3593 }
3594
3595 #[gpui::test]
3596 fn test_resolve_tool_name_conflicts() {
3597 use assistant_tool::{Tool, ToolSource};
3598
3599 assert_resolve_tool_name_conflicts(
3600 vec![
3601 TestTool::new("tool1", ToolSource::Native),
3602 TestTool::new("tool2", ToolSource::Native),
3603 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3604 ],
3605 vec!["tool1", "tool2", "tool3"],
3606 );
3607
3608 assert_resolve_tool_name_conflicts(
3609 vec![
3610 TestTool::new("tool1", ToolSource::Native),
3611 TestTool::new("tool2", ToolSource::Native),
3612 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3613 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3614 ],
3615 vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"],
3616 );
3617
3618 assert_resolve_tool_name_conflicts(
3619 vec![
3620 TestTool::new("tool1", ToolSource::Native),
3621 TestTool::new("tool2", ToolSource::Native),
3622 TestTool::new("tool3", ToolSource::Native),
3623 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
3624 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
3625 ],
3626 vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"],
3627 );
3628
3629 // Test that tool with very long name is always truncated
3630 assert_resolve_tool_name_conflicts(
3631 vec![TestTool::new(
3632 "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah",
3633 ToolSource::Native,
3634 )],
3635 vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"],
3636 );
3637
3638 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
3639 assert_resolve_tool_name_conflicts(
3640 vec![
3641 TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native),
3642 TestTool::new(
3643 "tool-with-very-very-very-long-name",
3644 ToolSource::ContextServer {
3645 id: "mcp-with-very-very-very-long-name".into(),
3646 },
3647 ),
3648 ],
3649 vec![
3650 "tool-with-very-very-very-long-name",
3651 "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name",
3652 ],
3653 );
3654
3655 fn assert_resolve_tool_name_conflicts(
3656 tools: Vec<TestTool>,
3657 expected: Vec<impl Into<String>>,
3658 ) {
3659 let tools: Vec<Arc<dyn Tool>> = tools
3660 .into_iter()
3661 .map(|t| Arc::new(t) as Arc<dyn Tool>)
3662 .collect();
3663 let tools = resolve_tool_name_conflicts(&tools);
3664 assert_eq!(tools.len(), expected.len());
3665 for (i, expected_name) in expected.into_iter().enumerate() {
3666 let expected_name = expected_name.into();
3667 let actual_name = &tools[i].0;
3668 assert_eq!(
3669 actual_name, &expected_name,
3670 "Expected '{}' got '{}' at index {}",
3671 expected_name, actual_name, i
3672 );
3673 }
3674 }
3675
3676 struct TestTool {
3677 name: String,
3678 source: ToolSource,
3679 }
3680
3681 impl TestTool {
3682 fn new(name: impl Into<String>, source: ToolSource) -> Self {
3683 Self {
3684 name: name.into(),
3685 source,
3686 }
3687 }
3688 }
3689
3690 impl Tool for TestTool {
3691 fn name(&self) -> String {
3692 self.name.clone()
3693 }
3694
3695 fn icon(&self) -> IconName {
3696 IconName::Ai
3697 }
3698
3699 fn may_perform_edits(&self) -> bool {
3700 false
3701 }
3702
3703 fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
3704 true
3705 }
3706
3707 fn source(&self) -> ToolSource {
3708 self.source.clone()
3709 }
3710
3711 fn description(&self) -> String {
3712 "Test tool".to_string()
3713 }
3714
3715 fn ui_text(&self, _input: &serde_json::Value) -> String {
3716 "Test tool".to_string()
3717 }
3718
3719 fn run(
3720 self: Arc<Self>,
3721 _input: serde_json::Value,
3722 _request: Arc<LanguageModelRequest>,
3723 _project: Entity<Project>,
3724 _action_log: Entity<ActionLog>,
3725 _model: Arc<dyn LanguageModel>,
3726 _window: Option<AnyWindowHandle>,
3727 _cx: &mut App,
3728 ) -> assistant_tool::ToolResult {
3729 assistant_tool::ToolResult {
3730 output: Task::ready(Err(anyhow::anyhow!("No content"))),
3731 card: None,
3732 }
3733 }
3734 }
3735 }
3736
3737 fn test_summarize_error(
3738 model: &Arc<dyn LanguageModel>,
3739 thread: &Entity<Thread>,
3740 cx: &mut TestAppContext,
3741 ) {
3742 thread.update(cx, |thread, cx| {
3743 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3744 thread.send_to_model(
3745 model.clone(),
3746 CompletionIntent::ThreadSummarization,
3747 None,
3748 cx,
3749 );
3750 });
3751
3752 let fake_model = model.as_fake();
3753 simulate_successful_response(&fake_model, cx);
3754
3755 thread.read_with(cx, |thread, _| {
3756 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3757 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3758 });
3759
3760 // Simulate summary request ending
3761 cx.run_until_parked();
3762 fake_model.end_last_completion_stream();
3763 cx.run_until_parked();
3764
3765 // State is set to Error and default message
3766 thread.read_with(cx, |thread, _| {
3767 assert!(matches!(thread.summary(), ThreadSummary::Error));
3768 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3769 });
3770 }
3771
3772 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3773 cx.run_until_parked();
3774 fake_model.stream_last_completion_response("Assistant response");
3775 fake_model.end_last_completion_stream();
3776 cx.run_until_parked();
3777 }
3778
3779 fn init_test_settings(cx: &mut TestAppContext) {
3780 cx.update(|cx| {
3781 let settings_store = SettingsStore::test(cx);
3782 cx.set_global(settings_store);
3783 language::init(cx);
3784 Project::init_settings(cx);
3785 AgentSettings::register(cx);
3786 prompt_store::init(cx);
3787 thread_store::init(cx);
3788 workspace::init_settings(cx);
3789 language_model::init_settings(cx);
3790 ThemeSettings::register(cx);
3791 EditorSettings::register(cx);
3792 ToolRegistry::default_global(cx);
3793 });
3794 }
3795
3796 // Helper to create a test project with test files
3797 async fn create_test_project(
3798 cx: &mut TestAppContext,
3799 files: serde_json::Value,
3800 ) -> Entity<Project> {
3801 let fs = FakeFs::new(cx.executor());
3802 fs.insert_tree(path!("/test"), files).await;
3803 Project::test(fs, [path!("/test").as_ref()], cx).await
3804 }
3805
3806 async fn setup_test_environment(
3807 cx: &mut TestAppContext,
3808 project: Entity<Project>,
3809 ) -> (
3810 Entity<Workspace>,
3811 Entity<ThreadStore>,
3812 Entity<Thread>,
3813 Entity<ContextStore>,
3814 Arc<dyn LanguageModel>,
3815 ) {
3816 let (workspace, cx) =
3817 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3818
3819 let thread_store = cx
3820 .update(|_, cx| {
3821 ThreadStore::load(
3822 project.clone(),
3823 cx.new(|_| ToolWorkingSet::default()),
3824 None,
3825 Arc::new(PromptBuilder::new(None).unwrap()),
3826 cx,
3827 )
3828 })
3829 .await
3830 .unwrap();
3831
3832 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3833 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3834
3835 let provider = Arc::new(FakeLanguageModelProvider);
3836 let model = provider.test_model();
3837 let model: Arc<dyn LanguageModel> = Arc::new(model);
3838
3839 cx.update(|_, cx| {
3840 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3841 registry.set_default_model(
3842 Some(ConfiguredModel {
3843 provider: provider.clone(),
3844 model: model.clone(),
3845 }),
3846 cx,
3847 );
3848 registry.set_thread_summary_model(
3849 Some(ConfiguredModel {
3850 provider,
3851 model: model.clone(),
3852 }),
3853 cx,
3854 );
3855 })
3856 });
3857
3858 (workspace, thread_store, thread, context_store, model)
3859 }
3860
3861 async fn add_file_to_context(
3862 project: &Entity<Project>,
3863 context_store: &Entity<ContextStore>,
3864 path: &str,
3865 cx: &mut TestAppContext,
3866 ) -> Result<Entity<language::Buffer>> {
3867 let buffer_path = project
3868 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3869 .unwrap();
3870
3871 let buffer = project
3872 .update(cx, |project, cx| {
3873 project.open_buffer(buffer_path.clone(), cx)
3874 })
3875 .await
3876 .unwrap();
3877
3878 context_store.update(cx, |context_store, cx| {
3879 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3880 });
3881
3882 Ok(buffer)
3883 }
3884}