1use std::io::Write;
2use std::ops::Range;
3use std::sync::Arc;
4use std::time::Instant;
5
6use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
7use anyhow::{Result, anyhow};
8use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
9use chrono::{DateTime, Utc};
10use client::{ModelRequestUsage, RequestUsage};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
27 TokenUsage,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40
41use uuid::Uuid;
42use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
43
44use crate::ThreadStore;
45use crate::agent_profile::AgentProfile;
46use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
47use crate::thread_store::{
48 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
49 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
50};
51use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
52
53#[derive(
54 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
55)]
56pub struct ThreadId(Arc<str>);
57
58impl ThreadId {
59 pub fn new() -> Self {
60 Self(Uuid::new_v4().to_string().into())
61 }
62}
63
64impl std::fmt::Display for ThreadId {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 write!(f, "{}", self.0)
67 }
68}
69
70impl From<&str> for ThreadId {
71 fn from(value: &str) -> Self {
72 Self(value.into())
73 }
74}
75
76/// The ID of the user prompt that initiated a request.
77///
78/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
79#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
80pub struct PromptId(Arc<str>);
81
82impl PromptId {
83 pub fn new() -> Self {
84 Self(Uuid::new_v4().to_string().into())
85 }
86}
87
88impl std::fmt::Display for PromptId {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(f, "{}", self.0)
91 }
92}
93
94#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
95pub struct MessageId(pub(crate) usize);
96
97impl MessageId {
98 fn post_inc(&mut self) -> Self {
99 Self(post_inc(&mut self.0))
100 }
101}
102
103/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
104#[derive(Clone, Debug)]
105pub struct MessageCrease {
106 pub range: Range<usize>,
107 pub metadata: CreaseMetadata,
108 /// None for a deserialized message, Some otherwise.
109 pub context: Option<AgentContextHandle>,
110}
111
112/// A message in a [`Thread`].
113#[derive(Debug, Clone)]
114pub struct Message {
115 pub id: MessageId,
116 pub role: Role,
117 pub segments: Vec<MessageSegment>,
118 pub loaded_context: LoadedContext,
119 pub creases: Vec<MessageCrease>,
120 pub is_hidden: bool,
121}
122
123impl Message {
124 /// Returns whether the message contains any meaningful text that should be displayed
125 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
126 pub fn should_display_content(&self) -> bool {
127 self.segments.iter().all(|segment| segment.should_display())
128 }
129
130 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
131 if let Some(MessageSegment::Thinking {
132 text: segment,
133 signature: current_signature,
134 }) = self.segments.last_mut()
135 {
136 if let Some(signature) = signature {
137 *current_signature = Some(signature);
138 }
139 segment.push_str(text);
140 } else {
141 self.segments.push(MessageSegment::Thinking {
142 text: text.to_string(),
143 signature,
144 });
145 }
146 }
147
148 pub fn push_text(&mut self, text: &str) {
149 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
150 segment.push_str(text);
151 } else {
152 self.segments.push(MessageSegment::Text(text.to_string()));
153 }
154 }
155
156 pub fn to_string(&self) -> String {
157 let mut result = String::new();
158
159 if !self.loaded_context.text.is_empty() {
160 result.push_str(&self.loaded_context.text);
161 }
162
163 for segment in &self.segments {
164 match segment {
165 MessageSegment::Text(text) => result.push_str(text),
166 MessageSegment::Thinking { text, .. } => {
167 result.push_str("<think>\n");
168 result.push_str(text);
169 result.push_str("\n</think>");
170 }
171 MessageSegment::RedactedThinking(_) => {}
172 }
173 }
174
175 result
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
180pub enum MessageSegment {
181 Text(String),
182 Thinking {
183 text: String,
184 signature: Option<String>,
185 },
186 RedactedThinking(Vec<u8>),
187}
188
189impl MessageSegment {
190 pub fn should_display(&self) -> bool {
191 match self {
192 Self::Text(text) => text.is_empty(),
193 Self::Thinking { text, .. } => text.is_empty(),
194 Self::RedactedThinking(_) => false,
195 }
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
200pub struct ProjectSnapshot {
201 pub worktree_snapshots: Vec<WorktreeSnapshot>,
202 pub unsaved_buffer_paths: Vec<String>,
203 pub timestamp: DateTime<Utc>,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
207pub struct WorktreeSnapshot {
208 pub worktree_path: String,
209 pub git_state: Option<GitState>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
213pub struct GitState {
214 pub remote_url: Option<String>,
215 pub head_sha: Option<String>,
216 pub current_branch: Option<String>,
217 pub diff: Option<String>,
218}
219
220#[derive(Clone, Debug)]
221pub struct ThreadCheckpoint {
222 message_id: MessageId,
223 git_checkpoint: GitStoreCheckpoint,
224}
225
226#[derive(Copy, Clone, Debug, PartialEq, Eq)]
227pub enum ThreadFeedback {
228 Positive,
229 Negative,
230}
231
232pub enum LastRestoreCheckpoint {
233 Pending {
234 message_id: MessageId,
235 },
236 Error {
237 message_id: MessageId,
238 error: String,
239 },
240}
241
242impl LastRestoreCheckpoint {
243 pub fn message_id(&self) -> MessageId {
244 match self {
245 LastRestoreCheckpoint::Pending { message_id } => *message_id,
246 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
247 }
248 }
249}
250
251#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
252pub enum DetailedSummaryState {
253 #[default]
254 NotGenerated,
255 Generating {
256 message_id: MessageId,
257 },
258 Generated {
259 text: SharedString,
260 message_id: MessageId,
261 },
262}
263
264impl DetailedSummaryState {
265 fn text(&self) -> Option<SharedString> {
266 if let Self::Generated { text, .. } = self {
267 Some(text.clone())
268 } else {
269 None
270 }
271 }
272}
273
274#[derive(Default, Debug)]
275pub struct TotalTokenUsage {
276 pub total: u64,
277 pub max: u64,
278}
279
280impl TotalTokenUsage {
281 pub fn ratio(&self) -> TokenUsageRatio {
282 #[cfg(debug_assertions)]
283 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
284 .unwrap_or("0.8".to_string())
285 .parse()
286 .unwrap();
287 #[cfg(not(debug_assertions))]
288 let warning_threshold: f32 = 0.8;
289
290 // When the maximum is unknown because there is no selected model,
291 // avoid showing the token limit warning.
292 if self.max == 0 {
293 TokenUsageRatio::Normal
294 } else if self.total >= self.max {
295 TokenUsageRatio::Exceeded
296 } else if self.total as f32 / self.max as f32 >= warning_threshold {
297 TokenUsageRatio::Warning
298 } else {
299 TokenUsageRatio::Normal
300 }
301 }
302
303 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
304 TotalTokenUsage {
305 total: self.total + tokens,
306 max: self.max,
307 }
308 }
309}
310
311#[derive(Debug, Default, PartialEq, Eq)]
312pub enum TokenUsageRatio {
313 #[default]
314 Normal,
315 Warning,
316 Exceeded,
317}
318
319#[derive(Debug, Clone, Copy)]
320pub enum QueueState {
321 Sending,
322 Queued { position: usize },
323 Started,
324}
325
326/// A thread of conversation with the LLM.
327pub struct Thread {
328 id: ThreadId,
329 updated_at: DateTime<Utc>,
330 summary: ThreadSummary,
331 pending_summary: Task<Option<()>>,
332 detailed_summary_task: Task<Option<()>>,
333 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
334 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
335 completion_mode: agent_settings::CompletionMode,
336 messages: Vec<Message>,
337 next_message_id: MessageId,
338 last_prompt_id: PromptId,
339 project_context: SharedProjectContext,
340 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
341 completion_count: usize,
342 pending_completions: Vec<PendingCompletion>,
343 project: Entity<Project>,
344 prompt_builder: Arc<PromptBuilder>,
345 tools: Entity<ToolWorkingSet>,
346 tool_use: ToolUseState,
347 action_log: Entity<ActionLog>,
348 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
349 pending_checkpoint: Option<ThreadCheckpoint>,
350 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
351 request_token_usage: Vec<TokenUsage>,
352 cumulative_token_usage: TokenUsage,
353 exceeded_window_error: Option<ExceededWindowError>,
354 tool_use_limit_reached: bool,
355 feedback: Option<ThreadFeedback>,
356 message_feedback: HashMap<MessageId, ThreadFeedback>,
357 last_auto_capture_at: Option<Instant>,
358 last_received_chunk_at: Option<Instant>,
359 request_callback: Option<
360 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
361 >,
362 remaining_turns: u32,
363 configured_model: Option<ConfiguredModel>,
364 profile: AgentProfile,
365}
366
367#[derive(Clone, Debug, PartialEq, Eq)]
368pub enum ThreadSummary {
369 Pending,
370 Generating,
371 Ready(SharedString),
372 Error,
373}
374
375impl ThreadSummary {
376 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
377
378 pub fn or_default(&self) -> SharedString {
379 self.unwrap_or(Self::DEFAULT)
380 }
381
382 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
383 self.ready().unwrap_or_else(|| message.into())
384 }
385
386 pub fn ready(&self) -> Option<SharedString> {
387 match self {
388 ThreadSummary::Ready(summary) => Some(summary.clone()),
389 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
390 }
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
395pub struct ExceededWindowError {
396 /// Model used when last message exceeded context window
397 model_id: LanguageModelId,
398 /// Token count including last message
399 token_count: u64,
400}
401
402impl Thread {
403 pub fn new(
404 project: Entity<Project>,
405 tools: Entity<ToolWorkingSet>,
406 prompt_builder: Arc<PromptBuilder>,
407 system_prompt: SharedProjectContext,
408 cx: &mut Context<Self>,
409 ) -> Self {
410 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
411 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
412 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
413
414 Self {
415 id: ThreadId::new(),
416 updated_at: Utc::now(),
417 summary: ThreadSummary::Pending,
418 pending_summary: Task::ready(None),
419 detailed_summary_task: Task::ready(None),
420 detailed_summary_tx,
421 detailed_summary_rx,
422 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
423 messages: Vec::new(),
424 next_message_id: MessageId(0),
425 last_prompt_id: PromptId::new(),
426 project_context: system_prompt,
427 checkpoints_by_message: HashMap::default(),
428 completion_count: 0,
429 pending_completions: Vec::new(),
430 project: project.clone(),
431 prompt_builder,
432 tools: tools.clone(),
433 last_restore_checkpoint: None,
434 pending_checkpoint: None,
435 tool_use: ToolUseState::new(tools.clone()),
436 action_log: cx.new(|_| ActionLog::new(project.clone())),
437 initial_project_snapshot: {
438 let project_snapshot = Self::project_snapshot(project, cx);
439 cx.foreground_executor()
440 .spawn(async move { Some(project_snapshot.await) })
441 .shared()
442 },
443 request_token_usage: Vec::new(),
444 cumulative_token_usage: TokenUsage::default(),
445 exceeded_window_error: None,
446 tool_use_limit_reached: false,
447 feedback: None,
448 message_feedback: HashMap::default(),
449 last_auto_capture_at: None,
450 last_received_chunk_at: None,
451 request_callback: None,
452 remaining_turns: u32::MAX,
453 configured_model,
454 profile: AgentProfile::new(profile_id, tools),
455 }
456 }
457
458 pub fn deserialize(
459 id: ThreadId,
460 serialized: SerializedThread,
461 project: Entity<Project>,
462 tools: Entity<ToolWorkingSet>,
463 prompt_builder: Arc<PromptBuilder>,
464 project_context: SharedProjectContext,
465 window: Option<&mut Window>, // None in headless mode
466 cx: &mut Context<Self>,
467 ) -> Self {
468 let next_message_id = MessageId(
469 serialized
470 .messages
471 .last()
472 .map(|message| message.id.0 + 1)
473 .unwrap_or(0),
474 );
475 let tool_use = ToolUseState::from_serialized_messages(
476 tools.clone(),
477 &serialized.messages,
478 project.clone(),
479 window,
480 cx,
481 );
482 let (detailed_summary_tx, detailed_summary_rx) =
483 postage::watch::channel_with(serialized.detailed_summary_state);
484
485 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
486 serialized
487 .model
488 .and_then(|model| {
489 let model = SelectedModel {
490 provider: model.provider.clone().into(),
491 model: model.model.clone().into(),
492 };
493 registry.select_model(&model, cx)
494 })
495 .or_else(|| registry.default_model())
496 });
497
498 let completion_mode = serialized
499 .completion_mode
500 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
501 let profile_id = serialized
502 .profile
503 .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
504
505 Self {
506 id,
507 updated_at: serialized.updated_at,
508 summary: ThreadSummary::Ready(serialized.summary),
509 pending_summary: Task::ready(None),
510 detailed_summary_task: Task::ready(None),
511 detailed_summary_tx,
512 detailed_summary_rx,
513 completion_mode,
514 messages: serialized
515 .messages
516 .into_iter()
517 .map(|message| Message {
518 id: message.id,
519 role: message.role,
520 segments: message
521 .segments
522 .into_iter()
523 .map(|segment| match segment {
524 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
525 SerializedMessageSegment::Thinking { text, signature } => {
526 MessageSegment::Thinking { text, signature }
527 }
528 SerializedMessageSegment::RedactedThinking { data } => {
529 MessageSegment::RedactedThinking(data)
530 }
531 })
532 .collect(),
533 loaded_context: LoadedContext {
534 contexts: Vec::new(),
535 text: message.context,
536 images: Vec::new(),
537 },
538 creases: message
539 .creases
540 .into_iter()
541 .map(|crease| MessageCrease {
542 range: crease.start..crease.end,
543 metadata: CreaseMetadata {
544 icon_path: crease.icon_path,
545 label: crease.label,
546 },
547 context: None,
548 })
549 .collect(),
550 is_hidden: message.is_hidden,
551 })
552 .collect(),
553 next_message_id,
554 last_prompt_id: PromptId::new(),
555 project_context,
556 checkpoints_by_message: HashMap::default(),
557 completion_count: 0,
558 pending_completions: Vec::new(),
559 last_restore_checkpoint: None,
560 pending_checkpoint: None,
561 project: project.clone(),
562 prompt_builder,
563 tools: tools.clone(),
564 tool_use,
565 action_log: cx.new(|_| ActionLog::new(project)),
566 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
567 request_token_usage: serialized.request_token_usage,
568 cumulative_token_usage: serialized.cumulative_token_usage,
569 exceeded_window_error: None,
570 tool_use_limit_reached: serialized.tool_use_limit_reached,
571 feedback: None,
572 message_feedback: HashMap::default(),
573 last_auto_capture_at: None,
574 last_received_chunk_at: None,
575 request_callback: None,
576 remaining_turns: u32::MAX,
577 configured_model,
578 profile: AgentProfile::new(profile_id, tools),
579 }
580 }
581
582 pub fn set_request_callback(
583 &mut self,
584 callback: impl 'static
585 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
586 ) {
587 self.request_callback = Some(Box::new(callback));
588 }
589
590 pub fn id(&self) -> &ThreadId {
591 &self.id
592 }
593
594 pub fn profile(&self) -> &AgentProfile {
595 &self.profile
596 }
597
598 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
599 if &id != self.profile.id() {
600 self.profile = AgentProfile::new(id, self.tools.clone());
601 cx.emit(ThreadEvent::ProfileChanged);
602 }
603 }
604
605 pub fn is_empty(&self) -> bool {
606 self.messages.is_empty()
607 }
608
609 pub fn updated_at(&self) -> DateTime<Utc> {
610 self.updated_at
611 }
612
613 pub fn touch_updated_at(&mut self) {
614 self.updated_at = Utc::now();
615 }
616
617 pub fn advance_prompt_id(&mut self) {
618 self.last_prompt_id = PromptId::new();
619 }
620
621 pub fn project_context(&self) -> SharedProjectContext {
622 self.project_context.clone()
623 }
624
625 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
626 if self.configured_model.is_none() {
627 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
628 }
629 self.configured_model.clone()
630 }
631
632 pub fn configured_model(&self) -> Option<ConfiguredModel> {
633 self.configured_model.clone()
634 }
635
636 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
637 self.configured_model = model;
638 cx.notify();
639 }
640
641 pub fn summary(&self) -> &ThreadSummary {
642 &self.summary
643 }
644
645 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
646 let current_summary = match &self.summary {
647 ThreadSummary::Pending | ThreadSummary::Generating => return,
648 ThreadSummary::Ready(summary) => summary,
649 ThreadSummary::Error => &ThreadSummary::DEFAULT,
650 };
651
652 let mut new_summary = new_summary.into();
653
654 if new_summary.is_empty() {
655 new_summary = ThreadSummary::DEFAULT;
656 }
657
658 if current_summary != &new_summary {
659 self.summary = ThreadSummary::Ready(new_summary);
660 cx.emit(ThreadEvent::SummaryChanged);
661 }
662 }
663
664 pub fn completion_mode(&self) -> CompletionMode {
665 self.completion_mode
666 }
667
668 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
669 self.completion_mode = mode;
670 }
671
672 pub fn message(&self, id: MessageId) -> Option<&Message> {
673 let index = self
674 .messages
675 .binary_search_by(|message| message.id.cmp(&id))
676 .ok()?;
677
678 self.messages.get(index)
679 }
680
681 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
682 self.messages.iter()
683 }
684
685 pub fn is_generating(&self) -> bool {
686 !self.pending_completions.is_empty() || !self.all_tools_finished()
687 }
688
689 /// Indicates whether streaming of language model events is stale.
690 /// When `is_generating()` is false, this method returns `None`.
691 pub fn is_generation_stale(&self) -> Option<bool> {
692 const STALE_THRESHOLD: u128 = 250;
693
694 self.last_received_chunk_at
695 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
696 }
697
698 fn received_chunk(&mut self) {
699 self.last_received_chunk_at = Some(Instant::now());
700 }
701
702 pub fn queue_state(&self) -> Option<QueueState> {
703 self.pending_completions
704 .first()
705 .map(|pending_completion| pending_completion.queue_state)
706 }
707
708 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
709 &self.tools
710 }
711
712 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
713 self.tool_use
714 .pending_tool_uses()
715 .into_iter()
716 .find(|tool_use| &tool_use.id == id)
717 }
718
719 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
720 self.tool_use
721 .pending_tool_uses()
722 .into_iter()
723 .filter(|tool_use| tool_use.status.needs_confirmation())
724 }
725
726 pub fn has_pending_tool_uses(&self) -> bool {
727 !self.tool_use.pending_tool_uses().is_empty()
728 }
729
730 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
731 self.checkpoints_by_message.get(&id).cloned()
732 }
733
734 pub fn restore_checkpoint(
735 &mut self,
736 checkpoint: ThreadCheckpoint,
737 cx: &mut Context<Self>,
738 ) -> Task<Result<()>> {
739 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
740 message_id: checkpoint.message_id,
741 });
742 cx.emit(ThreadEvent::CheckpointChanged);
743 cx.notify();
744
745 let git_store = self.project().read(cx).git_store().clone();
746 let restore = git_store.update(cx, |git_store, cx| {
747 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
748 });
749
750 cx.spawn(async move |this, cx| {
751 let result = restore.await;
752 this.update(cx, |this, cx| {
753 if let Err(err) = result.as_ref() {
754 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
755 message_id: checkpoint.message_id,
756 error: err.to_string(),
757 });
758 } else {
759 this.truncate(checkpoint.message_id, cx);
760 this.last_restore_checkpoint = None;
761 }
762 this.pending_checkpoint = None;
763 cx.emit(ThreadEvent::CheckpointChanged);
764 cx.notify();
765 })?;
766 result
767 })
768 }
769
770 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
771 let pending_checkpoint = if self.is_generating() {
772 return;
773 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
774 checkpoint
775 } else {
776 return;
777 };
778
779 self.finalize_checkpoint(pending_checkpoint, cx);
780 }
781
782 fn finalize_checkpoint(
783 &mut self,
784 pending_checkpoint: ThreadCheckpoint,
785 cx: &mut Context<Self>,
786 ) {
787 let git_store = self.project.read(cx).git_store().clone();
788 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
789 cx.spawn(async move |this, cx| match final_checkpoint.await {
790 Ok(final_checkpoint) => {
791 let equal = git_store
792 .update(cx, |store, cx| {
793 store.compare_checkpoints(
794 pending_checkpoint.git_checkpoint.clone(),
795 final_checkpoint.clone(),
796 cx,
797 )
798 })?
799 .await
800 .unwrap_or(false);
801
802 if !equal {
803 this.update(cx, |this, cx| {
804 this.insert_checkpoint(pending_checkpoint, cx)
805 })?;
806 }
807
808 Ok(())
809 }
810 Err(_) => this.update(cx, |this, cx| {
811 this.insert_checkpoint(pending_checkpoint, cx)
812 }),
813 })
814 .detach();
815 }
816
817 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
818 self.checkpoints_by_message
819 .insert(checkpoint.message_id, checkpoint);
820 cx.emit(ThreadEvent::CheckpointChanged);
821 cx.notify();
822 }
823
824 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
825 self.last_restore_checkpoint.as_ref()
826 }
827
828 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
829 let Some(message_ix) = self
830 .messages
831 .iter()
832 .rposition(|message| message.id == message_id)
833 else {
834 return;
835 };
836 for deleted_message in self.messages.drain(message_ix..) {
837 self.checkpoints_by_message.remove(&deleted_message.id);
838 }
839 cx.notify();
840 }
841
842 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
843 self.messages
844 .iter()
845 .find(|message| message.id == id)
846 .into_iter()
847 .flat_map(|message| message.loaded_context.contexts.iter())
848 }
849
850 pub fn is_turn_end(&self, ix: usize) -> bool {
851 if self.messages.is_empty() {
852 return false;
853 }
854
855 if !self.is_generating() && ix == self.messages.len() - 1 {
856 return true;
857 }
858
859 let Some(message) = self.messages.get(ix) else {
860 return false;
861 };
862
863 if message.role != Role::Assistant {
864 return false;
865 }
866
867 self.messages
868 .get(ix + 1)
869 .and_then(|message| {
870 self.message(message.id)
871 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
872 })
873 .unwrap_or(false)
874 }
875
876 pub fn tool_use_limit_reached(&self) -> bool {
877 self.tool_use_limit_reached
878 }
879
880 /// Returns whether all of the tool uses have finished running.
881 pub fn all_tools_finished(&self) -> bool {
882 // If the only pending tool uses left are the ones with errors, then
883 // that means that we've finished running all of the pending tools.
884 self.tool_use
885 .pending_tool_uses()
886 .iter()
887 .all(|pending_tool_use| pending_tool_use.status.is_error())
888 }
889
890 /// Returns whether any pending tool uses may perform edits
891 pub fn has_pending_edit_tool_uses(&self) -> bool {
892 self.tool_use
893 .pending_tool_uses()
894 .iter()
895 .filter(|pending_tool_use| !pending_tool_use.status.is_error())
896 .any(|pending_tool_use| pending_tool_use.may_perform_edits)
897 }
898
899 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
900 self.tool_use.tool_uses_for_message(id, cx)
901 }
902
903 pub fn tool_results_for_message(
904 &self,
905 assistant_message_id: MessageId,
906 ) -> Vec<&LanguageModelToolResult> {
907 self.tool_use.tool_results_for_message(assistant_message_id)
908 }
909
910 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
911 self.tool_use.tool_result(id)
912 }
913
914 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
915 match &self.tool_use.tool_result(id)?.content {
916 LanguageModelToolResultContent::Text(text) => Some(text),
917 LanguageModelToolResultContent::Image(_) => {
918 // TODO: We should display image
919 None
920 }
921 }
922 }
923
924 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
925 self.tool_use.tool_result_card(id).cloned()
926 }
927
928 /// Return tools that are both enabled and supported by the model
929 pub fn available_tools(
930 &self,
931 cx: &App,
932 model: Arc<dyn LanguageModel>,
933 ) -> Vec<LanguageModelRequestTool> {
934 if model.supports_tools() {
935 self.profile
936 .enabled_tools(cx)
937 .into_iter()
938 .filter_map(|tool| {
939 // Skip tools that cannot be supported
940 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
941 Some(LanguageModelRequestTool {
942 name: tool.name(),
943 description: tool.description(),
944 input_schema,
945 })
946 })
947 .collect()
948 } else {
949 Vec::default()
950 }
951 }
952
953 pub fn insert_user_message(
954 &mut self,
955 text: impl Into<String>,
956 loaded_context: ContextLoadResult,
957 git_checkpoint: Option<GitStoreCheckpoint>,
958 creases: Vec<MessageCrease>,
959 cx: &mut Context<Self>,
960 ) -> MessageId {
961 if !loaded_context.referenced_buffers.is_empty() {
962 self.action_log.update(cx, |log, cx| {
963 for buffer in loaded_context.referenced_buffers {
964 log.buffer_read(buffer, cx);
965 }
966 });
967 }
968
969 let message_id = self.insert_message(
970 Role::User,
971 vec![MessageSegment::Text(text.into())],
972 loaded_context.loaded_context,
973 creases,
974 false,
975 cx,
976 );
977
978 if let Some(git_checkpoint) = git_checkpoint {
979 self.pending_checkpoint = Some(ThreadCheckpoint {
980 message_id,
981 git_checkpoint,
982 });
983 }
984
985 self.auto_capture_telemetry(cx);
986
987 message_id
988 }
989
990 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
991 let id = self.insert_message(
992 Role::User,
993 vec![MessageSegment::Text("Continue where you left off".into())],
994 LoadedContext::default(),
995 vec![],
996 true,
997 cx,
998 );
999 self.pending_checkpoint = None;
1000
1001 id
1002 }
1003
1004 pub fn insert_assistant_message(
1005 &mut self,
1006 segments: Vec<MessageSegment>,
1007 cx: &mut Context<Self>,
1008 ) -> MessageId {
1009 self.insert_message(
1010 Role::Assistant,
1011 segments,
1012 LoadedContext::default(),
1013 Vec::new(),
1014 false,
1015 cx,
1016 )
1017 }
1018
1019 pub fn insert_message(
1020 &mut self,
1021 role: Role,
1022 segments: Vec<MessageSegment>,
1023 loaded_context: LoadedContext,
1024 creases: Vec<MessageCrease>,
1025 is_hidden: bool,
1026 cx: &mut Context<Self>,
1027 ) -> MessageId {
1028 let id = self.next_message_id.post_inc();
1029 self.messages.push(Message {
1030 id,
1031 role,
1032 segments,
1033 loaded_context,
1034 creases,
1035 is_hidden,
1036 });
1037 self.touch_updated_at();
1038 cx.emit(ThreadEvent::MessageAdded(id));
1039 id
1040 }
1041
1042 pub fn edit_message(
1043 &mut self,
1044 id: MessageId,
1045 new_role: Role,
1046 new_segments: Vec<MessageSegment>,
1047 creases: Vec<MessageCrease>,
1048 loaded_context: Option<LoadedContext>,
1049 checkpoint: Option<GitStoreCheckpoint>,
1050 cx: &mut Context<Self>,
1051 ) -> bool {
1052 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1053 return false;
1054 };
1055 message.role = new_role;
1056 message.segments = new_segments;
1057 message.creases = creases;
1058 if let Some(context) = loaded_context {
1059 message.loaded_context = context;
1060 }
1061 if let Some(git_checkpoint) = checkpoint {
1062 self.checkpoints_by_message.insert(
1063 id,
1064 ThreadCheckpoint {
1065 message_id: id,
1066 git_checkpoint,
1067 },
1068 );
1069 }
1070 self.touch_updated_at();
1071 cx.emit(ThreadEvent::MessageEdited(id));
1072 true
1073 }
1074
1075 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1076 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1077 return false;
1078 };
1079 self.messages.remove(index);
1080 self.touch_updated_at();
1081 cx.emit(ThreadEvent::MessageDeleted(id));
1082 true
1083 }
1084
1085 /// Returns the representation of this [`Thread`] in a textual form.
1086 ///
1087 /// This is the representation we use when attaching a thread as context to another thread.
1088 pub fn text(&self) -> String {
1089 let mut text = String::new();
1090
1091 for message in &self.messages {
1092 text.push_str(match message.role {
1093 language_model::Role::User => "User:",
1094 language_model::Role::Assistant => "Agent:",
1095 language_model::Role::System => "System:",
1096 });
1097 text.push('\n');
1098
1099 for segment in &message.segments {
1100 match segment {
1101 MessageSegment::Text(content) => text.push_str(content),
1102 MessageSegment::Thinking { text: content, .. } => {
1103 text.push_str(&format!("<think>{}</think>", content))
1104 }
1105 MessageSegment::RedactedThinking(_) => {}
1106 }
1107 }
1108 text.push('\n');
1109 }
1110
1111 text
1112 }
1113
1114 /// Serializes this thread into a format for storage or telemetry.
1115 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1116 let initial_project_snapshot = self.initial_project_snapshot.clone();
1117 cx.spawn(async move |this, cx| {
1118 let initial_project_snapshot = initial_project_snapshot.await;
1119 this.read_with(cx, |this, cx| SerializedThread {
1120 version: SerializedThread::VERSION.to_string(),
1121 summary: this.summary().or_default(),
1122 updated_at: this.updated_at(),
1123 messages: this
1124 .messages()
1125 .map(|message| SerializedMessage {
1126 id: message.id,
1127 role: message.role,
1128 segments: message
1129 .segments
1130 .iter()
1131 .map(|segment| match segment {
1132 MessageSegment::Text(text) => {
1133 SerializedMessageSegment::Text { text: text.clone() }
1134 }
1135 MessageSegment::Thinking { text, signature } => {
1136 SerializedMessageSegment::Thinking {
1137 text: text.clone(),
1138 signature: signature.clone(),
1139 }
1140 }
1141 MessageSegment::RedactedThinking(data) => {
1142 SerializedMessageSegment::RedactedThinking {
1143 data: data.clone(),
1144 }
1145 }
1146 })
1147 .collect(),
1148 tool_uses: this
1149 .tool_uses_for_message(message.id, cx)
1150 .into_iter()
1151 .map(|tool_use| SerializedToolUse {
1152 id: tool_use.id,
1153 name: tool_use.name,
1154 input: tool_use.input,
1155 })
1156 .collect(),
1157 tool_results: this
1158 .tool_results_for_message(message.id)
1159 .into_iter()
1160 .map(|tool_result| SerializedToolResult {
1161 tool_use_id: tool_result.tool_use_id.clone(),
1162 is_error: tool_result.is_error,
1163 content: tool_result.content.clone(),
1164 output: tool_result.output.clone(),
1165 })
1166 .collect(),
1167 context: message.loaded_context.text.clone(),
1168 creases: message
1169 .creases
1170 .iter()
1171 .map(|crease| SerializedCrease {
1172 start: crease.range.start,
1173 end: crease.range.end,
1174 icon_path: crease.metadata.icon_path.clone(),
1175 label: crease.metadata.label.clone(),
1176 })
1177 .collect(),
1178 is_hidden: message.is_hidden,
1179 })
1180 .collect(),
1181 initial_project_snapshot,
1182 cumulative_token_usage: this.cumulative_token_usage,
1183 request_token_usage: this.request_token_usage.clone(),
1184 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1185 exceeded_window_error: this.exceeded_window_error.clone(),
1186 model: this
1187 .configured_model
1188 .as_ref()
1189 .map(|model| SerializedLanguageModel {
1190 provider: model.provider.id().0.to_string(),
1191 model: model.model.id().0.to_string(),
1192 }),
1193 completion_mode: Some(this.completion_mode),
1194 tool_use_limit_reached: this.tool_use_limit_reached,
1195 profile: Some(this.profile.id().clone()),
1196 })
1197 })
1198 }
1199
1200 pub fn remaining_turns(&self) -> u32 {
1201 self.remaining_turns
1202 }
1203
1204 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1205 self.remaining_turns = remaining_turns;
1206 }
1207
1208 pub fn send_to_model(
1209 &mut self,
1210 model: Arc<dyn LanguageModel>,
1211 intent: CompletionIntent,
1212 window: Option<AnyWindowHandle>,
1213 cx: &mut Context<Self>,
1214 ) {
1215 if self.remaining_turns == 0 {
1216 return;
1217 }
1218
1219 self.remaining_turns -= 1;
1220
1221 let request = self.to_completion_request(model.clone(), intent, cx);
1222
1223 self.stream_completion(request, model, window, cx);
1224 }
1225
1226 pub fn used_tools_since_last_user_message(&self) -> bool {
1227 for message in self.messages.iter().rev() {
1228 if self.tool_use.message_has_tool_results(message.id) {
1229 return true;
1230 } else if message.role == Role::User {
1231 return false;
1232 }
1233 }
1234
1235 false
1236 }
1237
1238 pub fn to_completion_request(
1239 &self,
1240 model: Arc<dyn LanguageModel>,
1241 intent: CompletionIntent,
1242 cx: &mut Context<Self>,
1243 ) -> LanguageModelRequest {
1244 let mut request = LanguageModelRequest {
1245 thread_id: Some(self.id.to_string()),
1246 prompt_id: Some(self.last_prompt_id.to_string()),
1247 intent: Some(intent),
1248 mode: None,
1249 messages: vec![],
1250 tools: Vec::new(),
1251 tool_choice: None,
1252 stop: Vec::new(),
1253 temperature: AgentSettings::temperature_for_model(&model, cx),
1254 };
1255
1256 let available_tools = self.available_tools(cx, model.clone());
1257 let available_tool_names = available_tools
1258 .iter()
1259 .map(|tool| tool.name.clone())
1260 .collect();
1261
1262 let model_context = &ModelContext {
1263 available_tools: available_tool_names,
1264 };
1265
1266 if let Some(project_context) = self.project_context.borrow().as_ref() {
1267 match self
1268 .prompt_builder
1269 .generate_assistant_system_prompt(project_context, model_context)
1270 {
1271 Err(err) => {
1272 let message = format!("{err:?}").into();
1273 log::error!("{message}");
1274 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1275 header: "Error generating system prompt".into(),
1276 message,
1277 }));
1278 }
1279 Ok(system_prompt) => {
1280 request.messages.push(LanguageModelRequestMessage {
1281 role: Role::System,
1282 content: vec![MessageContent::Text(system_prompt)],
1283 cache: true,
1284 });
1285 }
1286 }
1287 } else {
1288 let message = "Context for system prompt unexpectedly not ready.".into();
1289 log::error!("{message}");
1290 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1291 header: "Error generating system prompt".into(),
1292 message,
1293 }));
1294 }
1295
1296 let mut message_ix_to_cache = None;
1297 for message in &self.messages {
1298 let mut request_message = LanguageModelRequestMessage {
1299 role: message.role,
1300 content: Vec::new(),
1301 cache: false,
1302 };
1303
1304 message
1305 .loaded_context
1306 .add_to_request_message(&mut request_message);
1307
1308 for segment in &message.segments {
1309 match segment {
1310 MessageSegment::Text(text) => {
1311 if !text.is_empty() {
1312 request_message
1313 .content
1314 .push(MessageContent::Text(text.into()));
1315 }
1316 }
1317 MessageSegment::Thinking { text, signature } => {
1318 if !text.is_empty() {
1319 request_message.content.push(MessageContent::Thinking {
1320 text: text.into(),
1321 signature: signature.clone(),
1322 });
1323 }
1324 }
1325 MessageSegment::RedactedThinking(data) => {
1326 request_message
1327 .content
1328 .push(MessageContent::RedactedThinking(data.clone()));
1329 }
1330 };
1331 }
1332
1333 let mut cache_message = true;
1334 let mut tool_results_message = LanguageModelRequestMessage {
1335 role: Role::User,
1336 content: Vec::new(),
1337 cache: false,
1338 };
1339 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1340 if let Some(tool_result) = tool_result {
1341 request_message
1342 .content
1343 .push(MessageContent::ToolUse(tool_use.clone()));
1344 tool_results_message
1345 .content
1346 .push(MessageContent::ToolResult(LanguageModelToolResult {
1347 tool_use_id: tool_use.id.clone(),
1348 tool_name: tool_result.tool_name.clone(),
1349 is_error: tool_result.is_error,
1350 content: if tool_result.content.is_empty() {
1351 // Surprisingly, the API fails if we return an empty string here.
1352 // It thinks we are sending a tool use without a tool result.
1353 "<Tool returned an empty string>".into()
1354 } else {
1355 tool_result.content.clone()
1356 },
1357 output: None,
1358 }));
1359 } else {
1360 cache_message = false;
1361 log::debug!(
1362 "skipped tool use {:?} because it is still pending",
1363 tool_use
1364 );
1365 }
1366 }
1367
1368 if cache_message {
1369 message_ix_to_cache = Some(request.messages.len());
1370 }
1371 request.messages.push(request_message);
1372
1373 if !tool_results_message.content.is_empty() {
1374 if cache_message {
1375 message_ix_to_cache = Some(request.messages.len());
1376 }
1377 request.messages.push(tool_results_message);
1378 }
1379 }
1380
1381 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1382 if let Some(message_ix_to_cache) = message_ix_to_cache {
1383 request.messages[message_ix_to_cache].cache = true;
1384 }
1385
1386 request.tools = available_tools;
1387 request.mode = if model.supports_max_mode() {
1388 Some(self.completion_mode.into())
1389 } else {
1390 Some(CompletionMode::Normal.into())
1391 };
1392
1393 request
1394 }
1395
1396 fn to_summarize_request(
1397 &self,
1398 model: &Arc<dyn LanguageModel>,
1399 intent: CompletionIntent,
1400 added_user_message: String,
1401 cx: &App,
1402 ) -> LanguageModelRequest {
1403 let mut request = LanguageModelRequest {
1404 thread_id: None,
1405 prompt_id: None,
1406 intent: Some(intent),
1407 mode: None,
1408 messages: vec![],
1409 tools: Vec::new(),
1410 tool_choice: None,
1411 stop: Vec::new(),
1412 temperature: AgentSettings::temperature_for_model(model, cx),
1413 };
1414
1415 for message in &self.messages {
1416 let mut request_message = LanguageModelRequestMessage {
1417 role: message.role,
1418 content: Vec::new(),
1419 cache: false,
1420 };
1421
1422 for segment in &message.segments {
1423 match segment {
1424 MessageSegment::Text(text) => request_message
1425 .content
1426 .push(MessageContent::Text(text.clone())),
1427 MessageSegment::Thinking { .. } => {}
1428 MessageSegment::RedactedThinking(_) => {}
1429 }
1430 }
1431
1432 if request_message.content.is_empty() {
1433 continue;
1434 }
1435
1436 request.messages.push(request_message);
1437 }
1438
1439 request.messages.push(LanguageModelRequestMessage {
1440 role: Role::User,
1441 content: vec![MessageContent::Text(added_user_message)],
1442 cache: false,
1443 });
1444
1445 request
1446 }
1447
1448 pub fn stream_completion(
1449 &mut self,
1450 request: LanguageModelRequest,
1451 model: Arc<dyn LanguageModel>,
1452 window: Option<AnyWindowHandle>,
1453 cx: &mut Context<Self>,
1454 ) {
1455 self.tool_use_limit_reached = false;
1456
1457 let pending_completion_id = post_inc(&mut self.completion_count);
1458 let mut request_callback_parameters = if self.request_callback.is_some() {
1459 Some((request.clone(), Vec::new()))
1460 } else {
1461 None
1462 };
1463 let prompt_id = self.last_prompt_id.clone();
1464 let tool_use_metadata = ToolUseMetadata {
1465 model: model.clone(),
1466 thread_id: self.id.clone(),
1467 prompt_id: prompt_id.clone(),
1468 };
1469
1470 self.last_received_chunk_at = Some(Instant::now());
1471
1472 let task = cx.spawn(async move |thread, cx| {
1473 let stream_completion_future = model.stream_completion(request, &cx);
1474 let initial_token_usage =
1475 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1476 let stream_completion = async {
1477 let mut events = stream_completion_future.await?;
1478
1479 let mut stop_reason = StopReason::EndTurn;
1480 let mut current_token_usage = TokenUsage::default();
1481
1482 thread
1483 .update(cx, |_thread, cx| {
1484 cx.emit(ThreadEvent::NewRequest);
1485 })
1486 .ok();
1487
1488 let mut request_assistant_message_id = None;
1489
1490 while let Some(event) = events.next().await {
1491 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1492 response_events
1493 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1494 }
1495
1496 thread.update(cx, |thread, cx| {
1497 let event = match event {
1498 Ok(event) => event,
1499 Err(LanguageModelCompletionError::BadInputJson {
1500 id,
1501 tool_name,
1502 raw_input: invalid_input_json,
1503 json_parse_error,
1504 }) => {
1505 thread.receive_invalid_tool_json(
1506 id,
1507 tool_name,
1508 invalid_input_json,
1509 json_parse_error,
1510 window,
1511 cx,
1512 );
1513 return Ok(());
1514 }
1515 Err(LanguageModelCompletionError::Other(error)) => {
1516 return Err(error);
1517 }
1518 Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
1519 return Err(err.into());
1520 }
1521 };
1522
1523 match event {
1524 LanguageModelCompletionEvent::StartMessage { .. } => {
1525 request_assistant_message_id =
1526 Some(thread.insert_assistant_message(
1527 vec![MessageSegment::Text(String::new())],
1528 cx,
1529 ));
1530 }
1531 LanguageModelCompletionEvent::Stop(reason) => {
1532 stop_reason = reason;
1533 }
1534 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1535 thread.update_token_usage_at_last_message(token_usage);
1536 thread.cumulative_token_usage = thread.cumulative_token_usage
1537 + token_usage
1538 - current_token_usage;
1539 current_token_usage = token_usage;
1540 }
1541 LanguageModelCompletionEvent::Text(chunk) => {
1542 thread.received_chunk();
1543
1544 cx.emit(ThreadEvent::ReceivedTextChunk);
1545 if let Some(last_message) = thread.messages.last_mut() {
1546 if last_message.role == Role::Assistant
1547 && !thread.tool_use.has_tool_results(last_message.id)
1548 {
1549 last_message.push_text(&chunk);
1550 cx.emit(ThreadEvent::StreamedAssistantText(
1551 last_message.id,
1552 chunk,
1553 ));
1554 } else {
1555 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1556 // of a new Assistant response.
1557 //
1558 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1559 // will result in duplicating the text of the chunk in the rendered Markdown.
1560 request_assistant_message_id =
1561 Some(thread.insert_assistant_message(
1562 vec![MessageSegment::Text(chunk.to_string())],
1563 cx,
1564 ));
1565 };
1566 }
1567 }
1568 LanguageModelCompletionEvent::Thinking {
1569 text: chunk,
1570 signature,
1571 } => {
1572 thread.received_chunk();
1573
1574 if let Some(last_message) = thread.messages.last_mut() {
1575 if last_message.role == Role::Assistant
1576 && !thread.tool_use.has_tool_results(last_message.id)
1577 {
1578 last_message.push_thinking(&chunk, signature);
1579 cx.emit(ThreadEvent::StreamedAssistantThinking(
1580 last_message.id,
1581 chunk,
1582 ));
1583 } else {
1584 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1585 // of a new Assistant response.
1586 //
1587 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1588 // will result in duplicating the text of the chunk in the rendered Markdown.
1589 request_assistant_message_id =
1590 Some(thread.insert_assistant_message(
1591 vec![MessageSegment::Thinking {
1592 text: chunk.to_string(),
1593 signature,
1594 }],
1595 cx,
1596 ));
1597 };
1598 }
1599 }
1600 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1601 let last_assistant_message_id = request_assistant_message_id
1602 .unwrap_or_else(|| {
1603 let new_assistant_message_id =
1604 thread.insert_assistant_message(vec![], cx);
1605 request_assistant_message_id =
1606 Some(new_assistant_message_id);
1607 new_assistant_message_id
1608 });
1609
1610 let tool_use_id = tool_use.id.clone();
1611 let streamed_input = if tool_use.is_input_complete {
1612 None
1613 } else {
1614 Some((&tool_use.input).clone())
1615 };
1616
1617 let ui_text = thread.tool_use.request_tool_use(
1618 last_assistant_message_id,
1619 tool_use,
1620 tool_use_metadata.clone(),
1621 cx,
1622 );
1623
1624 if let Some(input) = streamed_input {
1625 cx.emit(ThreadEvent::StreamedToolUse {
1626 tool_use_id,
1627 ui_text,
1628 input,
1629 });
1630 }
1631 }
1632 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1633 if let Some(completion) = thread
1634 .pending_completions
1635 .iter_mut()
1636 .find(|completion| completion.id == pending_completion_id)
1637 {
1638 match status_update {
1639 CompletionRequestStatus::Queued {
1640 position,
1641 } => {
1642 completion.queue_state = QueueState::Queued { position };
1643 }
1644 CompletionRequestStatus::Started => {
1645 completion.queue_state = QueueState::Started;
1646 }
1647 CompletionRequestStatus::Failed {
1648 code, message, request_id
1649 } => {
1650 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1651 }
1652 CompletionRequestStatus::UsageUpdated {
1653 amount, limit
1654 } => {
1655 thread.update_model_request_usage(amount as u32, limit, cx);
1656 }
1657 CompletionRequestStatus::ToolUseLimitReached => {
1658 thread.tool_use_limit_reached = true;
1659 cx.emit(ThreadEvent::ToolUseLimitReached);
1660 }
1661 }
1662 }
1663 }
1664 }
1665
1666 thread.touch_updated_at();
1667 cx.emit(ThreadEvent::StreamedCompletion);
1668 cx.notify();
1669
1670 thread.auto_capture_telemetry(cx);
1671 Ok(())
1672 })??;
1673
1674 smol::future::yield_now().await;
1675 }
1676
1677 thread.update(cx, |thread, cx| {
1678 thread.last_received_chunk_at = None;
1679 thread
1680 .pending_completions
1681 .retain(|completion| completion.id != pending_completion_id);
1682
1683 // If there is a response without tool use, summarize the message. Otherwise,
1684 // allow two tool uses before summarizing.
1685 if matches!(thread.summary, ThreadSummary::Pending)
1686 && thread.messages.len() >= 2
1687 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1688 {
1689 thread.summarize(cx);
1690 }
1691 })?;
1692
1693 anyhow::Ok(stop_reason)
1694 };
1695
1696 let result = stream_completion.await;
1697
1698 thread
1699 .update(cx, |thread, cx| {
1700 thread.finalize_pending_checkpoint(cx);
1701 match result.as_ref() {
1702 Ok(stop_reason) => match stop_reason {
1703 StopReason::ToolUse => {
1704 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1705 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1706 }
1707 StopReason::EndTurn | StopReason::MaxTokens => {
1708 thread.project.update(cx, |project, cx| {
1709 project.set_agent_location(None, cx);
1710 });
1711 }
1712 StopReason::Refusal => {
1713 thread.project.update(cx, |project, cx| {
1714 project.set_agent_location(None, cx);
1715 });
1716
1717 // Remove the turn that was refused.
1718 //
1719 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1720 {
1721 let mut messages_to_remove = Vec::new();
1722
1723 for (ix, message) in thread.messages.iter().enumerate().rev() {
1724 messages_to_remove.push(message.id);
1725
1726 if message.role == Role::User {
1727 if ix == 0 {
1728 break;
1729 }
1730
1731 if let Some(prev_message) = thread.messages.get(ix - 1) {
1732 if prev_message.role == Role::Assistant {
1733 break;
1734 }
1735 }
1736 }
1737 }
1738
1739 for message_id in messages_to_remove {
1740 thread.delete_message(message_id, cx);
1741 }
1742 }
1743
1744 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1745 header: "Language model refusal".into(),
1746 message: "Model refused to generate content for safety reasons.".into(),
1747 }));
1748 }
1749 },
1750 Err(error) => {
1751 thread.project.update(cx, |project, cx| {
1752 project.set_agent_location(None, cx);
1753 });
1754
1755 if error.is::<PaymentRequiredError>() {
1756 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1757 } else if let Some(error) =
1758 error.downcast_ref::<ModelRequestLimitReachedError>()
1759 {
1760 cx.emit(ThreadEvent::ShowError(
1761 ThreadError::ModelRequestLimitReached { plan: error.plan },
1762 ));
1763 } else if let Some(known_error) =
1764 error.downcast_ref::<LanguageModelKnownError>()
1765 {
1766 match known_error {
1767 LanguageModelKnownError::ContextWindowLimitExceeded {
1768 tokens,
1769 } => {
1770 thread.exceeded_window_error = Some(ExceededWindowError {
1771 model_id: model.id(),
1772 token_count: *tokens,
1773 });
1774 cx.notify();
1775 }
1776 }
1777 } else {
1778 let error_message = error
1779 .chain()
1780 .map(|err| err.to_string())
1781 .collect::<Vec<_>>()
1782 .join("\n");
1783 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1784 header: "Error interacting with language model".into(),
1785 message: SharedString::from(error_message.clone()),
1786 }));
1787 }
1788
1789 thread.cancel_last_completion(window, cx);
1790 }
1791 }
1792
1793 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1794
1795 if let Some((request_callback, (request, response_events))) = thread
1796 .request_callback
1797 .as_mut()
1798 .zip(request_callback_parameters.as_ref())
1799 {
1800 request_callback(request, response_events);
1801 }
1802
1803 thread.auto_capture_telemetry(cx);
1804
1805 if let Ok(initial_usage) = initial_token_usage {
1806 let usage = thread.cumulative_token_usage - initial_usage;
1807
1808 telemetry::event!(
1809 "Assistant Thread Completion",
1810 thread_id = thread.id().to_string(),
1811 prompt_id = prompt_id,
1812 model = model.telemetry_id(),
1813 model_provider = model.provider_id().to_string(),
1814 input_tokens = usage.input_tokens,
1815 output_tokens = usage.output_tokens,
1816 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1817 cache_read_input_tokens = usage.cache_read_input_tokens,
1818 );
1819 }
1820 })
1821 .ok();
1822 });
1823
1824 self.pending_completions.push(PendingCompletion {
1825 id: pending_completion_id,
1826 queue_state: QueueState::Sending,
1827 _task: task,
1828 });
1829 }
1830
1831 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1832 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1833 println!("No thread summary model");
1834 return;
1835 };
1836
1837 if !model.provider.is_authenticated(cx) {
1838 return;
1839 }
1840
1841 let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
1842
1843 let request = self.to_summarize_request(
1844 &model.model,
1845 CompletionIntent::ThreadSummarization,
1846 added_user_message.into(),
1847 cx,
1848 );
1849
1850 self.summary = ThreadSummary::Generating;
1851
1852 self.pending_summary = cx.spawn(async move |this, cx| {
1853 let result = async {
1854 let mut messages = model.model.stream_completion(request, &cx).await?;
1855
1856 let mut new_summary = String::new();
1857 while let Some(event) = messages.next().await {
1858 let Ok(event) = event else {
1859 continue;
1860 };
1861 let text = match event {
1862 LanguageModelCompletionEvent::Text(text) => text,
1863 LanguageModelCompletionEvent::StatusUpdate(
1864 CompletionRequestStatus::UsageUpdated { amount, limit },
1865 ) => {
1866 this.update(cx, |thread, cx| {
1867 thread.update_model_request_usage(amount as u32, limit, cx);
1868 })?;
1869 continue;
1870 }
1871 _ => continue,
1872 };
1873
1874 let mut lines = text.lines();
1875 new_summary.extend(lines.next());
1876
1877 // Stop if the LLM generated multiple lines.
1878 if lines.next().is_some() {
1879 break;
1880 }
1881 }
1882
1883 anyhow::Ok(new_summary)
1884 }
1885 .await;
1886
1887 this.update(cx, |this, cx| {
1888 match result {
1889 Ok(new_summary) => {
1890 if new_summary.is_empty() {
1891 this.summary = ThreadSummary::Error;
1892 } else {
1893 this.summary = ThreadSummary::Ready(new_summary.into());
1894 }
1895 }
1896 Err(err) => {
1897 this.summary = ThreadSummary::Error;
1898 log::error!("Failed to generate thread summary: {}", err);
1899 }
1900 }
1901 cx.emit(ThreadEvent::SummaryGenerated);
1902 })
1903 .log_err()?;
1904
1905 Some(())
1906 });
1907 }
1908
1909 pub fn start_generating_detailed_summary_if_needed(
1910 &mut self,
1911 thread_store: WeakEntity<ThreadStore>,
1912 cx: &mut Context<Self>,
1913 ) {
1914 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1915 return;
1916 };
1917
1918 match &*self.detailed_summary_rx.borrow() {
1919 DetailedSummaryState::Generating { message_id, .. }
1920 | DetailedSummaryState::Generated { message_id, .. }
1921 if *message_id == last_message_id =>
1922 {
1923 // Already up-to-date
1924 return;
1925 }
1926 _ => {}
1927 }
1928
1929 let Some(ConfiguredModel { model, provider }) =
1930 LanguageModelRegistry::read_global(cx).thread_summary_model()
1931 else {
1932 return;
1933 };
1934
1935 if !provider.is_authenticated(cx) {
1936 return;
1937 }
1938
1939 let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
1940
1941 let request = self.to_summarize_request(
1942 &model,
1943 CompletionIntent::ThreadContextSummarization,
1944 added_user_message.into(),
1945 cx,
1946 );
1947
1948 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1949 message_id: last_message_id,
1950 };
1951
1952 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1953 // be better to allow the old task to complete, but this would require logic for choosing
1954 // which result to prefer (the old task could complete after the new one, resulting in a
1955 // stale summary).
1956 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1957 let stream = model.stream_completion_text(request, &cx);
1958 let Some(mut messages) = stream.await.log_err() else {
1959 thread
1960 .update(cx, |thread, _cx| {
1961 *thread.detailed_summary_tx.borrow_mut() =
1962 DetailedSummaryState::NotGenerated;
1963 })
1964 .ok()?;
1965 return None;
1966 };
1967
1968 let mut new_detailed_summary = String::new();
1969
1970 while let Some(chunk) = messages.stream.next().await {
1971 if let Some(chunk) = chunk.log_err() {
1972 new_detailed_summary.push_str(&chunk);
1973 }
1974 }
1975
1976 thread
1977 .update(cx, |thread, _cx| {
1978 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1979 text: new_detailed_summary.into(),
1980 message_id: last_message_id,
1981 };
1982 })
1983 .ok()?;
1984
1985 // Save thread so its summary can be reused later
1986 if let Some(thread) = thread.upgrade() {
1987 if let Ok(Ok(save_task)) = cx.update(|cx| {
1988 thread_store
1989 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
1990 }) {
1991 save_task.await.log_err();
1992 }
1993 }
1994
1995 Some(())
1996 });
1997 }
1998
1999 pub async fn wait_for_detailed_summary_or_text(
2000 this: &Entity<Self>,
2001 cx: &mut AsyncApp,
2002 ) -> Option<SharedString> {
2003 let mut detailed_summary_rx = this
2004 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2005 .ok()?;
2006 loop {
2007 match detailed_summary_rx.recv().await? {
2008 DetailedSummaryState::Generating { .. } => {}
2009 DetailedSummaryState::NotGenerated => {
2010 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2011 }
2012 DetailedSummaryState::Generated { text, .. } => return Some(text),
2013 }
2014 }
2015 }
2016
2017 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2018 self.detailed_summary_rx
2019 .borrow()
2020 .text()
2021 .unwrap_or_else(|| self.text().into())
2022 }
2023
2024 pub fn is_generating_detailed_summary(&self) -> bool {
2025 matches!(
2026 &*self.detailed_summary_rx.borrow(),
2027 DetailedSummaryState::Generating { .. }
2028 )
2029 }
2030
2031 pub fn use_pending_tools(
2032 &mut self,
2033 window: Option<AnyWindowHandle>,
2034 cx: &mut Context<Self>,
2035 model: Arc<dyn LanguageModel>,
2036 ) -> Vec<PendingToolUse> {
2037 self.auto_capture_telemetry(cx);
2038 let request =
2039 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2040 let pending_tool_uses = self
2041 .tool_use
2042 .pending_tool_uses()
2043 .into_iter()
2044 .filter(|tool_use| tool_use.status.is_idle())
2045 .cloned()
2046 .collect::<Vec<_>>();
2047
2048 for tool_use in pending_tool_uses.iter() {
2049 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2050 if tool.needs_confirmation(&tool_use.input, cx)
2051 && !AgentSettings::get_global(cx).always_allow_tool_actions
2052 {
2053 self.tool_use.confirm_tool_use(
2054 tool_use.id.clone(),
2055 tool_use.ui_text.clone(),
2056 tool_use.input.clone(),
2057 request.clone(),
2058 tool,
2059 );
2060 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2061 } else {
2062 self.run_tool(
2063 tool_use.id.clone(),
2064 tool_use.ui_text.clone(),
2065 tool_use.input.clone(),
2066 request.clone(),
2067 tool,
2068 model.clone(),
2069 window,
2070 cx,
2071 );
2072 }
2073 } else {
2074 self.handle_hallucinated_tool_use(
2075 tool_use.id.clone(),
2076 tool_use.name.clone(),
2077 window,
2078 cx,
2079 );
2080 }
2081 }
2082
2083 pending_tool_uses
2084 }
2085
2086 pub fn handle_hallucinated_tool_use(
2087 &mut self,
2088 tool_use_id: LanguageModelToolUseId,
2089 hallucinated_tool_name: Arc<str>,
2090 window: Option<AnyWindowHandle>,
2091 cx: &mut Context<Thread>,
2092 ) {
2093 let available_tools = self.profile.enabled_tools(cx);
2094
2095 let tool_list = available_tools
2096 .iter()
2097 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2098 .collect::<Vec<_>>()
2099 .join("\n");
2100
2101 let error_message = format!(
2102 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2103 hallucinated_tool_name, tool_list
2104 );
2105
2106 let pending_tool_use = self.tool_use.insert_tool_output(
2107 tool_use_id.clone(),
2108 hallucinated_tool_name,
2109 Err(anyhow!("Missing tool call: {error_message}")),
2110 self.configured_model.as_ref(),
2111 );
2112
2113 cx.emit(ThreadEvent::MissingToolUse {
2114 tool_use_id: tool_use_id.clone(),
2115 ui_text: error_message.into(),
2116 });
2117
2118 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2119 }
2120
2121 pub fn receive_invalid_tool_json(
2122 &mut self,
2123 tool_use_id: LanguageModelToolUseId,
2124 tool_name: Arc<str>,
2125 invalid_json: Arc<str>,
2126 error: String,
2127 window: Option<AnyWindowHandle>,
2128 cx: &mut Context<Thread>,
2129 ) {
2130 log::error!("The model returned invalid input JSON: {invalid_json}");
2131
2132 let pending_tool_use = self.tool_use.insert_tool_output(
2133 tool_use_id.clone(),
2134 tool_name,
2135 Err(anyhow!("Error parsing input JSON: {error}")),
2136 self.configured_model.as_ref(),
2137 );
2138 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2139 pending_tool_use.ui_text.clone()
2140 } else {
2141 log::error!(
2142 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2143 );
2144 format!("Unknown tool {}", tool_use_id).into()
2145 };
2146
2147 cx.emit(ThreadEvent::InvalidToolInput {
2148 tool_use_id: tool_use_id.clone(),
2149 ui_text,
2150 invalid_input_json: invalid_json,
2151 });
2152
2153 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2154 }
2155
2156 pub fn run_tool(
2157 &mut self,
2158 tool_use_id: LanguageModelToolUseId,
2159 ui_text: impl Into<SharedString>,
2160 input: serde_json::Value,
2161 request: Arc<LanguageModelRequest>,
2162 tool: Arc<dyn Tool>,
2163 model: Arc<dyn LanguageModel>,
2164 window: Option<AnyWindowHandle>,
2165 cx: &mut Context<Thread>,
2166 ) {
2167 let task =
2168 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2169 self.tool_use
2170 .run_pending_tool(tool_use_id, ui_text.into(), task);
2171 }
2172
2173 fn spawn_tool_use(
2174 &mut self,
2175 tool_use_id: LanguageModelToolUseId,
2176 request: Arc<LanguageModelRequest>,
2177 input: serde_json::Value,
2178 tool: Arc<dyn Tool>,
2179 model: Arc<dyn LanguageModel>,
2180 window: Option<AnyWindowHandle>,
2181 cx: &mut Context<Thread>,
2182 ) -> Task<()> {
2183 let tool_name: Arc<str> = tool.name().into();
2184
2185 let tool_result = tool.run(
2186 input,
2187 request,
2188 self.project.clone(),
2189 self.action_log.clone(),
2190 model,
2191 window,
2192 cx,
2193 );
2194
2195 // Store the card separately if it exists
2196 if let Some(card) = tool_result.card.clone() {
2197 self.tool_use
2198 .insert_tool_result_card(tool_use_id.clone(), card);
2199 }
2200
2201 cx.spawn({
2202 async move |thread: WeakEntity<Thread>, cx| {
2203 let output = tool_result.output.await;
2204
2205 thread
2206 .update(cx, |thread, cx| {
2207 let pending_tool_use = thread.tool_use.insert_tool_output(
2208 tool_use_id.clone(),
2209 tool_name,
2210 output,
2211 thread.configured_model.as_ref(),
2212 );
2213 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2214 })
2215 .ok();
2216 }
2217 })
2218 }
2219
2220 fn tool_finished(
2221 &mut self,
2222 tool_use_id: LanguageModelToolUseId,
2223 pending_tool_use: Option<PendingToolUse>,
2224 canceled: bool,
2225 window: Option<AnyWindowHandle>,
2226 cx: &mut Context<Self>,
2227 ) {
2228 if self.all_tools_finished() {
2229 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2230 if !canceled {
2231 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2232 }
2233 self.auto_capture_telemetry(cx);
2234 }
2235 }
2236
2237 cx.emit(ThreadEvent::ToolFinished {
2238 tool_use_id,
2239 pending_tool_use,
2240 });
2241 }
2242
2243 /// Cancels the last pending completion, if there are any pending.
2244 ///
2245 /// Returns whether a completion was canceled.
2246 pub fn cancel_last_completion(
2247 &mut self,
2248 window: Option<AnyWindowHandle>,
2249 cx: &mut Context<Self>,
2250 ) -> bool {
2251 let mut canceled = self.pending_completions.pop().is_some();
2252
2253 for pending_tool_use in self.tool_use.cancel_pending() {
2254 canceled = true;
2255 self.tool_finished(
2256 pending_tool_use.id.clone(),
2257 Some(pending_tool_use),
2258 true,
2259 window,
2260 cx,
2261 );
2262 }
2263
2264 if canceled {
2265 cx.emit(ThreadEvent::CompletionCanceled);
2266
2267 // When canceled, we always want to insert the checkpoint.
2268 // (We skip over finalize_pending_checkpoint, because it
2269 // would conclude we didn't have anything to insert here.)
2270 if let Some(checkpoint) = self.pending_checkpoint.take() {
2271 self.insert_checkpoint(checkpoint, cx);
2272 }
2273 } else {
2274 self.finalize_pending_checkpoint(cx);
2275 }
2276
2277 canceled
2278 }
2279
2280 /// Signals that any in-progress editing should be canceled.
2281 ///
2282 /// This method is used to notify listeners (like ActiveThread) that
2283 /// they should cancel any editing operations.
2284 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2285 cx.emit(ThreadEvent::CancelEditing);
2286 }
2287
2288 pub fn feedback(&self) -> Option<ThreadFeedback> {
2289 self.feedback
2290 }
2291
2292 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2293 self.message_feedback.get(&message_id).copied()
2294 }
2295
2296 pub fn report_message_feedback(
2297 &mut self,
2298 message_id: MessageId,
2299 feedback: ThreadFeedback,
2300 cx: &mut Context<Self>,
2301 ) -> Task<Result<()>> {
2302 if self.message_feedback.get(&message_id) == Some(&feedback) {
2303 return Task::ready(Ok(()));
2304 }
2305
2306 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2307 let serialized_thread = self.serialize(cx);
2308 let thread_id = self.id().clone();
2309 let client = self.project.read(cx).client();
2310
2311 let enabled_tool_names: Vec<String> = self
2312 .profile
2313 .enabled_tools(cx)
2314 .iter()
2315 .map(|tool| tool.name())
2316 .collect();
2317
2318 self.message_feedback.insert(message_id, feedback);
2319
2320 cx.notify();
2321
2322 let message_content = self
2323 .message(message_id)
2324 .map(|msg| msg.to_string())
2325 .unwrap_or_default();
2326
2327 cx.background_spawn(async move {
2328 let final_project_snapshot = final_project_snapshot.await;
2329 let serialized_thread = serialized_thread.await?;
2330 let thread_data =
2331 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2332
2333 let rating = match feedback {
2334 ThreadFeedback::Positive => "positive",
2335 ThreadFeedback::Negative => "negative",
2336 };
2337 telemetry::event!(
2338 "Assistant Thread Rated",
2339 rating,
2340 thread_id,
2341 enabled_tool_names,
2342 message_id = message_id.0,
2343 message_content,
2344 thread_data,
2345 final_project_snapshot
2346 );
2347 client.telemetry().flush_events().await;
2348
2349 Ok(())
2350 })
2351 }
2352
2353 pub fn report_feedback(
2354 &mut self,
2355 feedback: ThreadFeedback,
2356 cx: &mut Context<Self>,
2357 ) -> Task<Result<()>> {
2358 let last_assistant_message_id = self
2359 .messages
2360 .iter()
2361 .rev()
2362 .find(|msg| msg.role == Role::Assistant)
2363 .map(|msg| msg.id);
2364
2365 if let Some(message_id) = last_assistant_message_id {
2366 self.report_message_feedback(message_id, feedback, cx)
2367 } else {
2368 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2369 let serialized_thread = self.serialize(cx);
2370 let thread_id = self.id().clone();
2371 let client = self.project.read(cx).client();
2372 self.feedback = Some(feedback);
2373 cx.notify();
2374
2375 cx.background_spawn(async move {
2376 let final_project_snapshot = final_project_snapshot.await;
2377 let serialized_thread = serialized_thread.await?;
2378 let thread_data = serde_json::to_value(serialized_thread)
2379 .unwrap_or_else(|_| serde_json::Value::Null);
2380
2381 let rating = match feedback {
2382 ThreadFeedback::Positive => "positive",
2383 ThreadFeedback::Negative => "negative",
2384 };
2385 telemetry::event!(
2386 "Assistant Thread Rated",
2387 rating,
2388 thread_id,
2389 thread_data,
2390 final_project_snapshot
2391 );
2392 client.telemetry().flush_events().await;
2393
2394 Ok(())
2395 })
2396 }
2397 }
2398
2399 /// Create a snapshot of the current project state including git information and unsaved buffers.
2400 fn project_snapshot(
2401 project: Entity<Project>,
2402 cx: &mut Context<Self>,
2403 ) -> Task<Arc<ProjectSnapshot>> {
2404 let git_store = project.read(cx).git_store().clone();
2405 let worktree_snapshots: Vec<_> = project
2406 .read(cx)
2407 .visible_worktrees(cx)
2408 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2409 .collect();
2410
2411 cx.spawn(async move |_, cx| {
2412 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2413
2414 let mut unsaved_buffers = Vec::new();
2415 cx.update(|app_cx| {
2416 let buffer_store = project.read(app_cx).buffer_store();
2417 for buffer_handle in buffer_store.read(app_cx).buffers() {
2418 let buffer = buffer_handle.read(app_cx);
2419 if buffer.is_dirty() {
2420 if let Some(file) = buffer.file() {
2421 let path = file.path().to_string_lossy().to_string();
2422 unsaved_buffers.push(path);
2423 }
2424 }
2425 }
2426 })
2427 .ok();
2428
2429 Arc::new(ProjectSnapshot {
2430 worktree_snapshots,
2431 unsaved_buffer_paths: unsaved_buffers,
2432 timestamp: Utc::now(),
2433 })
2434 })
2435 }
2436
2437 fn worktree_snapshot(
2438 worktree: Entity<project::Worktree>,
2439 git_store: Entity<GitStore>,
2440 cx: &App,
2441 ) -> Task<WorktreeSnapshot> {
2442 cx.spawn(async move |cx| {
2443 // Get worktree path and snapshot
2444 let worktree_info = cx.update(|app_cx| {
2445 let worktree = worktree.read(app_cx);
2446 let path = worktree.abs_path().to_string_lossy().to_string();
2447 let snapshot = worktree.snapshot();
2448 (path, snapshot)
2449 });
2450
2451 let Ok((worktree_path, _snapshot)) = worktree_info else {
2452 return WorktreeSnapshot {
2453 worktree_path: String::new(),
2454 git_state: None,
2455 };
2456 };
2457
2458 let git_state = git_store
2459 .update(cx, |git_store, cx| {
2460 git_store
2461 .repositories()
2462 .values()
2463 .find(|repo| {
2464 repo.read(cx)
2465 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2466 .is_some()
2467 })
2468 .cloned()
2469 })
2470 .ok()
2471 .flatten()
2472 .map(|repo| {
2473 repo.update(cx, |repo, _| {
2474 let current_branch =
2475 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2476 repo.send_job(None, |state, _| async move {
2477 let RepositoryState::Local { backend, .. } = state else {
2478 return GitState {
2479 remote_url: None,
2480 head_sha: None,
2481 current_branch,
2482 diff: None,
2483 };
2484 };
2485
2486 let remote_url = backend.remote_url("origin");
2487 let head_sha = backend.head_sha().await;
2488 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2489
2490 GitState {
2491 remote_url,
2492 head_sha,
2493 current_branch,
2494 diff,
2495 }
2496 })
2497 })
2498 });
2499
2500 let git_state = match git_state {
2501 Some(git_state) => match git_state.ok() {
2502 Some(git_state) => git_state.await.ok(),
2503 None => None,
2504 },
2505 None => None,
2506 };
2507
2508 WorktreeSnapshot {
2509 worktree_path,
2510 git_state,
2511 }
2512 })
2513 }
2514
2515 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2516 let mut markdown = Vec::new();
2517
2518 let summary = self.summary().or_default();
2519 writeln!(markdown, "# {summary}\n")?;
2520
2521 for message in self.messages() {
2522 writeln!(
2523 markdown,
2524 "## {role}\n",
2525 role = match message.role {
2526 Role::User => "User",
2527 Role::Assistant => "Agent",
2528 Role::System => "System",
2529 }
2530 )?;
2531
2532 if !message.loaded_context.text.is_empty() {
2533 writeln!(markdown, "{}", message.loaded_context.text)?;
2534 }
2535
2536 if !message.loaded_context.images.is_empty() {
2537 writeln!(
2538 markdown,
2539 "\n{} images attached as context.\n",
2540 message.loaded_context.images.len()
2541 )?;
2542 }
2543
2544 for segment in &message.segments {
2545 match segment {
2546 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2547 MessageSegment::Thinking { text, .. } => {
2548 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2549 }
2550 MessageSegment::RedactedThinking(_) => {}
2551 }
2552 }
2553
2554 for tool_use in self.tool_uses_for_message(message.id, cx) {
2555 writeln!(
2556 markdown,
2557 "**Use Tool: {} ({})**",
2558 tool_use.name, tool_use.id
2559 )?;
2560 writeln!(markdown, "```json")?;
2561 writeln!(
2562 markdown,
2563 "{}",
2564 serde_json::to_string_pretty(&tool_use.input)?
2565 )?;
2566 writeln!(markdown, "```")?;
2567 }
2568
2569 for tool_result in self.tool_results_for_message(message.id) {
2570 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2571 if tool_result.is_error {
2572 write!(markdown, " (Error)")?;
2573 }
2574
2575 writeln!(markdown, "**\n")?;
2576 match &tool_result.content {
2577 LanguageModelToolResultContent::Text(text) => {
2578 writeln!(markdown, "{text}")?;
2579 }
2580 LanguageModelToolResultContent::Image(image) => {
2581 writeln!(markdown, "", image.source)?;
2582 }
2583 }
2584
2585 if let Some(output) = tool_result.output.as_ref() {
2586 writeln!(
2587 markdown,
2588 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2589 serde_json::to_string_pretty(output)?
2590 )?;
2591 }
2592 }
2593 }
2594
2595 Ok(String::from_utf8_lossy(&markdown).to_string())
2596 }
2597
2598 pub fn keep_edits_in_range(
2599 &mut self,
2600 buffer: Entity<language::Buffer>,
2601 buffer_range: Range<language::Anchor>,
2602 cx: &mut Context<Self>,
2603 ) {
2604 self.action_log.update(cx, |action_log, cx| {
2605 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2606 });
2607 }
2608
2609 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2610 self.action_log
2611 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2612 }
2613
2614 pub fn reject_edits_in_ranges(
2615 &mut self,
2616 buffer: Entity<language::Buffer>,
2617 buffer_ranges: Vec<Range<language::Anchor>>,
2618 cx: &mut Context<Self>,
2619 ) -> Task<Result<()>> {
2620 self.action_log.update(cx, |action_log, cx| {
2621 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2622 })
2623 }
2624
2625 pub fn action_log(&self) -> &Entity<ActionLog> {
2626 &self.action_log
2627 }
2628
2629 pub fn project(&self) -> &Entity<Project> {
2630 &self.project
2631 }
2632
2633 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2634 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2635 return;
2636 }
2637
2638 let now = Instant::now();
2639 if let Some(last) = self.last_auto_capture_at {
2640 if now.duration_since(last).as_secs() < 10 {
2641 return;
2642 }
2643 }
2644
2645 self.last_auto_capture_at = Some(now);
2646
2647 let thread_id = self.id().clone();
2648 let github_login = self
2649 .project
2650 .read(cx)
2651 .user_store()
2652 .read(cx)
2653 .current_user()
2654 .map(|user| user.github_login.clone());
2655 let client = self.project.read(cx).client();
2656 let serialize_task = self.serialize(cx);
2657
2658 cx.background_executor()
2659 .spawn(async move {
2660 if let Ok(serialized_thread) = serialize_task.await {
2661 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2662 telemetry::event!(
2663 "Agent Thread Auto-Captured",
2664 thread_id = thread_id.to_string(),
2665 thread_data = thread_data,
2666 auto_capture_reason = "tracked_user",
2667 github_login = github_login
2668 );
2669
2670 client.telemetry().flush_events().await;
2671 }
2672 }
2673 })
2674 .detach();
2675 }
2676
2677 pub fn cumulative_token_usage(&self) -> TokenUsage {
2678 self.cumulative_token_usage
2679 }
2680
2681 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2682 let Some(model) = self.configured_model.as_ref() else {
2683 return TotalTokenUsage::default();
2684 };
2685
2686 let max = model.model.max_token_count();
2687
2688 let index = self
2689 .messages
2690 .iter()
2691 .position(|msg| msg.id == message_id)
2692 .unwrap_or(0);
2693
2694 if index == 0 {
2695 return TotalTokenUsage { total: 0, max };
2696 }
2697
2698 let token_usage = &self
2699 .request_token_usage
2700 .get(index - 1)
2701 .cloned()
2702 .unwrap_or_default();
2703
2704 TotalTokenUsage {
2705 total: token_usage.total_tokens(),
2706 max,
2707 }
2708 }
2709
2710 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2711 let model = self.configured_model.as_ref()?;
2712
2713 let max = model.model.max_token_count();
2714
2715 if let Some(exceeded_error) = &self.exceeded_window_error {
2716 if model.model.id() == exceeded_error.model_id {
2717 return Some(TotalTokenUsage {
2718 total: exceeded_error.token_count,
2719 max,
2720 });
2721 }
2722 }
2723
2724 let total = self
2725 .token_usage_at_last_message()
2726 .unwrap_or_default()
2727 .total_tokens();
2728
2729 Some(TotalTokenUsage { total, max })
2730 }
2731
2732 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2733 self.request_token_usage
2734 .get(self.messages.len().saturating_sub(1))
2735 .or_else(|| self.request_token_usage.last())
2736 .cloned()
2737 }
2738
2739 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2740 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2741 self.request_token_usage
2742 .resize(self.messages.len(), placeholder);
2743
2744 if let Some(last) = self.request_token_usage.last_mut() {
2745 *last = token_usage;
2746 }
2747 }
2748
2749 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
2750 self.project.update(cx, |project, cx| {
2751 project.user_store().update(cx, |user_store, cx| {
2752 user_store.update_model_request_usage(
2753 ModelRequestUsage(RequestUsage {
2754 amount: amount as i32,
2755 limit,
2756 }),
2757 cx,
2758 )
2759 })
2760 });
2761 }
2762
2763 pub fn deny_tool_use(
2764 &mut self,
2765 tool_use_id: LanguageModelToolUseId,
2766 tool_name: Arc<str>,
2767 window: Option<AnyWindowHandle>,
2768 cx: &mut Context<Self>,
2769 ) {
2770 let err = Err(anyhow::anyhow!(
2771 "Permission to run tool action denied by user"
2772 ));
2773
2774 self.tool_use.insert_tool_output(
2775 tool_use_id.clone(),
2776 tool_name,
2777 err,
2778 self.configured_model.as_ref(),
2779 );
2780 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2781 }
2782}
2783
2784#[derive(Debug, Clone, Error)]
2785pub enum ThreadError {
2786 #[error("Payment required")]
2787 PaymentRequired,
2788 #[error("Model request limit reached")]
2789 ModelRequestLimitReached { plan: Plan },
2790 #[error("Message {header}: {message}")]
2791 Message {
2792 header: SharedString,
2793 message: SharedString,
2794 },
2795}
2796
2797#[derive(Debug, Clone)]
2798pub enum ThreadEvent {
2799 ShowError(ThreadError),
2800 StreamedCompletion,
2801 ReceivedTextChunk,
2802 NewRequest,
2803 StreamedAssistantText(MessageId, String),
2804 StreamedAssistantThinking(MessageId, String),
2805 StreamedToolUse {
2806 tool_use_id: LanguageModelToolUseId,
2807 ui_text: Arc<str>,
2808 input: serde_json::Value,
2809 },
2810 MissingToolUse {
2811 tool_use_id: LanguageModelToolUseId,
2812 ui_text: Arc<str>,
2813 },
2814 InvalidToolInput {
2815 tool_use_id: LanguageModelToolUseId,
2816 ui_text: Arc<str>,
2817 invalid_input_json: Arc<str>,
2818 },
2819 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2820 MessageAdded(MessageId),
2821 MessageEdited(MessageId),
2822 MessageDeleted(MessageId),
2823 SummaryGenerated,
2824 SummaryChanged,
2825 UsePendingTools {
2826 tool_uses: Vec<PendingToolUse>,
2827 },
2828 ToolFinished {
2829 #[allow(unused)]
2830 tool_use_id: LanguageModelToolUseId,
2831 /// The pending tool use that corresponds to this tool.
2832 pending_tool_use: Option<PendingToolUse>,
2833 },
2834 CheckpointChanged,
2835 ToolConfirmationNeeded,
2836 ToolUseLimitReached,
2837 CancelEditing,
2838 CompletionCanceled,
2839 ProfileChanged,
2840}
2841
2842impl EventEmitter<ThreadEvent> for Thread {}
2843
2844struct PendingCompletion {
2845 id: usize,
2846 queue_state: QueueState,
2847 _task: Task<()>,
2848}
2849
2850#[cfg(test)]
2851mod tests {
2852 use super::*;
2853 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2854 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
2855 use assistant_tool::ToolRegistry;
2856 use editor::EditorSettings;
2857 use gpui::TestAppContext;
2858 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2859 use project::{FakeFs, Project};
2860 use prompt_store::PromptBuilder;
2861 use serde_json::json;
2862 use settings::{Settings, SettingsStore};
2863 use std::sync::Arc;
2864 use theme::ThemeSettings;
2865 use util::path;
2866 use workspace::Workspace;
2867
2868 #[gpui::test]
2869 async fn test_message_with_context(cx: &mut TestAppContext) {
2870 init_test_settings(cx);
2871
2872 let project = create_test_project(
2873 cx,
2874 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2875 )
2876 .await;
2877
2878 let (_workspace, _thread_store, thread, context_store, model) =
2879 setup_test_environment(cx, project.clone()).await;
2880
2881 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2882 .await
2883 .unwrap();
2884
2885 let context =
2886 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2887 let loaded_context = cx
2888 .update(|cx| load_context(vec![context], &project, &None, cx))
2889 .await;
2890
2891 // Insert user message with context
2892 let message_id = thread.update(cx, |thread, cx| {
2893 thread.insert_user_message(
2894 "Please explain this code",
2895 loaded_context,
2896 None,
2897 Vec::new(),
2898 cx,
2899 )
2900 });
2901
2902 // Check content and context in message object
2903 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2904
2905 // Use different path format strings based on platform for the test
2906 #[cfg(windows)]
2907 let path_part = r"test\code.rs";
2908 #[cfg(not(windows))]
2909 let path_part = "test/code.rs";
2910
2911 let expected_context = format!(
2912 r#"
2913<context>
2914The following items were attached by the user. They are up-to-date and don't need to be re-read.
2915
2916<files>
2917```rs {path_part}
2918fn main() {{
2919 println!("Hello, world!");
2920}}
2921```
2922</files>
2923</context>
2924"#
2925 );
2926
2927 assert_eq!(message.role, Role::User);
2928 assert_eq!(message.segments.len(), 1);
2929 assert_eq!(
2930 message.segments[0],
2931 MessageSegment::Text("Please explain this code".to_string())
2932 );
2933 assert_eq!(message.loaded_context.text, expected_context);
2934
2935 // Check message in request
2936 let request = thread.update(cx, |thread, cx| {
2937 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
2938 });
2939
2940 assert_eq!(request.messages.len(), 2);
2941 let expected_full_message = format!("{}Please explain this code", expected_context);
2942 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2943 }
2944
2945 #[gpui::test]
2946 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2947 init_test_settings(cx);
2948
2949 let project = create_test_project(
2950 cx,
2951 json!({
2952 "file1.rs": "fn function1() {}\n",
2953 "file2.rs": "fn function2() {}\n",
2954 "file3.rs": "fn function3() {}\n",
2955 "file4.rs": "fn function4() {}\n",
2956 }),
2957 )
2958 .await;
2959
2960 let (_, _thread_store, thread, context_store, model) =
2961 setup_test_environment(cx, project.clone()).await;
2962
2963 // First message with context 1
2964 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2965 .await
2966 .unwrap();
2967 let new_contexts = context_store.update(cx, |store, cx| {
2968 store.new_context_for_thread(thread.read(cx), None)
2969 });
2970 assert_eq!(new_contexts.len(), 1);
2971 let loaded_context = cx
2972 .update(|cx| load_context(new_contexts, &project, &None, cx))
2973 .await;
2974 let message1_id = thread.update(cx, |thread, cx| {
2975 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2976 });
2977
2978 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2979 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2980 .await
2981 .unwrap();
2982 let new_contexts = context_store.update(cx, |store, cx| {
2983 store.new_context_for_thread(thread.read(cx), None)
2984 });
2985 assert_eq!(new_contexts.len(), 1);
2986 let loaded_context = cx
2987 .update(|cx| load_context(new_contexts, &project, &None, cx))
2988 .await;
2989 let message2_id = thread.update(cx, |thread, cx| {
2990 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2991 });
2992
2993 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2994 //
2995 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2996 .await
2997 .unwrap();
2998 let new_contexts = context_store.update(cx, |store, cx| {
2999 store.new_context_for_thread(thread.read(cx), None)
3000 });
3001 assert_eq!(new_contexts.len(), 1);
3002 let loaded_context = cx
3003 .update(|cx| load_context(new_contexts, &project, &None, cx))
3004 .await;
3005 let message3_id = thread.update(cx, |thread, cx| {
3006 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3007 });
3008
3009 // Check what contexts are included in each message
3010 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3011 (
3012 thread.message(message1_id).unwrap().clone(),
3013 thread.message(message2_id).unwrap().clone(),
3014 thread.message(message3_id).unwrap().clone(),
3015 )
3016 });
3017
3018 // First message should include context 1
3019 assert!(message1.loaded_context.text.contains("file1.rs"));
3020
3021 // Second message should include only context 2 (not 1)
3022 assert!(!message2.loaded_context.text.contains("file1.rs"));
3023 assert!(message2.loaded_context.text.contains("file2.rs"));
3024
3025 // Third message should include only context 3 (not 1 or 2)
3026 assert!(!message3.loaded_context.text.contains("file1.rs"));
3027 assert!(!message3.loaded_context.text.contains("file2.rs"));
3028 assert!(message3.loaded_context.text.contains("file3.rs"));
3029
3030 // Check entire request to make sure all contexts are properly included
3031 let request = thread.update(cx, |thread, cx| {
3032 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3033 });
3034
3035 // The request should contain all 3 messages
3036 assert_eq!(request.messages.len(), 4);
3037
3038 // Check that the contexts are properly formatted in each message
3039 assert!(request.messages[1].string_contents().contains("file1.rs"));
3040 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3041 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3042
3043 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3044 assert!(request.messages[2].string_contents().contains("file2.rs"));
3045 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3046
3047 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3048 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3049 assert!(request.messages[3].string_contents().contains("file3.rs"));
3050
3051 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3052 .await
3053 .unwrap();
3054 let new_contexts = context_store.update(cx, |store, cx| {
3055 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3056 });
3057 assert_eq!(new_contexts.len(), 3);
3058 let loaded_context = cx
3059 .update(|cx| load_context(new_contexts, &project, &None, cx))
3060 .await
3061 .loaded_context;
3062
3063 assert!(!loaded_context.text.contains("file1.rs"));
3064 assert!(loaded_context.text.contains("file2.rs"));
3065 assert!(loaded_context.text.contains("file3.rs"));
3066 assert!(loaded_context.text.contains("file4.rs"));
3067
3068 let new_contexts = context_store.update(cx, |store, cx| {
3069 // Remove file4.rs
3070 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3071 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3072 });
3073 assert_eq!(new_contexts.len(), 2);
3074 let loaded_context = cx
3075 .update(|cx| load_context(new_contexts, &project, &None, cx))
3076 .await
3077 .loaded_context;
3078
3079 assert!(!loaded_context.text.contains("file1.rs"));
3080 assert!(loaded_context.text.contains("file2.rs"));
3081 assert!(loaded_context.text.contains("file3.rs"));
3082 assert!(!loaded_context.text.contains("file4.rs"));
3083
3084 let new_contexts = context_store.update(cx, |store, cx| {
3085 // Remove file3.rs
3086 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3087 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3088 });
3089 assert_eq!(new_contexts.len(), 1);
3090 let loaded_context = cx
3091 .update(|cx| load_context(new_contexts, &project, &None, cx))
3092 .await
3093 .loaded_context;
3094
3095 assert!(!loaded_context.text.contains("file1.rs"));
3096 assert!(loaded_context.text.contains("file2.rs"));
3097 assert!(!loaded_context.text.contains("file3.rs"));
3098 assert!(!loaded_context.text.contains("file4.rs"));
3099 }
3100
3101 #[gpui::test]
3102 async fn test_message_without_files(cx: &mut TestAppContext) {
3103 init_test_settings(cx);
3104
3105 let project = create_test_project(
3106 cx,
3107 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3108 )
3109 .await;
3110
3111 let (_, _thread_store, thread, _context_store, model) =
3112 setup_test_environment(cx, project.clone()).await;
3113
3114 // Insert user message without any context (empty context vector)
3115 let message_id = thread.update(cx, |thread, cx| {
3116 thread.insert_user_message(
3117 "What is the best way to learn Rust?",
3118 ContextLoadResult::default(),
3119 None,
3120 Vec::new(),
3121 cx,
3122 )
3123 });
3124
3125 // Check content and context in message object
3126 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3127
3128 // Context should be empty when no files are included
3129 assert_eq!(message.role, Role::User);
3130 assert_eq!(message.segments.len(), 1);
3131 assert_eq!(
3132 message.segments[0],
3133 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3134 );
3135 assert_eq!(message.loaded_context.text, "");
3136
3137 // Check message in request
3138 let request = thread.update(cx, |thread, cx| {
3139 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3140 });
3141
3142 assert_eq!(request.messages.len(), 2);
3143 assert_eq!(
3144 request.messages[1].string_contents(),
3145 "What is the best way to learn Rust?"
3146 );
3147
3148 // Add second message, also without context
3149 let message2_id = thread.update(cx, |thread, cx| {
3150 thread.insert_user_message(
3151 "Are there any good books?",
3152 ContextLoadResult::default(),
3153 None,
3154 Vec::new(),
3155 cx,
3156 )
3157 });
3158
3159 let message2 =
3160 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3161 assert_eq!(message2.loaded_context.text, "");
3162
3163 // Check that both messages appear in the request
3164 let request = thread.update(cx, |thread, cx| {
3165 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3166 });
3167
3168 assert_eq!(request.messages.len(), 3);
3169 assert_eq!(
3170 request.messages[1].string_contents(),
3171 "What is the best way to learn Rust?"
3172 );
3173 assert_eq!(
3174 request.messages[2].string_contents(),
3175 "Are there any good books?"
3176 );
3177 }
3178
3179 #[gpui::test]
3180 async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
3181 init_test_settings(cx);
3182
3183 let project = create_test_project(
3184 cx,
3185 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3186 )
3187 .await;
3188
3189 let (_workspace, thread_store, thread, _context_store, _model) =
3190 setup_test_environment(cx, project.clone()).await;
3191
3192 // Check that we are starting with the default profile
3193 let profile = cx.read(|cx| thread.read(cx).profile.clone());
3194 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3195 assert_eq!(
3196 profile,
3197 AgentProfile::new(AgentProfileId::default(), tool_set)
3198 );
3199 }
3200
3201 #[gpui::test]
3202 async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
3203 init_test_settings(cx);
3204
3205 let project = create_test_project(
3206 cx,
3207 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3208 )
3209 .await;
3210
3211 let (_workspace, thread_store, thread, _context_store, _model) =
3212 setup_test_environment(cx, project.clone()).await;
3213
3214 // Profile gets serialized with default values
3215 let serialized = thread
3216 .update(cx, |thread, cx| thread.serialize(cx))
3217 .await
3218 .unwrap();
3219
3220 assert_eq!(serialized.profile, Some(AgentProfileId::default()));
3221
3222 let deserialized = cx.update(|cx| {
3223 thread.update(cx, |thread, cx| {
3224 Thread::deserialize(
3225 thread.id.clone(),
3226 serialized,
3227 thread.project.clone(),
3228 thread.tools.clone(),
3229 thread.prompt_builder.clone(),
3230 thread.project_context.clone(),
3231 None,
3232 cx,
3233 )
3234 })
3235 });
3236 let tool_set = cx.read(|cx| thread_store.read(cx).tools());
3237
3238 assert_eq!(
3239 deserialized.profile,
3240 AgentProfile::new(AgentProfileId::default(), tool_set)
3241 );
3242 }
3243
3244 #[gpui::test]
3245 async fn test_temperature_setting(cx: &mut TestAppContext) {
3246 init_test_settings(cx);
3247
3248 let project = create_test_project(
3249 cx,
3250 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3251 )
3252 .await;
3253
3254 let (_workspace, _thread_store, thread, _context_store, model) =
3255 setup_test_environment(cx, project.clone()).await;
3256
3257 // Both model and provider
3258 cx.update(|cx| {
3259 AgentSettings::override_global(
3260 AgentSettings {
3261 model_parameters: vec![LanguageModelParameters {
3262 provider: Some(model.provider_id().0.to_string().into()),
3263 model: Some(model.id().0.clone()),
3264 temperature: Some(0.66),
3265 }],
3266 ..AgentSettings::get_global(cx).clone()
3267 },
3268 cx,
3269 );
3270 });
3271
3272 let request = thread.update(cx, |thread, cx| {
3273 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3274 });
3275 assert_eq!(request.temperature, Some(0.66));
3276
3277 // Only model
3278 cx.update(|cx| {
3279 AgentSettings::override_global(
3280 AgentSettings {
3281 model_parameters: vec![LanguageModelParameters {
3282 provider: None,
3283 model: Some(model.id().0.clone()),
3284 temperature: Some(0.66),
3285 }],
3286 ..AgentSettings::get_global(cx).clone()
3287 },
3288 cx,
3289 );
3290 });
3291
3292 let request = thread.update(cx, |thread, cx| {
3293 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3294 });
3295 assert_eq!(request.temperature, Some(0.66));
3296
3297 // Only provider
3298 cx.update(|cx| {
3299 AgentSettings::override_global(
3300 AgentSettings {
3301 model_parameters: vec![LanguageModelParameters {
3302 provider: Some(model.provider_id().0.to_string().into()),
3303 model: None,
3304 temperature: Some(0.66),
3305 }],
3306 ..AgentSettings::get_global(cx).clone()
3307 },
3308 cx,
3309 );
3310 });
3311
3312 let request = thread.update(cx, |thread, cx| {
3313 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3314 });
3315 assert_eq!(request.temperature, Some(0.66));
3316
3317 // Same model name, different provider
3318 cx.update(|cx| {
3319 AgentSettings::override_global(
3320 AgentSettings {
3321 model_parameters: vec![LanguageModelParameters {
3322 provider: Some("anthropic".into()),
3323 model: Some(model.id().0.clone()),
3324 temperature: Some(0.66),
3325 }],
3326 ..AgentSettings::get_global(cx).clone()
3327 },
3328 cx,
3329 );
3330 });
3331
3332 let request = thread.update(cx, |thread, cx| {
3333 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3334 });
3335 assert_eq!(request.temperature, None);
3336 }
3337
3338 #[gpui::test]
3339 async fn test_thread_summary(cx: &mut TestAppContext) {
3340 init_test_settings(cx);
3341
3342 let project = create_test_project(cx, json!({})).await;
3343
3344 let (_, _thread_store, thread, _context_store, model) =
3345 setup_test_environment(cx, project.clone()).await;
3346
3347 // Initial state should be pending
3348 thread.read_with(cx, |thread, _| {
3349 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3350 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3351 });
3352
3353 // Manually setting the summary should not be allowed in this state
3354 thread.update(cx, |thread, cx| {
3355 thread.set_summary("This should not work", cx);
3356 });
3357
3358 thread.read_with(cx, |thread, _| {
3359 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3360 });
3361
3362 // Send a message
3363 thread.update(cx, |thread, cx| {
3364 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3365 thread.send_to_model(
3366 model.clone(),
3367 CompletionIntent::ThreadSummarization,
3368 None,
3369 cx,
3370 );
3371 });
3372
3373 let fake_model = model.as_fake();
3374 simulate_successful_response(&fake_model, cx);
3375
3376 // Should start generating summary when there are >= 2 messages
3377 thread.read_with(cx, |thread, _| {
3378 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3379 });
3380
3381 // Should not be able to set the summary while generating
3382 thread.update(cx, |thread, cx| {
3383 thread.set_summary("This should not work either", cx);
3384 });
3385
3386 thread.read_with(cx, |thread, _| {
3387 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3388 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3389 });
3390
3391 cx.run_until_parked();
3392 fake_model.stream_last_completion_response("Brief");
3393 fake_model.stream_last_completion_response(" Introduction");
3394 fake_model.end_last_completion_stream();
3395 cx.run_until_parked();
3396
3397 // Summary should be set
3398 thread.read_with(cx, |thread, _| {
3399 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3400 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3401 });
3402
3403 // Now we should be able to set a summary
3404 thread.update(cx, |thread, cx| {
3405 thread.set_summary("Brief Intro", cx);
3406 });
3407
3408 thread.read_with(cx, |thread, _| {
3409 assert_eq!(thread.summary().or_default(), "Brief Intro");
3410 });
3411
3412 // Test setting an empty summary (should default to DEFAULT)
3413 thread.update(cx, |thread, cx| {
3414 thread.set_summary("", cx);
3415 });
3416
3417 thread.read_with(cx, |thread, _| {
3418 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3419 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3420 });
3421 }
3422
3423 #[gpui::test]
3424 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3425 init_test_settings(cx);
3426
3427 let project = create_test_project(cx, json!({})).await;
3428
3429 let (_, _thread_store, thread, _context_store, model) =
3430 setup_test_environment(cx, project.clone()).await;
3431
3432 test_summarize_error(&model, &thread, cx);
3433
3434 // Now we should be able to set a summary
3435 thread.update(cx, |thread, cx| {
3436 thread.set_summary("Brief Intro", cx);
3437 });
3438
3439 thread.read_with(cx, |thread, _| {
3440 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3441 assert_eq!(thread.summary().or_default(), "Brief Intro");
3442 });
3443 }
3444
3445 #[gpui::test]
3446 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3447 init_test_settings(cx);
3448
3449 let project = create_test_project(cx, json!({})).await;
3450
3451 let (_, _thread_store, thread, _context_store, model) =
3452 setup_test_environment(cx, project.clone()).await;
3453
3454 test_summarize_error(&model, &thread, cx);
3455
3456 // Sending another message should not trigger another summarize request
3457 thread.update(cx, |thread, cx| {
3458 thread.insert_user_message(
3459 "How are you?",
3460 ContextLoadResult::default(),
3461 None,
3462 vec![],
3463 cx,
3464 );
3465 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3466 });
3467
3468 let fake_model = model.as_fake();
3469 simulate_successful_response(&fake_model, cx);
3470
3471 thread.read_with(cx, |thread, _| {
3472 // State is still Error, not Generating
3473 assert!(matches!(thread.summary(), ThreadSummary::Error));
3474 });
3475
3476 // But the summarize request can be invoked manually
3477 thread.update(cx, |thread, cx| {
3478 thread.summarize(cx);
3479 });
3480
3481 thread.read_with(cx, |thread, _| {
3482 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3483 });
3484
3485 cx.run_until_parked();
3486 fake_model.stream_last_completion_response("A successful summary");
3487 fake_model.end_last_completion_stream();
3488 cx.run_until_parked();
3489
3490 thread.read_with(cx, |thread, _| {
3491 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3492 assert_eq!(thread.summary().or_default(), "A successful summary");
3493 });
3494 }
3495
3496 fn test_summarize_error(
3497 model: &Arc<dyn LanguageModel>,
3498 thread: &Entity<Thread>,
3499 cx: &mut TestAppContext,
3500 ) {
3501 thread.update(cx, |thread, cx| {
3502 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3503 thread.send_to_model(
3504 model.clone(),
3505 CompletionIntent::ThreadSummarization,
3506 None,
3507 cx,
3508 );
3509 });
3510
3511 let fake_model = model.as_fake();
3512 simulate_successful_response(&fake_model, cx);
3513
3514 thread.read_with(cx, |thread, _| {
3515 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3516 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3517 });
3518
3519 // Simulate summary request ending
3520 cx.run_until_parked();
3521 fake_model.end_last_completion_stream();
3522 cx.run_until_parked();
3523
3524 // State is set to Error and default message
3525 thread.read_with(cx, |thread, _| {
3526 assert!(matches!(thread.summary(), ThreadSummary::Error));
3527 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3528 });
3529 }
3530
3531 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3532 cx.run_until_parked();
3533 fake_model.stream_last_completion_response("Assistant response");
3534 fake_model.end_last_completion_stream();
3535 cx.run_until_parked();
3536 }
3537
3538 fn init_test_settings(cx: &mut TestAppContext) {
3539 cx.update(|cx| {
3540 let settings_store = SettingsStore::test(cx);
3541 cx.set_global(settings_store);
3542 language::init(cx);
3543 Project::init_settings(cx);
3544 AgentSettings::register(cx);
3545 prompt_store::init(cx);
3546 thread_store::init(cx);
3547 workspace::init_settings(cx);
3548 language_model::init_settings(cx);
3549 ThemeSettings::register(cx);
3550 EditorSettings::register(cx);
3551 ToolRegistry::default_global(cx);
3552 });
3553 }
3554
3555 // Helper to create a test project with test files
3556 async fn create_test_project(
3557 cx: &mut TestAppContext,
3558 files: serde_json::Value,
3559 ) -> Entity<Project> {
3560 let fs = FakeFs::new(cx.executor());
3561 fs.insert_tree(path!("/test"), files).await;
3562 Project::test(fs, [path!("/test").as_ref()], cx).await
3563 }
3564
3565 async fn setup_test_environment(
3566 cx: &mut TestAppContext,
3567 project: Entity<Project>,
3568 ) -> (
3569 Entity<Workspace>,
3570 Entity<ThreadStore>,
3571 Entity<Thread>,
3572 Entity<ContextStore>,
3573 Arc<dyn LanguageModel>,
3574 ) {
3575 let (workspace, cx) =
3576 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3577
3578 let thread_store = cx
3579 .update(|_, cx| {
3580 ThreadStore::load(
3581 project.clone(),
3582 cx.new(|_| ToolWorkingSet::default()),
3583 None,
3584 Arc::new(PromptBuilder::new(None).unwrap()),
3585 cx,
3586 )
3587 })
3588 .await
3589 .unwrap();
3590
3591 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3592 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3593
3594 let provider = Arc::new(FakeLanguageModelProvider);
3595 let model = provider.test_model();
3596 let model: Arc<dyn LanguageModel> = Arc::new(model);
3597
3598 cx.update(|_, cx| {
3599 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3600 registry.set_default_model(
3601 Some(ConfiguredModel {
3602 provider: provider.clone(),
3603 model: model.clone(),
3604 }),
3605 cx,
3606 );
3607 registry.set_thread_summary_model(
3608 Some(ConfiguredModel {
3609 provider,
3610 model: model.clone(),
3611 }),
3612 cx,
3613 );
3614 })
3615 });
3616
3617 (workspace, thread_store, thread, context_store, model)
3618 }
3619
3620 async fn add_file_to_context(
3621 project: &Entity<Project>,
3622 context_store: &Entity<ContextStore>,
3623 path: &str,
3624 cx: &mut TestAppContext,
3625 ) -> Result<Entity<language::Buffer>> {
3626 let buffer_path = project
3627 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3628 .unwrap();
3629
3630 let buffer = project
3631 .update(cx, |project, cx| {
3632 project.open_buffer(buffer_path.clone(), cx)
3633 })
3634 .await
3635 .unwrap();
3636
3637 context_store.update(cx, |context_store, cx| {
3638 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3639 });
3640
3641 Ok(buffer)
3642 }
3643}