1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::CompletionRequestStatus;
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118 pub is_hidden: bool,
119}
120
121impl Message {
122 /// Returns whether the message contains any meaningful text that should be displayed
123 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
124 pub fn should_display_content(&self) -> bool {
125 self.segments.iter().all(|segment| segment.should_display())
126 }
127
128 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
129 if let Some(MessageSegment::Thinking {
130 text: segment,
131 signature: current_signature,
132 }) = self.segments.last_mut()
133 {
134 if let Some(signature) = signature {
135 *current_signature = Some(signature);
136 }
137 segment.push_str(text);
138 } else {
139 self.segments.push(MessageSegment::Thinking {
140 text: text.to_string(),
141 signature,
142 });
143 }
144 }
145
146 pub fn push_text(&mut self, text: &str) {
147 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Text(text.to_string()));
151 }
152 }
153
154 pub fn to_string(&self) -> String {
155 let mut result = String::new();
156
157 if !self.loaded_context.text.is_empty() {
158 result.push_str(&self.loaded_context.text);
159 }
160
161 for segment in &self.segments {
162 match segment {
163 MessageSegment::Text(text) => result.push_str(text),
164 MessageSegment::Thinking { text, .. } => {
165 result.push_str("<think>\n");
166 result.push_str(text);
167 result.push_str("\n</think>");
168 }
169 MessageSegment::RedactedThinking(_) => {}
170 }
171 }
172
173 result
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum MessageSegment {
179 Text(String),
180 Thinking {
181 text: String,
182 signature: Option<String>,
183 },
184 RedactedThinking(Vec<u8>),
185}
186
187impl MessageSegment {
188 pub fn should_display(&self) -> bool {
189 match self {
190 Self::Text(text) => text.is_empty(),
191 Self::Thinking { text, .. } => text.is_empty(),
192 Self::RedactedThinking(_) => false,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ProjectSnapshot {
199 pub worktree_snapshots: Vec<WorktreeSnapshot>,
200 pub unsaved_buffer_paths: Vec<String>,
201 pub timestamp: DateTime<Utc>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct WorktreeSnapshot {
206 pub worktree_path: String,
207 pub git_state: Option<GitState>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct GitState {
212 pub remote_url: Option<String>,
213 pub head_sha: Option<String>,
214 pub current_branch: Option<String>,
215 pub diff: Option<String>,
216}
217
218#[derive(Clone, Debug)]
219pub struct ThreadCheckpoint {
220 message_id: MessageId,
221 git_checkpoint: GitStoreCheckpoint,
222}
223
224#[derive(Copy, Clone, Debug, PartialEq, Eq)]
225pub enum ThreadFeedback {
226 Positive,
227 Negative,
228}
229
230pub enum LastRestoreCheckpoint {
231 Pending {
232 message_id: MessageId,
233 },
234 Error {
235 message_id: MessageId,
236 error: String,
237 },
238}
239
240impl LastRestoreCheckpoint {
241 pub fn message_id(&self) -> MessageId {
242 match self {
243 LastRestoreCheckpoint::Pending { message_id } => *message_id,
244 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
245 }
246 }
247}
248
249#[derive(Clone, Debug, Default, Serialize, Deserialize)]
250pub enum DetailedSummaryState {
251 #[default]
252 NotGenerated,
253 Generating {
254 message_id: MessageId,
255 },
256 Generated {
257 text: SharedString,
258 message_id: MessageId,
259 },
260}
261
262impl DetailedSummaryState {
263 fn text(&self) -> Option<SharedString> {
264 if let Self::Generated { text, .. } = self {
265 Some(text.clone())
266 } else {
267 None
268 }
269 }
270}
271
272#[derive(Default, Debug)]
273pub struct TotalTokenUsage {
274 pub total: usize,
275 pub max: usize,
276}
277
278impl TotalTokenUsage {
279 pub fn ratio(&self) -> TokenUsageRatio {
280 #[cfg(debug_assertions)]
281 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
282 .unwrap_or("0.8".to_string())
283 .parse()
284 .unwrap();
285 #[cfg(not(debug_assertions))]
286 let warning_threshold: f32 = 0.8;
287
288 // When the maximum is unknown because there is no selected model,
289 // avoid showing the token limit warning.
290 if self.max == 0 {
291 TokenUsageRatio::Normal
292 } else if self.total >= self.max {
293 TokenUsageRatio::Exceeded
294 } else if self.total as f32 / self.max as f32 >= warning_threshold {
295 TokenUsageRatio::Warning
296 } else {
297 TokenUsageRatio::Normal
298 }
299 }
300
301 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
302 TotalTokenUsage {
303 total: self.total + tokens,
304 max: self.max,
305 }
306 }
307}
308
309#[derive(Debug, Default, PartialEq, Eq)]
310pub enum TokenUsageRatio {
311 #[default]
312 Normal,
313 Warning,
314 Exceeded,
315}
316
317#[derive(Debug, Clone, Copy)]
318pub enum QueueState {
319 Sending,
320 Queued { position: usize },
321 Started,
322}
323
324/// A thread of conversation with the LLM.
325pub struct Thread {
326 id: ThreadId,
327 updated_at: DateTime<Utc>,
328 summary: ThreadSummary,
329 pending_summary: Task<Option<()>>,
330 detailed_summary_task: Task<Option<()>>,
331 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
332 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
333 completion_mode: agent_settings::CompletionMode,
334 messages: Vec<Message>,
335 next_message_id: MessageId,
336 last_prompt_id: PromptId,
337 project_context: SharedProjectContext,
338 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
339 completion_count: usize,
340 pending_completions: Vec<PendingCompletion>,
341 project: Entity<Project>,
342 prompt_builder: Arc<PromptBuilder>,
343 tools: Entity<ToolWorkingSet>,
344 tool_use: ToolUseState,
345 action_log: Entity<ActionLog>,
346 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
347 pending_checkpoint: Option<ThreadCheckpoint>,
348 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
349 request_token_usage: Vec<TokenUsage>,
350 cumulative_token_usage: TokenUsage,
351 exceeded_window_error: Option<ExceededWindowError>,
352 last_usage: Option<RequestUsage>,
353 tool_use_limit_reached: bool,
354 feedback: Option<ThreadFeedback>,
355 message_feedback: HashMap<MessageId, ThreadFeedback>,
356 last_auto_capture_at: Option<Instant>,
357 last_received_chunk_at: Option<Instant>,
358 request_callback: Option<
359 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
360 >,
361 remaining_turns: u32,
362 configured_model: Option<ConfiguredModel>,
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
366pub enum ThreadSummary {
367 Pending,
368 Generating,
369 Ready(SharedString),
370 Error,
371}
372
373impl ThreadSummary {
374 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
375
376 pub fn or_default(&self) -> SharedString {
377 self.unwrap_or(Self::DEFAULT)
378 }
379
380 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
381 self.ready().unwrap_or_else(|| message.into())
382 }
383
384 pub fn ready(&self) -> Option<SharedString> {
385 match self {
386 ThreadSummary::Ready(summary) => Some(summary.clone()),
387 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
388 }
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ExceededWindowError {
394 /// Model used when last message exceeded context window
395 model_id: LanguageModelId,
396 /// Token count including last message
397 token_count: usize,
398}
399
400impl Thread {
401 pub fn new(
402 project: Entity<Project>,
403 tools: Entity<ToolWorkingSet>,
404 prompt_builder: Arc<PromptBuilder>,
405 system_prompt: SharedProjectContext,
406 cx: &mut Context<Self>,
407 ) -> Self {
408 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
409 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
410
411 Self {
412 id: ThreadId::new(),
413 updated_at: Utc::now(),
414 summary: ThreadSummary::Pending,
415 pending_summary: Task::ready(None),
416 detailed_summary_task: Task::ready(None),
417 detailed_summary_tx,
418 detailed_summary_rx,
419 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
420 messages: Vec::new(),
421 next_message_id: MessageId(0),
422 last_prompt_id: PromptId::new(),
423 project_context: system_prompt,
424 checkpoints_by_message: HashMap::default(),
425 completion_count: 0,
426 pending_completions: Vec::new(),
427 project: project.clone(),
428 prompt_builder,
429 tools: tools.clone(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 tool_use: ToolUseState::new(tools.clone()),
433 action_log: cx.new(|_| ActionLog::new(project.clone())),
434 initial_project_snapshot: {
435 let project_snapshot = Self::project_snapshot(project, cx);
436 cx.foreground_executor()
437 .spawn(async move { Some(project_snapshot.await) })
438 .shared()
439 },
440 request_token_usage: Vec::new(),
441 cumulative_token_usage: TokenUsage::default(),
442 exceeded_window_error: None,
443 last_usage: None,
444 tool_use_limit_reached: false,
445 feedback: None,
446 message_feedback: HashMap::default(),
447 last_auto_capture_at: None,
448 last_received_chunk_at: None,
449 request_callback: None,
450 remaining_turns: u32::MAX,
451 configured_model,
452 }
453 }
454
455 pub fn deserialize(
456 id: ThreadId,
457 serialized: SerializedThread,
458 project: Entity<Project>,
459 tools: Entity<ToolWorkingSet>,
460 prompt_builder: Arc<PromptBuilder>,
461 project_context: SharedProjectContext,
462 window: Option<&mut Window>, // None in headless mode
463 cx: &mut Context<Self>,
464 ) -> Self {
465 let next_message_id = MessageId(
466 serialized
467 .messages
468 .last()
469 .map(|message| message.id.0 + 1)
470 .unwrap_or(0),
471 );
472 let tool_use = ToolUseState::from_serialized_messages(
473 tools.clone(),
474 &serialized.messages,
475 project.clone(),
476 window,
477 cx,
478 );
479 let (detailed_summary_tx, detailed_summary_rx) =
480 postage::watch::channel_with(serialized.detailed_summary_state);
481
482 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
483 serialized
484 .model
485 .and_then(|model| {
486 let model = SelectedModel {
487 provider: model.provider.clone().into(),
488 model: model.model.clone().into(),
489 };
490 registry.select_model(&model, cx)
491 })
492 .or_else(|| registry.default_model())
493 });
494
495 let completion_mode = serialized
496 .completion_mode
497 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
498
499 Self {
500 id,
501 updated_at: serialized.updated_at,
502 summary: ThreadSummary::Ready(serialized.summary),
503 pending_summary: Task::ready(None),
504 detailed_summary_task: Task::ready(None),
505 detailed_summary_tx,
506 detailed_summary_rx,
507 completion_mode,
508 messages: serialized
509 .messages
510 .into_iter()
511 .map(|message| Message {
512 id: message.id,
513 role: message.role,
514 segments: message
515 .segments
516 .into_iter()
517 .map(|segment| match segment {
518 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
519 SerializedMessageSegment::Thinking { text, signature } => {
520 MessageSegment::Thinking { text, signature }
521 }
522 SerializedMessageSegment::RedactedThinking { data } => {
523 MessageSegment::RedactedThinking(data)
524 }
525 })
526 .collect(),
527 loaded_context: LoadedContext {
528 contexts: Vec::new(),
529 text: message.context,
530 images: Vec::new(),
531 },
532 creases: message
533 .creases
534 .into_iter()
535 .map(|crease| MessageCrease {
536 range: crease.start..crease.end,
537 metadata: CreaseMetadata {
538 icon_path: crease.icon_path,
539 label: crease.label,
540 },
541 context: None,
542 })
543 .collect(),
544 is_hidden: message.is_hidden,
545 })
546 .collect(),
547 next_message_id,
548 last_prompt_id: PromptId::new(),
549 project_context,
550 checkpoints_by_message: HashMap::default(),
551 completion_count: 0,
552 pending_completions: Vec::new(),
553 last_restore_checkpoint: None,
554 pending_checkpoint: None,
555 project: project.clone(),
556 prompt_builder,
557 tools,
558 tool_use,
559 action_log: cx.new(|_| ActionLog::new(project)),
560 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
561 request_token_usage: serialized.request_token_usage,
562 cumulative_token_usage: serialized.cumulative_token_usage,
563 exceeded_window_error: None,
564 last_usage: None,
565 tool_use_limit_reached: serialized.tool_use_limit_reached,
566 feedback: None,
567 message_feedback: HashMap::default(),
568 last_auto_capture_at: None,
569 last_received_chunk_at: None,
570 request_callback: None,
571 remaining_turns: u32::MAX,
572 configured_model,
573 }
574 }
575
576 pub fn set_request_callback(
577 &mut self,
578 callback: impl 'static
579 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
580 ) {
581 self.request_callback = Some(Box::new(callback));
582 }
583
584 pub fn id(&self) -> &ThreadId {
585 &self.id
586 }
587
588 pub fn is_empty(&self) -> bool {
589 self.messages.is_empty()
590 }
591
592 pub fn updated_at(&self) -> DateTime<Utc> {
593 self.updated_at
594 }
595
596 pub fn touch_updated_at(&mut self) {
597 self.updated_at = Utc::now();
598 }
599
600 pub fn advance_prompt_id(&mut self) {
601 self.last_prompt_id = PromptId::new();
602 }
603
604 pub fn project_context(&self) -> SharedProjectContext {
605 self.project_context.clone()
606 }
607
608 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
609 if self.configured_model.is_none() {
610 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
611 }
612 self.configured_model.clone()
613 }
614
615 pub fn configured_model(&self) -> Option<ConfiguredModel> {
616 self.configured_model.clone()
617 }
618
619 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
620 self.configured_model = model;
621 cx.notify();
622 }
623
624 pub fn summary(&self) -> &ThreadSummary {
625 &self.summary
626 }
627
628 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
629 let current_summary = match &self.summary {
630 ThreadSummary::Pending | ThreadSummary::Generating => return,
631 ThreadSummary::Ready(summary) => summary,
632 ThreadSummary::Error => &ThreadSummary::DEFAULT,
633 };
634
635 let mut new_summary = new_summary.into();
636
637 if new_summary.is_empty() {
638 new_summary = ThreadSummary::DEFAULT;
639 }
640
641 if current_summary != &new_summary {
642 self.summary = ThreadSummary::Ready(new_summary);
643 cx.emit(ThreadEvent::SummaryChanged);
644 }
645 }
646
647 pub fn completion_mode(&self) -> CompletionMode {
648 self.completion_mode
649 }
650
651 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
652 self.completion_mode = mode;
653 }
654
655 pub fn message(&self, id: MessageId) -> Option<&Message> {
656 let index = self
657 .messages
658 .binary_search_by(|message| message.id.cmp(&id))
659 .ok()?;
660
661 self.messages.get(index)
662 }
663
664 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
665 self.messages.iter()
666 }
667
668 pub fn is_generating(&self) -> bool {
669 !self.pending_completions.is_empty() || !self.all_tools_finished()
670 }
671
672 /// Indicates whether streaming of language model events is stale.
673 /// When `is_generating()` is false, this method returns `None`.
674 pub fn is_generation_stale(&self) -> Option<bool> {
675 const STALE_THRESHOLD: u128 = 250;
676
677 self.last_received_chunk_at
678 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
679 }
680
681 fn received_chunk(&mut self) {
682 self.last_received_chunk_at = Some(Instant::now());
683 }
684
685 pub fn queue_state(&self) -> Option<QueueState> {
686 self.pending_completions
687 .first()
688 .map(|pending_completion| pending_completion.queue_state)
689 }
690
691 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
692 &self.tools
693 }
694
695 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
696 self.tool_use
697 .pending_tool_uses()
698 .into_iter()
699 .find(|tool_use| &tool_use.id == id)
700 }
701
702 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
703 self.tool_use
704 .pending_tool_uses()
705 .into_iter()
706 .filter(|tool_use| tool_use.status.needs_confirmation())
707 }
708
709 pub fn has_pending_tool_uses(&self) -> bool {
710 !self.tool_use.pending_tool_uses().is_empty()
711 }
712
713 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
714 self.checkpoints_by_message.get(&id).cloned()
715 }
716
717 pub fn restore_checkpoint(
718 &mut self,
719 checkpoint: ThreadCheckpoint,
720 cx: &mut Context<Self>,
721 ) -> Task<Result<()>> {
722 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
723 message_id: checkpoint.message_id,
724 });
725 cx.emit(ThreadEvent::CheckpointChanged);
726 cx.notify();
727
728 let git_store = self.project().read(cx).git_store().clone();
729 let restore = git_store.update(cx, |git_store, cx| {
730 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
731 });
732
733 cx.spawn(async move |this, cx| {
734 let result = restore.await;
735 this.update(cx, |this, cx| {
736 if let Err(err) = result.as_ref() {
737 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
738 message_id: checkpoint.message_id,
739 error: err.to_string(),
740 });
741 } else {
742 this.truncate(checkpoint.message_id, cx);
743 this.last_restore_checkpoint = None;
744 }
745 this.pending_checkpoint = None;
746 cx.emit(ThreadEvent::CheckpointChanged);
747 cx.notify();
748 })?;
749 result
750 })
751 }
752
753 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
754 let pending_checkpoint = if self.is_generating() {
755 return;
756 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
757 checkpoint
758 } else {
759 return;
760 };
761
762 self.finalize_checkpoint(pending_checkpoint, cx);
763 }
764
765 fn finalize_checkpoint(
766 &mut self,
767 pending_checkpoint: ThreadCheckpoint,
768 cx: &mut Context<Self>,
769 ) {
770 let git_store = self.project.read(cx).git_store().clone();
771 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
772 cx.spawn(async move |this, cx| match final_checkpoint.await {
773 Ok(final_checkpoint) => {
774 let equal = git_store
775 .update(cx, |store, cx| {
776 store.compare_checkpoints(
777 pending_checkpoint.git_checkpoint.clone(),
778 final_checkpoint.clone(),
779 cx,
780 )
781 })?
782 .await
783 .unwrap_or(false);
784
785 if !equal {
786 this.update(cx, |this, cx| {
787 this.insert_checkpoint(pending_checkpoint, cx)
788 })?;
789 }
790
791 Ok(())
792 }
793 Err(_) => this.update(cx, |this, cx| {
794 this.insert_checkpoint(pending_checkpoint, cx)
795 }),
796 })
797 .detach();
798 }
799
800 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
801 self.checkpoints_by_message
802 .insert(checkpoint.message_id, checkpoint);
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 }
806
807 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
808 self.last_restore_checkpoint.as_ref()
809 }
810
811 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
812 let Some(message_ix) = self
813 .messages
814 .iter()
815 .rposition(|message| message.id == message_id)
816 else {
817 return;
818 };
819 for deleted_message in self.messages.drain(message_ix..) {
820 self.checkpoints_by_message.remove(&deleted_message.id);
821 }
822 cx.notify();
823 }
824
825 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
826 self.messages
827 .iter()
828 .find(|message| message.id == id)
829 .into_iter()
830 .flat_map(|message| message.loaded_context.contexts.iter())
831 }
832
833 pub fn is_turn_end(&self, ix: usize) -> bool {
834 if self.messages.is_empty() {
835 return false;
836 }
837
838 if !self.is_generating() && ix == self.messages.len() - 1 {
839 return true;
840 }
841
842 let Some(message) = self.messages.get(ix) else {
843 return false;
844 };
845
846 if message.role != Role::Assistant {
847 return false;
848 }
849
850 self.messages
851 .get(ix + 1)
852 .and_then(|message| {
853 self.message(message.id)
854 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
855 })
856 .unwrap_or(false)
857 }
858
859 pub fn last_usage(&self) -> Option<RequestUsage> {
860 self.last_usage
861 }
862
863 pub fn tool_use_limit_reached(&self) -> bool {
864 self.tool_use_limit_reached
865 }
866
867 /// Returns whether all of the tool uses have finished running.
868 pub fn all_tools_finished(&self) -> bool {
869 // If the only pending tool uses left are the ones with errors, then
870 // that means that we've finished running all of the pending tools.
871 self.tool_use
872 .pending_tool_uses()
873 .iter()
874 .all(|tool_use| tool_use.status.is_error())
875 }
876
877 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
878 self.tool_use.tool_uses_for_message(id, cx)
879 }
880
881 pub fn tool_results_for_message(
882 &self,
883 assistant_message_id: MessageId,
884 ) -> Vec<&LanguageModelToolResult> {
885 self.tool_use.tool_results_for_message(assistant_message_id)
886 }
887
888 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
889 self.tool_use.tool_result(id)
890 }
891
892 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
893 match &self.tool_use.tool_result(id)?.content {
894 LanguageModelToolResultContent::Text(text)
895 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
896 Some(text)
897 }
898 LanguageModelToolResultContent::Image(_) => {
899 // TODO: We should display image
900 None
901 }
902 }
903 }
904
905 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
906 self.tool_use.tool_result_card(id).cloned()
907 }
908
909 /// Return tools that are both enabled and supported by the model
910 pub fn available_tools(
911 &self,
912 cx: &App,
913 model: Arc<dyn LanguageModel>,
914 ) -> Vec<LanguageModelRequestTool> {
915 if model.supports_tools() {
916 self.tools()
917 .read(cx)
918 .enabled_tools(cx)
919 .into_iter()
920 .filter_map(|tool| {
921 // Skip tools that cannot be supported
922 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
923 Some(LanguageModelRequestTool {
924 name: tool.name(),
925 description: tool.description(),
926 input_schema,
927 })
928 })
929 .collect()
930 } else {
931 Vec::default()
932 }
933 }
934
935 pub fn insert_user_message(
936 &mut self,
937 text: impl Into<String>,
938 loaded_context: ContextLoadResult,
939 git_checkpoint: Option<GitStoreCheckpoint>,
940 creases: Vec<MessageCrease>,
941 cx: &mut Context<Self>,
942 ) -> MessageId {
943 if !loaded_context.referenced_buffers.is_empty() {
944 self.action_log.update(cx, |log, cx| {
945 for buffer in loaded_context.referenced_buffers {
946 log.buffer_read(buffer, cx);
947 }
948 });
949 }
950
951 let message_id = self.insert_message(
952 Role::User,
953 vec![MessageSegment::Text(text.into())],
954 loaded_context.loaded_context,
955 creases,
956 false,
957 cx,
958 );
959
960 if let Some(git_checkpoint) = git_checkpoint {
961 self.pending_checkpoint = Some(ThreadCheckpoint {
962 message_id,
963 git_checkpoint,
964 });
965 }
966
967 self.auto_capture_telemetry(cx);
968
969 message_id
970 }
971
972 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
973 let id = self.insert_message(
974 Role::User,
975 vec![MessageSegment::Text("Continue where you left off".into())],
976 LoadedContext::default(),
977 vec![],
978 true,
979 cx,
980 );
981 self.pending_checkpoint = None;
982
983 id
984 }
985
986 pub fn insert_assistant_message(
987 &mut self,
988 segments: Vec<MessageSegment>,
989 cx: &mut Context<Self>,
990 ) -> MessageId {
991 self.insert_message(
992 Role::Assistant,
993 segments,
994 LoadedContext::default(),
995 Vec::new(),
996 false,
997 cx,
998 )
999 }
1000
1001 pub fn insert_message(
1002 &mut self,
1003 role: Role,
1004 segments: Vec<MessageSegment>,
1005 loaded_context: LoadedContext,
1006 creases: Vec<MessageCrease>,
1007 is_hidden: bool,
1008 cx: &mut Context<Self>,
1009 ) -> MessageId {
1010 let id = self.next_message_id.post_inc();
1011 self.messages.push(Message {
1012 id,
1013 role,
1014 segments,
1015 loaded_context,
1016 creases,
1017 is_hidden,
1018 });
1019 self.touch_updated_at();
1020 cx.emit(ThreadEvent::MessageAdded(id));
1021 id
1022 }
1023
1024 pub fn edit_message(
1025 &mut self,
1026 id: MessageId,
1027 new_role: Role,
1028 new_segments: Vec<MessageSegment>,
1029 loaded_context: Option<LoadedContext>,
1030 checkpoint: Option<GitStoreCheckpoint>,
1031 cx: &mut Context<Self>,
1032 ) -> bool {
1033 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1034 return false;
1035 };
1036 message.role = new_role;
1037 message.segments = new_segments;
1038 if let Some(context) = loaded_context {
1039 message.loaded_context = context;
1040 }
1041 if let Some(git_checkpoint) = checkpoint {
1042 self.checkpoints_by_message.insert(
1043 id,
1044 ThreadCheckpoint {
1045 message_id: id,
1046 git_checkpoint,
1047 },
1048 );
1049 }
1050 self.touch_updated_at();
1051 cx.emit(ThreadEvent::MessageEdited(id));
1052 true
1053 }
1054
1055 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1056 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1057 return false;
1058 };
1059 self.messages.remove(index);
1060 self.touch_updated_at();
1061 cx.emit(ThreadEvent::MessageDeleted(id));
1062 true
1063 }
1064
1065 /// Returns the representation of this [`Thread`] in a textual form.
1066 ///
1067 /// This is the representation we use when attaching a thread as context to another thread.
1068 pub fn text(&self) -> String {
1069 let mut text = String::new();
1070
1071 for message in &self.messages {
1072 text.push_str(match message.role {
1073 language_model::Role::User => "User:",
1074 language_model::Role::Assistant => "Agent:",
1075 language_model::Role::System => "System:",
1076 });
1077 text.push('\n');
1078
1079 for segment in &message.segments {
1080 match segment {
1081 MessageSegment::Text(content) => text.push_str(content),
1082 MessageSegment::Thinking { text: content, .. } => {
1083 text.push_str(&format!("<think>{}</think>", content))
1084 }
1085 MessageSegment::RedactedThinking(_) => {}
1086 }
1087 }
1088 text.push('\n');
1089 }
1090
1091 text
1092 }
1093
1094 /// Serializes this thread into a format for storage or telemetry.
1095 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1096 let initial_project_snapshot = self.initial_project_snapshot.clone();
1097 cx.spawn(async move |this, cx| {
1098 let initial_project_snapshot = initial_project_snapshot.await;
1099 this.read_with(cx, |this, cx| SerializedThread {
1100 version: SerializedThread::VERSION.to_string(),
1101 summary: this.summary().or_default(),
1102 updated_at: this.updated_at(),
1103 messages: this
1104 .messages()
1105 .map(|message| SerializedMessage {
1106 id: message.id,
1107 role: message.role,
1108 segments: message
1109 .segments
1110 .iter()
1111 .map(|segment| match segment {
1112 MessageSegment::Text(text) => {
1113 SerializedMessageSegment::Text { text: text.clone() }
1114 }
1115 MessageSegment::Thinking { text, signature } => {
1116 SerializedMessageSegment::Thinking {
1117 text: text.clone(),
1118 signature: signature.clone(),
1119 }
1120 }
1121 MessageSegment::RedactedThinking(data) => {
1122 SerializedMessageSegment::RedactedThinking {
1123 data: data.clone(),
1124 }
1125 }
1126 })
1127 .collect(),
1128 tool_uses: this
1129 .tool_uses_for_message(message.id, cx)
1130 .into_iter()
1131 .map(|tool_use| SerializedToolUse {
1132 id: tool_use.id,
1133 name: tool_use.name,
1134 input: tool_use.input,
1135 })
1136 .collect(),
1137 tool_results: this
1138 .tool_results_for_message(message.id)
1139 .into_iter()
1140 .map(|tool_result| SerializedToolResult {
1141 tool_use_id: tool_result.tool_use_id.clone(),
1142 is_error: tool_result.is_error,
1143 content: tool_result.content.clone(),
1144 output: tool_result.output.clone(),
1145 })
1146 .collect(),
1147 context: message.loaded_context.text.clone(),
1148 creases: message
1149 .creases
1150 .iter()
1151 .map(|crease| SerializedCrease {
1152 start: crease.range.start,
1153 end: crease.range.end,
1154 icon_path: crease.metadata.icon_path.clone(),
1155 label: crease.metadata.label.clone(),
1156 })
1157 .collect(),
1158 is_hidden: message.is_hidden,
1159 })
1160 .collect(),
1161 initial_project_snapshot,
1162 cumulative_token_usage: this.cumulative_token_usage,
1163 request_token_usage: this.request_token_usage.clone(),
1164 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1165 exceeded_window_error: this.exceeded_window_error.clone(),
1166 model: this
1167 .configured_model
1168 .as_ref()
1169 .map(|model| SerializedLanguageModel {
1170 provider: model.provider.id().0.to_string(),
1171 model: model.model.id().0.to_string(),
1172 }),
1173 completion_mode: Some(this.completion_mode),
1174 tool_use_limit_reached: this.tool_use_limit_reached,
1175 })
1176 })
1177 }
1178
1179 pub fn remaining_turns(&self) -> u32 {
1180 self.remaining_turns
1181 }
1182
1183 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1184 self.remaining_turns = remaining_turns;
1185 }
1186
1187 pub fn send_to_model(
1188 &mut self,
1189 model: Arc<dyn LanguageModel>,
1190 window: Option<AnyWindowHandle>,
1191 cx: &mut Context<Self>,
1192 ) {
1193 if self.remaining_turns == 0 {
1194 return;
1195 }
1196
1197 self.remaining_turns -= 1;
1198
1199 let request = self.to_completion_request(model.clone(), cx);
1200
1201 self.stream_completion(request, model, window, cx);
1202 }
1203
1204 pub fn used_tools_since_last_user_message(&self) -> bool {
1205 for message in self.messages.iter().rev() {
1206 if self.tool_use.message_has_tool_results(message.id) {
1207 return true;
1208 } else if message.role == Role::User {
1209 return false;
1210 }
1211 }
1212
1213 false
1214 }
1215
1216 pub fn to_completion_request(
1217 &self,
1218 model: Arc<dyn LanguageModel>,
1219 cx: &mut Context<Self>,
1220 ) -> LanguageModelRequest {
1221 let mut request = LanguageModelRequest {
1222 thread_id: Some(self.id.to_string()),
1223 prompt_id: Some(self.last_prompt_id.to_string()),
1224 mode: None,
1225 messages: vec![],
1226 tools: Vec::new(),
1227 tool_choice: None,
1228 stop: Vec::new(),
1229 temperature: AgentSettings::temperature_for_model(&model, cx),
1230 };
1231
1232 let available_tools = self.available_tools(cx, model.clone());
1233 let available_tool_names = available_tools
1234 .iter()
1235 .map(|tool| tool.name.clone())
1236 .collect();
1237
1238 let model_context = &ModelContext {
1239 available_tools: available_tool_names,
1240 };
1241
1242 if let Some(project_context) = self.project_context.borrow().as_ref() {
1243 match self
1244 .prompt_builder
1245 .generate_assistant_system_prompt(project_context, model_context)
1246 {
1247 Err(err) => {
1248 let message = format!("{err:?}").into();
1249 log::error!("{message}");
1250 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1251 header: "Error generating system prompt".into(),
1252 message,
1253 }));
1254 }
1255 Ok(system_prompt) => {
1256 request.messages.push(LanguageModelRequestMessage {
1257 role: Role::System,
1258 content: vec![MessageContent::Text(system_prompt)],
1259 cache: true,
1260 });
1261 }
1262 }
1263 } else {
1264 let message = "Context for system prompt unexpectedly not ready.".into();
1265 log::error!("{message}");
1266 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1267 header: "Error generating system prompt".into(),
1268 message,
1269 }));
1270 }
1271
1272 let mut message_ix_to_cache = None;
1273 for message in &self.messages {
1274 let mut request_message = LanguageModelRequestMessage {
1275 role: message.role,
1276 content: Vec::new(),
1277 cache: false,
1278 };
1279
1280 message
1281 .loaded_context
1282 .add_to_request_message(&mut request_message);
1283
1284 for segment in &message.segments {
1285 match segment {
1286 MessageSegment::Text(text) => {
1287 if !text.is_empty() {
1288 request_message
1289 .content
1290 .push(MessageContent::Text(text.into()));
1291 }
1292 }
1293 MessageSegment::Thinking { text, signature } => {
1294 if !text.is_empty() {
1295 request_message.content.push(MessageContent::Thinking {
1296 text: text.into(),
1297 signature: signature.clone(),
1298 });
1299 }
1300 }
1301 MessageSegment::RedactedThinking(data) => {
1302 request_message
1303 .content
1304 .push(MessageContent::RedactedThinking(data.clone()));
1305 }
1306 };
1307 }
1308
1309 let mut cache_message = true;
1310 let mut tool_results_message = LanguageModelRequestMessage {
1311 role: Role::User,
1312 content: Vec::new(),
1313 cache: false,
1314 };
1315 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1316 if let Some(tool_result) = tool_result {
1317 request_message
1318 .content
1319 .push(MessageContent::ToolUse(tool_use.clone()));
1320 tool_results_message
1321 .content
1322 .push(MessageContent::ToolResult(LanguageModelToolResult {
1323 tool_use_id: tool_use.id.clone(),
1324 tool_name: tool_result.tool_name.clone(),
1325 is_error: tool_result.is_error,
1326 content: if tool_result.content.is_empty() {
1327 // Surprisingly, the API fails if we return an empty string here.
1328 // It thinks we are sending a tool use without a tool result.
1329 "<Tool returned an empty string>".into()
1330 } else {
1331 tool_result.content.clone()
1332 },
1333 output: None,
1334 }));
1335 } else {
1336 cache_message = false;
1337 log::debug!(
1338 "skipped tool use {:?} because it is still pending",
1339 tool_use
1340 );
1341 }
1342 }
1343
1344 if cache_message {
1345 message_ix_to_cache = Some(request.messages.len());
1346 }
1347 request.messages.push(request_message);
1348
1349 if !tool_results_message.content.is_empty() {
1350 if cache_message {
1351 message_ix_to_cache = Some(request.messages.len());
1352 }
1353 request.messages.push(tool_results_message);
1354 }
1355 }
1356
1357 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1358 if let Some(message_ix_to_cache) = message_ix_to_cache {
1359 request.messages[message_ix_to_cache].cache = true;
1360 }
1361
1362 self.attached_tracked_files_state(&mut request.messages, cx);
1363
1364 request.tools = available_tools;
1365 request.mode = if model.supports_max_mode() {
1366 Some(self.completion_mode.into())
1367 } else {
1368 Some(CompletionMode::Normal.into())
1369 };
1370
1371 request
1372 }
1373
1374 fn to_summarize_request(
1375 &self,
1376 model: &Arc<dyn LanguageModel>,
1377 added_user_message: String,
1378 cx: &App,
1379 ) -> LanguageModelRequest {
1380 let mut request = LanguageModelRequest {
1381 thread_id: None,
1382 prompt_id: None,
1383 mode: None,
1384 messages: vec![],
1385 tools: Vec::new(),
1386 tool_choice: None,
1387 stop: Vec::new(),
1388 temperature: AgentSettings::temperature_for_model(model, cx),
1389 };
1390
1391 for message in &self.messages {
1392 let mut request_message = LanguageModelRequestMessage {
1393 role: message.role,
1394 content: Vec::new(),
1395 cache: false,
1396 };
1397
1398 for segment in &message.segments {
1399 match segment {
1400 MessageSegment::Text(text) => request_message
1401 .content
1402 .push(MessageContent::Text(text.clone())),
1403 MessageSegment::Thinking { .. } => {}
1404 MessageSegment::RedactedThinking(_) => {}
1405 }
1406 }
1407
1408 if request_message.content.is_empty() {
1409 continue;
1410 }
1411
1412 request.messages.push(request_message);
1413 }
1414
1415 request.messages.push(LanguageModelRequestMessage {
1416 role: Role::User,
1417 content: vec![MessageContent::Text(added_user_message)],
1418 cache: false,
1419 });
1420
1421 request
1422 }
1423
1424 fn attached_tracked_files_state(
1425 &self,
1426 messages: &mut Vec<LanguageModelRequestMessage>,
1427 cx: &App,
1428 ) {
1429 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1430
1431 let mut stale_message = String::new();
1432
1433 let action_log = self.action_log.read(cx);
1434
1435 for stale_file in action_log.stale_buffers(cx) {
1436 let Some(file) = stale_file.read(cx).file() else {
1437 continue;
1438 };
1439
1440 if stale_message.is_empty() {
1441 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1442 }
1443
1444 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1445 }
1446
1447 let mut content = Vec::with_capacity(2);
1448
1449 if !stale_message.is_empty() {
1450 content.push(stale_message.into());
1451 }
1452
1453 if !content.is_empty() {
1454 let context_message = LanguageModelRequestMessage {
1455 role: Role::User,
1456 content,
1457 cache: false,
1458 };
1459
1460 messages.push(context_message);
1461 }
1462 }
1463
1464 pub fn stream_completion(
1465 &mut self,
1466 request: LanguageModelRequest,
1467 model: Arc<dyn LanguageModel>,
1468 window: Option<AnyWindowHandle>,
1469 cx: &mut Context<Self>,
1470 ) {
1471 self.tool_use_limit_reached = false;
1472
1473 let pending_completion_id = post_inc(&mut self.completion_count);
1474 let mut request_callback_parameters = if self.request_callback.is_some() {
1475 Some((request.clone(), Vec::new()))
1476 } else {
1477 None
1478 };
1479 let prompt_id = self.last_prompt_id.clone();
1480 let tool_use_metadata = ToolUseMetadata {
1481 model: model.clone(),
1482 thread_id: self.id.clone(),
1483 prompt_id: prompt_id.clone(),
1484 };
1485
1486 self.last_received_chunk_at = Some(Instant::now());
1487
1488 let task = cx.spawn(async move |thread, cx| {
1489 let stream_completion_future = model.stream_completion(request, &cx);
1490 let initial_token_usage =
1491 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1492 let stream_completion = async {
1493 let mut events = stream_completion_future.await?;
1494
1495 let mut stop_reason = StopReason::EndTurn;
1496 let mut current_token_usage = TokenUsage::default();
1497
1498 thread
1499 .update(cx, |_thread, cx| {
1500 cx.emit(ThreadEvent::NewRequest);
1501 })
1502 .ok();
1503
1504 let mut request_assistant_message_id = None;
1505
1506 while let Some(event) = events.next().await {
1507 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1508 response_events
1509 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1510 }
1511
1512 thread.update(cx, |thread, cx| {
1513 let event = match event {
1514 Ok(event) => event,
1515 Err(LanguageModelCompletionError::BadInputJson {
1516 id,
1517 tool_name,
1518 raw_input: invalid_input_json,
1519 json_parse_error,
1520 }) => {
1521 thread.receive_invalid_tool_json(
1522 id,
1523 tool_name,
1524 invalid_input_json,
1525 json_parse_error,
1526 window,
1527 cx,
1528 );
1529 return Ok(());
1530 }
1531 Err(LanguageModelCompletionError::Other(error)) => {
1532 return Err(error);
1533 }
1534 };
1535
1536 match event {
1537 LanguageModelCompletionEvent::StartMessage { .. } => {
1538 request_assistant_message_id =
1539 Some(thread.insert_assistant_message(
1540 vec![MessageSegment::Text(String::new())],
1541 cx,
1542 ));
1543 }
1544 LanguageModelCompletionEvent::Stop(reason) => {
1545 stop_reason = reason;
1546 }
1547 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1548 thread.update_token_usage_at_last_message(token_usage);
1549 thread.cumulative_token_usage = thread.cumulative_token_usage
1550 + token_usage
1551 - current_token_usage;
1552 current_token_usage = token_usage;
1553 }
1554 LanguageModelCompletionEvent::Text(chunk) => {
1555 thread.received_chunk();
1556
1557 cx.emit(ThreadEvent::ReceivedTextChunk);
1558 if let Some(last_message) = thread.messages.last_mut() {
1559 if last_message.role == Role::Assistant
1560 && !thread.tool_use.has_tool_results(last_message.id)
1561 {
1562 last_message.push_text(&chunk);
1563 cx.emit(ThreadEvent::StreamedAssistantText(
1564 last_message.id,
1565 chunk,
1566 ));
1567 } else {
1568 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1569 // of a new Assistant response.
1570 //
1571 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1572 // will result in duplicating the text of the chunk in the rendered Markdown.
1573 request_assistant_message_id =
1574 Some(thread.insert_assistant_message(
1575 vec![MessageSegment::Text(chunk.to_string())],
1576 cx,
1577 ));
1578 };
1579 }
1580 }
1581 LanguageModelCompletionEvent::Thinking {
1582 text: chunk,
1583 signature,
1584 } => {
1585 thread.received_chunk();
1586
1587 if let Some(last_message) = thread.messages.last_mut() {
1588 if last_message.role == Role::Assistant
1589 && !thread.tool_use.has_tool_results(last_message.id)
1590 {
1591 last_message.push_thinking(&chunk, signature);
1592 cx.emit(ThreadEvent::StreamedAssistantThinking(
1593 last_message.id,
1594 chunk,
1595 ));
1596 } else {
1597 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1598 // of a new Assistant response.
1599 //
1600 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1601 // will result in duplicating the text of the chunk in the rendered Markdown.
1602 request_assistant_message_id =
1603 Some(thread.insert_assistant_message(
1604 vec![MessageSegment::Thinking {
1605 text: chunk.to_string(),
1606 signature,
1607 }],
1608 cx,
1609 ));
1610 };
1611 }
1612 }
1613 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1614 let last_assistant_message_id = request_assistant_message_id
1615 .unwrap_or_else(|| {
1616 let new_assistant_message_id =
1617 thread.insert_assistant_message(vec![], cx);
1618 request_assistant_message_id =
1619 Some(new_assistant_message_id);
1620 new_assistant_message_id
1621 });
1622
1623 let tool_use_id = tool_use.id.clone();
1624 let streamed_input = if tool_use.is_input_complete {
1625 None
1626 } else {
1627 Some((&tool_use.input).clone())
1628 };
1629
1630 let ui_text = thread.tool_use.request_tool_use(
1631 last_assistant_message_id,
1632 tool_use,
1633 tool_use_metadata.clone(),
1634 cx,
1635 );
1636
1637 if let Some(input) = streamed_input {
1638 cx.emit(ThreadEvent::StreamedToolUse {
1639 tool_use_id,
1640 ui_text,
1641 input,
1642 });
1643 }
1644 }
1645 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1646 if let Some(completion) = thread
1647 .pending_completions
1648 .iter_mut()
1649 .find(|completion| completion.id == pending_completion_id)
1650 {
1651 match status_update {
1652 CompletionRequestStatus::Queued {
1653 position,
1654 } => {
1655 completion.queue_state = QueueState::Queued { position };
1656 }
1657 CompletionRequestStatus::Started => {
1658 completion.queue_state = QueueState::Started;
1659 }
1660 CompletionRequestStatus::Failed {
1661 code, message, request_id
1662 } => {
1663 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1664 }
1665 CompletionRequestStatus::UsageUpdated {
1666 amount, limit
1667 } => {
1668 let usage = RequestUsage { limit, amount: amount as i32 };
1669
1670 thread.last_usage = Some(usage);
1671 }
1672 CompletionRequestStatus::ToolUseLimitReached => {
1673 thread.tool_use_limit_reached = true;
1674 }
1675 }
1676 }
1677 }
1678 }
1679
1680 thread.touch_updated_at();
1681 cx.emit(ThreadEvent::StreamedCompletion);
1682 cx.notify();
1683
1684 thread.auto_capture_telemetry(cx);
1685 Ok(())
1686 })??;
1687
1688 smol::future::yield_now().await;
1689 }
1690
1691 thread.update(cx, |thread, cx| {
1692 thread.last_received_chunk_at = None;
1693 thread
1694 .pending_completions
1695 .retain(|completion| completion.id != pending_completion_id);
1696
1697 // If there is a response without tool use, summarize the message. Otherwise,
1698 // allow two tool uses before summarizing.
1699 if matches!(thread.summary, ThreadSummary::Pending)
1700 && thread.messages.len() >= 2
1701 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1702 {
1703 thread.summarize(cx);
1704 }
1705 })?;
1706
1707 anyhow::Ok(stop_reason)
1708 };
1709
1710 let result = stream_completion.await;
1711
1712 thread
1713 .update(cx, |thread, cx| {
1714 thread.finalize_pending_checkpoint(cx);
1715 match result.as_ref() {
1716 Ok(stop_reason) => match stop_reason {
1717 StopReason::ToolUse => {
1718 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1719 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1720 }
1721 StopReason::EndTurn | StopReason::MaxTokens => {
1722 thread.project.update(cx, |project, cx| {
1723 project.set_agent_location(None, cx);
1724 });
1725 }
1726 StopReason::Refusal => {
1727 thread.project.update(cx, |project, cx| {
1728 project.set_agent_location(None, cx);
1729 });
1730
1731 // Remove the turn that was refused.
1732 //
1733 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1734 {
1735 let mut messages_to_remove = Vec::new();
1736
1737 for (ix, message) in thread.messages.iter().enumerate().rev() {
1738 messages_to_remove.push(message.id);
1739
1740 if message.role == Role::User {
1741 if ix == 0 {
1742 break;
1743 }
1744
1745 if let Some(prev_message) = thread.messages.get(ix - 1) {
1746 if prev_message.role == Role::Assistant {
1747 break;
1748 }
1749 }
1750 }
1751 }
1752
1753 for message_id in messages_to_remove {
1754 thread.delete_message(message_id, cx);
1755 }
1756 }
1757
1758 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1759 header: "Language model refusal".into(),
1760 message: "Model refused to generate content for safety reasons.".into(),
1761 }));
1762 }
1763 },
1764 Err(error) => {
1765 thread.project.update(cx, |project, cx| {
1766 project.set_agent_location(None, cx);
1767 });
1768
1769 if error.is::<PaymentRequiredError>() {
1770 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1771 } else if let Some(error) =
1772 error.downcast_ref::<ModelRequestLimitReachedError>()
1773 {
1774 cx.emit(ThreadEvent::ShowError(
1775 ThreadError::ModelRequestLimitReached { plan: error.plan },
1776 ));
1777 } else if let Some(known_error) =
1778 error.downcast_ref::<LanguageModelKnownError>()
1779 {
1780 match known_error {
1781 LanguageModelKnownError::ContextWindowLimitExceeded {
1782 tokens,
1783 } => {
1784 thread.exceeded_window_error = Some(ExceededWindowError {
1785 model_id: model.id(),
1786 token_count: *tokens,
1787 });
1788 cx.notify();
1789 }
1790 }
1791 } else {
1792 let error_message = error
1793 .chain()
1794 .map(|err| err.to_string())
1795 .collect::<Vec<_>>()
1796 .join("\n");
1797 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1798 header: "Error interacting with language model".into(),
1799 message: SharedString::from(error_message.clone()),
1800 }));
1801 }
1802
1803 thread.cancel_last_completion(window, cx);
1804 }
1805 }
1806
1807 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1808
1809 if let Some((request_callback, (request, response_events))) = thread
1810 .request_callback
1811 .as_mut()
1812 .zip(request_callback_parameters.as_ref())
1813 {
1814 request_callback(request, response_events);
1815 }
1816
1817 thread.auto_capture_telemetry(cx);
1818
1819 if let Ok(initial_usage) = initial_token_usage {
1820 let usage = thread.cumulative_token_usage - initial_usage;
1821
1822 telemetry::event!(
1823 "Assistant Thread Completion",
1824 thread_id = thread.id().to_string(),
1825 prompt_id = prompt_id,
1826 model = model.telemetry_id(),
1827 model_provider = model.provider_id().to_string(),
1828 input_tokens = usage.input_tokens,
1829 output_tokens = usage.output_tokens,
1830 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1831 cache_read_input_tokens = usage.cache_read_input_tokens,
1832 );
1833 }
1834 })
1835 .ok();
1836 });
1837
1838 self.pending_completions.push(PendingCompletion {
1839 id: pending_completion_id,
1840 queue_state: QueueState::Sending,
1841 _task: task,
1842 });
1843 }
1844
1845 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1846 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1847 println!("No thread summary model");
1848 return;
1849 };
1850
1851 if !model.provider.is_authenticated(cx) {
1852 return;
1853 }
1854
1855 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1856 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1857 If the conversation is about a specific subject, include it in the title. \
1858 Be descriptive. DO NOT speak in the first person.";
1859
1860 let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
1861
1862 self.summary = ThreadSummary::Generating;
1863
1864 self.pending_summary = cx.spawn(async move |this, cx| {
1865 let result = async {
1866 let mut messages = model.model.stream_completion(request, &cx).await?;
1867
1868 let mut new_summary = String::new();
1869 while let Some(event) = messages.next().await {
1870 let Ok(event) = event else {
1871 continue;
1872 };
1873 let text = match event {
1874 LanguageModelCompletionEvent::Text(text) => text,
1875 LanguageModelCompletionEvent::StatusUpdate(
1876 CompletionRequestStatus::UsageUpdated { amount, limit },
1877 ) => {
1878 this.update(cx, |thread, _cx| {
1879 thread.last_usage = Some(RequestUsage {
1880 limit,
1881 amount: amount as i32,
1882 });
1883 })?;
1884 continue;
1885 }
1886 _ => continue,
1887 };
1888
1889 let mut lines = text.lines();
1890 new_summary.extend(lines.next());
1891
1892 // Stop if the LLM generated multiple lines.
1893 if lines.next().is_some() {
1894 break;
1895 }
1896 }
1897
1898 anyhow::Ok(new_summary)
1899 }
1900 .await;
1901
1902 this.update(cx, |this, cx| {
1903 match result {
1904 Ok(new_summary) => {
1905 if new_summary.is_empty() {
1906 this.summary = ThreadSummary::Error;
1907 } else {
1908 this.summary = ThreadSummary::Ready(new_summary.into());
1909 }
1910 }
1911 Err(err) => {
1912 this.summary = ThreadSummary::Error;
1913 log::error!("Failed to generate thread summary: {}", err);
1914 }
1915 }
1916 cx.emit(ThreadEvent::SummaryGenerated);
1917 })
1918 .log_err()?;
1919
1920 Some(())
1921 });
1922 }
1923
1924 pub fn start_generating_detailed_summary_if_needed(
1925 &mut self,
1926 thread_store: WeakEntity<ThreadStore>,
1927 cx: &mut Context<Self>,
1928 ) {
1929 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1930 return;
1931 };
1932
1933 match &*self.detailed_summary_rx.borrow() {
1934 DetailedSummaryState::Generating { message_id, .. }
1935 | DetailedSummaryState::Generated { message_id, .. }
1936 if *message_id == last_message_id =>
1937 {
1938 // Already up-to-date
1939 return;
1940 }
1941 _ => {}
1942 }
1943
1944 let Some(ConfiguredModel { model, provider }) =
1945 LanguageModelRegistry::read_global(cx).thread_summary_model()
1946 else {
1947 return;
1948 };
1949
1950 if !provider.is_authenticated(cx) {
1951 return;
1952 }
1953
1954 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1955 1. A brief overview of what was discussed\n\
1956 2. Key facts or information discovered\n\
1957 3. Outcomes or conclusions reached\n\
1958 4. Any action items or next steps if any\n\
1959 Format it in Markdown with headings and bullet points.";
1960
1961 let request = self.to_summarize_request(&model, added_user_message.into(), cx);
1962
1963 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1964 message_id: last_message_id,
1965 };
1966
1967 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1968 // be better to allow the old task to complete, but this would require logic for choosing
1969 // which result to prefer (the old task could complete after the new one, resulting in a
1970 // stale summary).
1971 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1972 let stream = model.stream_completion_text(request, &cx);
1973 let Some(mut messages) = stream.await.log_err() else {
1974 thread
1975 .update(cx, |thread, _cx| {
1976 *thread.detailed_summary_tx.borrow_mut() =
1977 DetailedSummaryState::NotGenerated;
1978 })
1979 .ok()?;
1980 return None;
1981 };
1982
1983 let mut new_detailed_summary = String::new();
1984
1985 while let Some(chunk) = messages.stream.next().await {
1986 if let Some(chunk) = chunk.log_err() {
1987 new_detailed_summary.push_str(&chunk);
1988 }
1989 }
1990
1991 thread
1992 .update(cx, |thread, _cx| {
1993 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
1994 text: new_detailed_summary.into(),
1995 message_id: last_message_id,
1996 };
1997 })
1998 .ok()?;
1999
2000 // Save thread so its summary can be reused later
2001 if let Some(thread) = thread.upgrade() {
2002 if let Ok(Ok(save_task)) = cx.update(|cx| {
2003 thread_store
2004 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2005 }) {
2006 save_task.await.log_err();
2007 }
2008 }
2009
2010 Some(())
2011 });
2012 }
2013
2014 pub async fn wait_for_detailed_summary_or_text(
2015 this: &Entity<Self>,
2016 cx: &mut AsyncApp,
2017 ) -> Option<SharedString> {
2018 let mut detailed_summary_rx = this
2019 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2020 .ok()?;
2021 loop {
2022 match detailed_summary_rx.recv().await? {
2023 DetailedSummaryState::Generating { .. } => {}
2024 DetailedSummaryState::NotGenerated => {
2025 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2026 }
2027 DetailedSummaryState::Generated { text, .. } => return Some(text),
2028 }
2029 }
2030 }
2031
2032 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2033 self.detailed_summary_rx
2034 .borrow()
2035 .text()
2036 .unwrap_or_else(|| self.text().into())
2037 }
2038
2039 pub fn is_generating_detailed_summary(&self) -> bool {
2040 matches!(
2041 &*self.detailed_summary_rx.borrow(),
2042 DetailedSummaryState::Generating { .. }
2043 )
2044 }
2045
2046 pub fn use_pending_tools(
2047 &mut self,
2048 window: Option<AnyWindowHandle>,
2049 cx: &mut Context<Self>,
2050 model: Arc<dyn LanguageModel>,
2051 ) -> Vec<PendingToolUse> {
2052 self.auto_capture_telemetry(cx);
2053 let request = Arc::new(self.to_completion_request(model.clone(), cx));
2054 let pending_tool_uses = self
2055 .tool_use
2056 .pending_tool_uses()
2057 .into_iter()
2058 .filter(|tool_use| tool_use.status.is_idle())
2059 .cloned()
2060 .collect::<Vec<_>>();
2061
2062 for tool_use in pending_tool_uses.iter() {
2063 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2064 if tool.needs_confirmation(&tool_use.input, cx)
2065 && !AgentSettings::get_global(cx).always_allow_tool_actions
2066 {
2067 self.tool_use.confirm_tool_use(
2068 tool_use.id.clone(),
2069 tool_use.ui_text.clone(),
2070 tool_use.input.clone(),
2071 request.clone(),
2072 tool,
2073 );
2074 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2075 } else {
2076 self.run_tool(
2077 tool_use.id.clone(),
2078 tool_use.ui_text.clone(),
2079 tool_use.input.clone(),
2080 request.clone(),
2081 tool,
2082 model.clone(),
2083 window,
2084 cx,
2085 );
2086 }
2087 } else {
2088 self.handle_hallucinated_tool_use(
2089 tool_use.id.clone(),
2090 tool_use.name.clone(),
2091 window,
2092 cx,
2093 );
2094 }
2095 }
2096
2097 pending_tool_uses
2098 }
2099
2100 pub fn handle_hallucinated_tool_use(
2101 &mut self,
2102 tool_use_id: LanguageModelToolUseId,
2103 hallucinated_tool_name: Arc<str>,
2104 window: Option<AnyWindowHandle>,
2105 cx: &mut Context<Thread>,
2106 ) {
2107 let available_tools = self.tools.read(cx).enabled_tools(cx);
2108
2109 let tool_list = available_tools
2110 .iter()
2111 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2112 .collect::<Vec<_>>()
2113 .join("\n");
2114
2115 let error_message = format!(
2116 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2117 hallucinated_tool_name, tool_list
2118 );
2119
2120 let pending_tool_use = self.tool_use.insert_tool_output(
2121 tool_use_id.clone(),
2122 hallucinated_tool_name,
2123 Err(anyhow!("Missing tool call: {error_message}")),
2124 self.configured_model.as_ref(),
2125 );
2126
2127 cx.emit(ThreadEvent::MissingToolUse {
2128 tool_use_id: tool_use_id.clone(),
2129 ui_text: error_message.into(),
2130 });
2131
2132 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2133 }
2134
2135 pub fn receive_invalid_tool_json(
2136 &mut self,
2137 tool_use_id: LanguageModelToolUseId,
2138 tool_name: Arc<str>,
2139 invalid_json: Arc<str>,
2140 error: String,
2141 window: Option<AnyWindowHandle>,
2142 cx: &mut Context<Thread>,
2143 ) {
2144 log::error!("The model returned invalid input JSON: {invalid_json}");
2145
2146 let pending_tool_use = self.tool_use.insert_tool_output(
2147 tool_use_id.clone(),
2148 tool_name,
2149 Err(anyhow!("Error parsing input JSON: {error}")),
2150 self.configured_model.as_ref(),
2151 );
2152 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2153 pending_tool_use.ui_text.clone()
2154 } else {
2155 log::error!(
2156 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2157 );
2158 format!("Unknown tool {}", tool_use_id).into()
2159 };
2160
2161 cx.emit(ThreadEvent::InvalidToolInput {
2162 tool_use_id: tool_use_id.clone(),
2163 ui_text,
2164 invalid_input_json: invalid_json,
2165 });
2166
2167 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2168 }
2169
2170 pub fn run_tool(
2171 &mut self,
2172 tool_use_id: LanguageModelToolUseId,
2173 ui_text: impl Into<SharedString>,
2174 input: serde_json::Value,
2175 request: Arc<LanguageModelRequest>,
2176 tool: Arc<dyn Tool>,
2177 model: Arc<dyn LanguageModel>,
2178 window: Option<AnyWindowHandle>,
2179 cx: &mut Context<Thread>,
2180 ) {
2181 let task =
2182 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2183 self.tool_use
2184 .run_pending_tool(tool_use_id, ui_text.into(), task);
2185 }
2186
2187 fn spawn_tool_use(
2188 &mut self,
2189 tool_use_id: LanguageModelToolUseId,
2190 request: Arc<LanguageModelRequest>,
2191 input: serde_json::Value,
2192 tool: Arc<dyn Tool>,
2193 model: Arc<dyn LanguageModel>,
2194 window: Option<AnyWindowHandle>,
2195 cx: &mut Context<Thread>,
2196 ) -> Task<()> {
2197 let tool_name: Arc<str> = tool.name().into();
2198
2199 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2200 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2201 } else {
2202 tool.run(
2203 input,
2204 request,
2205 self.project.clone(),
2206 self.action_log.clone(),
2207 model,
2208 window,
2209 cx,
2210 )
2211 };
2212
2213 // Store the card separately if it exists
2214 if let Some(card) = tool_result.card.clone() {
2215 self.tool_use
2216 .insert_tool_result_card(tool_use_id.clone(), card);
2217 }
2218
2219 cx.spawn({
2220 async move |thread: WeakEntity<Thread>, cx| {
2221 let output = tool_result.output.await;
2222
2223 thread
2224 .update(cx, |thread, cx| {
2225 let pending_tool_use = thread.tool_use.insert_tool_output(
2226 tool_use_id.clone(),
2227 tool_name,
2228 output,
2229 thread.configured_model.as_ref(),
2230 );
2231 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2232 })
2233 .ok();
2234 }
2235 })
2236 }
2237
2238 fn tool_finished(
2239 &mut self,
2240 tool_use_id: LanguageModelToolUseId,
2241 pending_tool_use: Option<PendingToolUse>,
2242 canceled: bool,
2243 window: Option<AnyWindowHandle>,
2244 cx: &mut Context<Self>,
2245 ) {
2246 if self.all_tools_finished() {
2247 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2248 if !canceled {
2249 self.send_to_model(model.clone(), window, cx);
2250 }
2251 self.auto_capture_telemetry(cx);
2252 }
2253 }
2254
2255 cx.emit(ThreadEvent::ToolFinished {
2256 tool_use_id,
2257 pending_tool_use,
2258 });
2259 }
2260
2261 /// Cancels the last pending completion, if there are any pending.
2262 ///
2263 /// Returns whether a completion was canceled.
2264 pub fn cancel_last_completion(
2265 &mut self,
2266 window: Option<AnyWindowHandle>,
2267 cx: &mut Context<Self>,
2268 ) -> bool {
2269 let mut canceled = self.pending_completions.pop().is_some();
2270
2271 for pending_tool_use in self.tool_use.cancel_pending() {
2272 canceled = true;
2273 self.tool_finished(
2274 pending_tool_use.id.clone(),
2275 Some(pending_tool_use),
2276 true,
2277 window,
2278 cx,
2279 );
2280 }
2281
2282 if canceled {
2283 cx.emit(ThreadEvent::CompletionCanceled);
2284
2285 // When canceled, we always want to insert the checkpoint.
2286 // (We skip over finalize_pending_checkpoint, because it
2287 // would conclude we didn't have anything to insert here.)
2288 if let Some(checkpoint) = self.pending_checkpoint.take() {
2289 self.insert_checkpoint(checkpoint, cx);
2290 }
2291 } else {
2292 self.finalize_pending_checkpoint(cx);
2293 }
2294
2295 canceled
2296 }
2297
2298 /// Signals that any in-progress editing should be canceled.
2299 ///
2300 /// This method is used to notify listeners (like ActiveThread) that
2301 /// they should cancel any editing operations.
2302 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2303 cx.emit(ThreadEvent::CancelEditing);
2304 }
2305
2306 pub fn feedback(&self) -> Option<ThreadFeedback> {
2307 self.feedback
2308 }
2309
2310 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2311 self.message_feedback.get(&message_id).copied()
2312 }
2313
2314 pub fn report_message_feedback(
2315 &mut self,
2316 message_id: MessageId,
2317 feedback: ThreadFeedback,
2318 cx: &mut Context<Self>,
2319 ) -> Task<Result<()>> {
2320 if self.message_feedback.get(&message_id) == Some(&feedback) {
2321 return Task::ready(Ok(()));
2322 }
2323
2324 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2325 let serialized_thread = self.serialize(cx);
2326 let thread_id = self.id().clone();
2327 let client = self.project.read(cx).client();
2328
2329 let enabled_tool_names: Vec<String> = self
2330 .tools()
2331 .read(cx)
2332 .enabled_tools(cx)
2333 .iter()
2334 .map(|tool| tool.name())
2335 .collect();
2336
2337 self.message_feedback.insert(message_id, feedback);
2338
2339 cx.notify();
2340
2341 let message_content = self
2342 .message(message_id)
2343 .map(|msg| msg.to_string())
2344 .unwrap_or_default();
2345
2346 cx.background_spawn(async move {
2347 let final_project_snapshot = final_project_snapshot.await;
2348 let serialized_thread = serialized_thread.await?;
2349 let thread_data =
2350 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2351
2352 let rating = match feedback {
2353 ThreadFeedback::Positive => "positive",
2354 ThreadFeedback::Negative => "negative",
2355 };
2356 telemetry::event!(
2357 "Assistant Thread Rated",
2358 rating,
2359 thread_id,
2360 enabled_tool_names,
2361 message_id = message_id.0,
2362 message_content,
2363 thread_data,
2364 final_project_snapshot
2365 );
2366 client.telemetry().flush_events().await;
2367
2368 Ok(())
2369 })
2370 }
2371
2372 pub fn report_feedback(
2373 &mut self,
2374 feedback: ThreadFeedback,
2375 cx: &mut Context<Self>,
2376 ) -> Task<Result<()>> {
2377 let last_assistant_message_id = self
2378 .messages
2379 .iter()
2380 .rev()
2381 .find(|msg| msg.role == Role::Assistant)
2382 .map(|msg| msg.id);
2383
2384 if let Some(message_id) = last_assistant_message_id {
2385 self.report_message_feedback(message_id, feedback, cx)
2386 } else {
2387 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2388 let serialized_thread = self.serialize(cx);
2389 let thread_id = self.id().clone();
2390 let client = self.project.read(cx).client();
2391 self.feedback = Some(feedback);
2392 cx.notify();
2393
2394 cx.background_spawn(async move {
2395 let final_project_snapshot = final_project_snapshot.await;
2396 let serialized_thread = serialized_thread.await?;
2397 let thread_data = serde_json::to_value(serialized_thread)
2398 .unwrap_or_else(|_| serde_json::Value::Null);
2399
2400 let rating = match feedback {
2401 ThreadFeedback::Positive => "positive",
2402 ThreadFeedback::Negative => "negative",
2403 };
2404 telemetry::event!(
2405 "Assistant Thread Rated",
2406 rating,
2407 thread_id,
2408 thread_data,
2409 final_project_snapshot
2410 );
2411 client.telemetry().flush_events().await;
2412
2413 Ok(())
2414 })
2415 }
2416 }
2417
2418 /// Create a snapshot of the current project state including git information and unsaved buffers.
2419 fn project_snapshot(
2420 project: Entity<Project>,
2421 cx: &mut Context<Self>,
2422 ) -> Task<Arc<ProjectSnapshot>> {
2423 let git_store = project.read(cx).git_store().clone();
2424 let worktree_snapshots: Vec<_> = project
2425 .read(cx)
2426 .visible_worktrees(cx)
2427 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2428 .collect();
2429
2430 cx.spawn(async move |_, cx| {
2431 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2432
2433 let mut unsaved_buffers = Vec::new();
2434 cx.update(|app_cx| {
2435 let buffer_store = project.read(app_cx).buffer_store();
2436 for buffer_handle in buffer_store.read(app_cx).buffers() {
2437 let buffer = buffer_handle.read(app_cx);
2438 if buffer.is_dirty() {
2439 if let Some(file) = buffer.file() {
2440 let path = file.path().to_string_lossy().to_string();
2441 unsaved_buffers.push(path);
2442 }
2443 }
2444 }
2445 })
2446 .ok();
2447
2448 Arc::new(ProjectSnapshot {
2449 worktree_snapshots,
2450 unsaved_buffer_paths: unsaved_buffers,
2451 timestamp: Utc::now(),
2452 })
2453 })
2454 }
2455
2456 fn worktree_snapshot(
2457 worktree: Entity<project::Worktree>,
2458 git_store: Entity<GitStore>,
2459 cx: &App,
2460 ) -> Task<WorktreeSnapshot> {
2461 cx.spawn(async move |cx| {
2462 // Get worktree path and snapshot
2463 let worktree_info = cx.update(|app_cx| {
2464 let worktree = worktree.read(app_cx);
2465 let path = worktree.abs_path().to_string_lossy().to_string();
2466 let snapshot = worktree.snapshot();
2467 (path, snapshot)
2468 });
2469
2470 let Ok((worktree_path, _snapshot)) = worktree_info else {
2471 return WorktreeSnapshot {
2472 worktree_path: String::new(),
2473 git_state: None,
2474 };
2475 };
2476
2477 let git_state = git_store
2478 .update(cx, |git_store, cx| {
2479 git_store
2480 .repositories()
2481 .values()
2482 .find(|repo| {
2483 repo.read(cx)
2484 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2485 .is_some()
2486 })
2487 .cloned()
2488 })
2489 .ok()
2490 .flatten()
2491 .map(|repo| {
2492 repo.update(cx, |repo, _| {
2493 let current_branch =
2494 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2495 repo.send_job(None, |state, _| async move {
2496 let RepositoryState::Local { backend, .. } = state else {
2497 return GitState {
2498 remote_url: None,
2499 head_sha: None,
2500 current_branch,
2501 diff: None,
2502 };
2503 };
2504
2505 let remote_url = backend.remote_url("origin");
2506 let head_sha = backend.head_sha().await;
2507 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2508
2509 GitState {
2510 remote_url,
2511 head_sha,
2512 current_branch,
2513 diff,
2514 }
2515 })
2516 })
2517 });
2518
2519 let git_state = match git_state {
2520 Some(git_state) => match git_state.ok() {
2521 Some(git_state) => git_state.await.ok(),
2522 None => None,
2523 },
2524 None => None,
2525 };
2526
2527 WorktreeSnapshot {
2528 worktree_path,
2529 git_state,
2530 }
2531 })
2532 }
2533
2534 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2535 let mut markdown = Vec::new();
2536
2537 let summary = self.summary().or_default();
2538 writeln!(markdown, "# {summary}\n")?;
2539
2540 for message in self.messages() {
2541 writeln!(
2542 markdown,
2543 "## {role}\n",
2544 role = match message.role {
2545 Role::User => "User",
2546 Role::Assistant => "Agent",
2547 Role::System => "System",
2548 }
2549 )?;
2550
2551 if !message.loaded_context.text.is_empty() {
2552 writeln!(markdown, "{}", message.loaded_context.text)?;
2553 }
2554
2555 if !message.loaded_context.images.is_empty() {
2556 writeln!(
2557 markdown,
2558 "\n{} images attached as context.\n",
2559 message.loaded_context.images.len()
2560 )?;
2561 }
2562
2563 for segment in &message.segments {
2564 match segment {
2565 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2566 MessageSegment::Thinking { text, .. } => {
2567 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2568 }
2569 MessageSegment::RedactedThinking(_) => {}
2570 }
2571 }
2572
2573 for tool_use in self.tool_uses_for_message(message.id, cx) {
2574 writeln!(
2575 markdown,
2576 "**Use Tool: {} ({})**",
2577 tool_use.name, tool_use.id
2578 )?;
2579 writeln!(markdown, "```json")?;
2580 writeln!(
2581 markdown,
2582 "{}",
2583 serde_json::to_string_pretty(&tool_use.input)?
2584 )?;
2585 writeln!(markdown, "```")?;
2586 }
2587
2588 for tool_result in self.tool_results_for_message(message.id) {
2589 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2590 if tool_result.is_error {
2591 write!(markdown, " (Error)")?;
2592 }
2593
2594 writeln!(markdown, "**\n")?;
2595 match &tool_result.content {
2596 LanguageModelToolResultContent::Text(text)
2597 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2598 text,
2599 ..
2600 }) => {
2601 writeln!(markdown, "{text}")?;
2602 }
2603 LanguageModelToolResultContent::Image(image) => {
2604 writeln!(markdown, "", image.source)?;
2605 }
2606 }
2607
2608 if let Some(output) = tool_result.output.as_ref() {
2609 writeln!(
2610 markdown,
2611 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2612 serde_json::to_string_pretty(output)?
2613 )?;
2614 }
2615 }
2616 }
2617
2618 Ok(String::from_utf8_lossy(&markdown).to_string())
2619 }
2620
2621 pub fn keep_edits_in_range(
2622 &mut self,
2623 buffer: Entity<language::Buffer>,
2624 buffer_range: Range<language::Anchor>,
2625 cx: &mut Context<Self>,
2626 ) {
2627 self.action_log.update(cx, |action_log, cx| {
2628 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2629 });
2630 }
2631
2632 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2633 self.action_log
2634 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2635 }
2636
2637 pub fn reject_edits_in_ranges(
2638 &mut self,
2639 buffer: Entity<language::Buffer>,
2640 buffer_ranges: Vec<Range<language::Anchor>>,
2641 cx: &mut Context<Self>,
2642 ) -> Task<Result<()>> {
2643 self.action_log.update(cx, |action_log, cx| {
2644 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2645 })
2646 }
2647
2648 pub fn action_log(&self) -> &Entity<ActionLog> {
2649 &self.action_log
2650 }
2651
2652 pub fn project(&self) -> &Entity<Project> {
2653 &self.project
2654 }
2655
2656 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2657 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2658 return;
2659 }
2660
2661 let now = Instant::now();
2662 if let Some(last) = self.last_auto_capture_at {
2663 if now.duration_since(last).as_secs() < 10 {
2664 return;
2665 }
2666 }
2667
2668 self.last_auto_capture_at = Some(now);
2669
2670 let thread_id = self.id().clone();
2671 let github_login = self
2672 .project
2673 .read(cx)
2674 .user_store()
2675 .read(cx)
2676 .current_user()
2677 .map(|user| user.github_login.clone());
2678 let client = self.project.read(cx).client();
2679 let serialize_task = self.serialize(cx);
2680
2681 cx.background_executor()
2682 .spawn(async move {
2683 if let Ok(serialized_thread) = serialize_task.await {
2684 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2685 telemetry::event!(
2686 "Agent Thread Auto-Captured",
2687 thread_id = thread_id.to_string(),
2688 thread_data = thread_data,
2689 auto_capture_reason = "tracked_user",
2690 github_login = github_login
2691 );
2692
2693 client.telemetry().flush_events().await;
2694 }
2695 }
2696 })
2697 .detach();
2698 }
2699
2700 pub fn cumulative_token_usage(&self) -> TokenUsage {
2701 self.cumulative_token_usage
2702 }
2703
2704 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2705 let Some(model) = self.configured_model.as_ref() else {
2706 return TotalTokenUsage::default();
2707 };
2708
2709 let max = model.model.max_token_count();
2710
2711 let index = self
2712 .messages
2713 .iter()
2714 .position(|msg| msg.id == message_id)
2715 .unwrap_or(0);
2716
2717 if index == 0 {
2718 return TotalTokenUsage { total: 0, max };
2719 }
2720
2721 let token_usage = &self
2722 .request_token_usage
2723 .get(index - 1)
2724 .cloned()
2725 .unwrap_or_default();
2726
2727 TotalTokenUsage {
2728 total: token_usage.total_tokens() as usize,
2729 max,
2730 }
2731 }
2732
2733 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2734 let model = self.configured_model.as_ref()?;
2735
2736 let max = model.model.max_token_count();
2737
2738 if let Some(exceeded_error) = &self.exceeded_window_error {
2739 if model.model.id() == exceeded_error.model_id {
2740 return Some(TotalTokenUsage {
2741 total: exceeded_error.token_count,
2742 max,
2743 });
2744 }
2745 }
2746
2747 let total = self
2748 .token_usage_at_last_message()
2749 .unwrap_or_default()
2750 .total_tokens() as usize;
2751
2752 Some(TotalTokenUsage { total, max })
2753 }
2754
2755 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2756 self.request_token_usage
2757 .get(self.messages.len().saturating_sub(1))
2758 .or_else(|| self.request_token_usage.last())
2759 .cloned()
2760 }
2761
2762 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2763 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2764 self.request_token_usage
2765 .resize(self.messages.len(), placeholder);
2766
2767 if let Some(last) = self.request_token_usage.last_mut() {
2768 *last = token_usage;
2769 }
2770 }
2771
2772 pub fn deny_tool_use(
2773 &mut self,
2774 tool_use_id: LanguageModelToolUseId,
2775 tool_name: Arc<str>,
2776 window: Option<AnyWindowHandle>,
2777 cx: &mut Context<Self>,
2778 ) {
2779 let err = Err(anyhow::anyhow!(
2780 "Permission to run tool action denied by user"
2781 ));
2782
2783 self.tool_use.insert_tool_output(
2784 tool_use_id.clone(),
2785 tool_name,
2786 err,
2787 self.configured_model.as_ref(),
2788 );
2789 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2790 }
2791}
2792
2793#[derive(Debug, Clone, Error)]
2794pub enum ThreadError {
2795 #[error("Payment required")]
2796 PaymentRequired,
2797 #[error("Model request limit reached")]
2798 ModelRequestLimitReached { plan: Plan },
2799 #[error("Message {header}: {message}")]
2800 Message {
2801 header: SharedString,
2802 message: SharedString,
2803 },
2804}
2805
2806#[derive(Debug, Clone)]
2807pub enum ThreadEvent {
2808 ShowError(ThreadError),
2809 StreamedCompletion,
2810 ReceivedTextChunk,
2811 NewRequest,
2812 StreamedAssistantText(MessageId, String),
2813 StreamedAssistantThinking(MessageId, String),
2814 StreamedToolUse {
2815 tool_use_id: LanguageModelToolUseId,
2816 ui_text: Arc<str>,
2817 input: serde_json::Value,
2818 },
2819 MissingToolUse {
2820 tool_use_id: LanguageModelToolUseId,
2821 ui_text: Arc<str>,
2822 },
2823 InvalidToolInput {
2824 tool_use_id: LanguageModelToolUseId,
2825 ui_text: Arc<str>,
2826 invalid_input_json: Arc<str>,
2827 },
2828 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2829 MessageAdded(MessageId),
2830 MessageEdited(MessageId),
2831 MessageDeleted(MessageId),
2832 SummaryGenerated,
2833 SummaryChanged,
2834 UsePendingTools {
2835 tool_uses: Vec<PendingToolUse>,
2836 },
2837 ToolFinished {
2838 #[allow(unused)]
2839 tool_use_id: LanguageModelToolUseId,
2840 /// The pending tool use that corresponds to this tool.
2841 pending_tool_use: Option<PendingToolUse>,
2842 },
2843 CheckpointChanged,
2844 ToolConfirmationNeeded,
2845 CancelEditing,
2846 CompletionCanceled,
2847}
2848
2849impl EventEmitter<ThreadEvent> for Thread {}
2850
2851struct PendingCompletion {
2852 id: usize,
2853 queue_state: QueueState,
2854 _task: Task<()>,
2855}
2856
2857#[cfg(test)]
2858mod tests {
2859 use super::*;
2860 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2861 use agent_settings::{AgentSettings, LanguageModelParameters};
2862 use assistant_tool::ToolRegistry;
2863 use editor::EditorSettings;
2864 use gpui::TestAppContext;
2865 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2866 use project::{FakeFs, Project};
2867 use prompt_store::PromptBuilder;
2868 use serde_json::json;
2869 use settings::{Settings, SettingsStore};
2870 use std::sync::Arc;
2871 use theme::ThemeSettings;
2872 use util::path;
2873 use workspace::Workspace;
2874
2875 #[gpui::test]
2876 async fn test_message_with_context(cx: &mut TestAppContext) {
2877 init_test_settings(cx);
2878
2879 let project = create_test_project(
2880 cx,
2881 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2882 )
2883 .await;
2884
2885 let (_workspace, _thread_store, thread, context_store, model) =
2886 setup_test_environment(cx, project.clone()).await;
2887
2888 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2889 .await
2890 .unwrap();
2891
2892 let context =
2893 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2894 let loaded_context = cx
2895 .update(|cx| load_context(vec![context], &project, &None, cx))
2896 .await;
2897
2898 // Insert user message with context
2899 let message_id = thread.update(cx, |thread, cx| {
2900 thread.insert_user_message(
2901 "Please explain this code",
2902 loaded_context,
2903 None,
2904 Vec::new(),
2905 cx,
2906 )
2907 });
2908
2909 // Check content and context in message object
2910 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2911
2912 // Use different path format strings based on platform for the test
2913 #[cfg(windows)]
2914 let path_part = r"test\code.rs";
2915 #[cfg(not(windows))]
2916 let path_part = "test/code.rs";
2917
2918 let expected_context = format!(
2919 r#"
2920<context>
2921The following items were attached by the user. They are up-to-date and don't need to be re-read.
2922
2923<files>
2924```rs {path_part}
2925fn main() {{
2926 println!("Hello, world!");
2927}}
2928```
2929</files>
2930</context>
2931"#
2932 );
2933
2934 assert_eq!(message.role, Role::User);
2935 assert_eq!(message.segments.len(), 1);
2936 assert_eq!(
2937 message.segments[0],
2938 MessageSegment::Text("Please explain this code".to_string())
2939 );
2940 assert_eq!(message.loaded_context.text, expected_context);
2941
2942 // Check message in request
2943 let request = thread.update(cx, |thread, cx| {
2944 thread.to_completion_request(model.clone(), cx)
2945 });
2946
2947 assert_eq!(request.messages.len(), 2);
2948 let expected_full_message = format!("{}Please explain this code", expected_context);
2949 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2950 }
2951
2952 #[gpui::test]
2953 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2954 init_test_settings(cx);
2955
2956 let project = create_test_project(
2957 cx,
2958 json!({
2959 "file1.rs": "fn function1() {}\n",
2960 "file2.rs": "fn function2() {}\n",
2961 "file3.rs": "fn function3() {}\n",
2962 "file4.rs": "fn function4() {}\n",
2963 }),
2964 )
2965 .await;
2966
2967 let (_, _thread_store, thread, context_store, model) =
2968 setup_test_environment(cx, project.clone()).await;
2969
2970 // First message with context 1
2971 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2972 .await
2973 .unwrap();
2974 let new_contexts = context_store.update(cx, |store, cx| {
2975 store.new_context_for_thread(thread.read(cx), None)
2976 });
2977 assert_eq!(new_contexts.len(), 1);
2978 let loaded_context = cx
2979 .update(|cx| load_context(new_contexts, &project, &None, cx))
2980 .await;
2981 let message1_id = thread.update(cx, |thread, cx| {
2982 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2983 });
2984
2985 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2986 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2987 .await
2988 .unwrap();
2989 let new_contexts = context_store.update(cx, |store, cx| {
2990 store.new_context_for_thread(thread.read(cx), None)
2991 });
2992 assert_eq!(new_contexts.len(), 1);
2993 let loaded_context = cx
2994 .update(|cx| load_context(new_contexts, &project, &None, cx))
2995 .await;
2996 let message2_id = thread.update(cx, |thread, cx| {
2997 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
2998 });
2999
3000 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3001 //
3002 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3003 .await
3004 .unwrap();
3005 let new_contexts = context_store.update(cx, |store, cx| {
3006 store.new_context_for_thread(thread.read(cx), None)
3007 });
3008 assert_eq!(new_contexts.len(), 1);
3009 let loaded_context = cx
3010 .update(|cx| load_context(new_contexts, &project, &None, cx))
3011 .await;
3012 let message3_id = thread.update(cx, |thread, cx| {
3013 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3014 });
3015
3016 // Check what contexts are included in each message
3017 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3018 (
3019 thread.message(message1_id).unwrap().clone(),
3020 thread.message(message2_id).unwrap().clone(),
3021 thread.message(message3_id).unwrap().clone(),
3022 )
3023 });
3024
3025 // First message should include context 1
3026 assert!(message1.loaded_context.text.contains("file1.rs"));
3027
3028 // Second message should include only context 2 (not 1)
3029 assert!(!message2.loaded_context.text.contains("file1.rs"));
3030 assert!(message2.loaded_context.text.contains("file2.rs"));
3031
3032 // Third message should include only context 3 (not 1 or 2)
3033 assert!(!message3.loaded_context.text.contains("file1.rs"));
3034 assert!(!message3.loaded_context.text.contains("file2.rs"));
3035 assert!(message3.loaded_context.text.contains("file3.rs"));
3036
3037 // Check entire request to make sure all contexts are properly included
3038 let request = thread.update(cx, |thread, cx| {
3039 thread.to_completion_request(model.clone(), cx)
3040 });
3041
3042 // The request should contain all 3 messages
3043 assert_eq!(request.messages.len(), 4);
3044
3045 // Check that the contexts are properly formatted in each message
3046 assert!(request.messages[1].string_contents().contains("file1.rs"));
3047 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3048 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3049
3050 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3051 assert!(request.messages[2].string_contents().contains("file2.rs"));
3052 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3053
3054 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3055 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3056 assert!(request.messages[3].string_contents().contains("file3.rs"));
3057
3058 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3059 .await
3060 .unwrap();
3061 let new_contexts = context_store.update(cx, |store, cx| {
3062 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3063 });
3064 assert_eq!(new_contexts.len(), 3);
3065 let loaded_context = cx
3066 .update(|cx| load_context(new_contexts, &project, &None, cx))
3067 .await
3068 .loaded_context;
3069
3070 assert!(!loaded_context.text.contains("file1.rs"));
3071 assert!(loaded_context.text.contains("file2.rs"));
3072 assert!(loaded_context.text.contains("file3.rs"));
3073 assert!(loaded_context.text.contains("file4.rs"));
3074
3075 let new_contexts = context_store.update(cx, |store, cx| {
3076 // Remove file4.rs
3077 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3078 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3079 });
3080 assert_eq!(new_contexts.len(), 2);
3081 let loaded_context = cx
3082 .update(|cx| load_context(new_contexts, &project, &None, cx))
3083 .await
3084 .loaded_context;
3085
3086 assert!(!loaded_context.text.contains("file1.rs"));
3087 assert!(loaded_context.text.contains("file2.rs"));
3088 assert!(loaded_context.text.contains("file3.rs"));
3089 assert!(!loaded_context.text.contains("file4.rs"));
3090
3091 let new_contexts = context_store.update(cx, |store, cx| {
3092 // Remove file3.rs
3093 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3094 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3095 });
3096 assert_eq!(new_contexts.len(), 1);
3097 let loaded_context = cx
3098 .update(|cx| load_context(new_contexts, &project, &None, cx))
3099 .await
3100 .loaded_context;
3101
3102 assert!(!loaded_context.text.contains("file1.rs"));
3103 assert!(loaded_context.text.contains("file2.rs"));
3104 assert!(!loaded_context.text.contains("file3.rs"));
3105 assert!(!loaded_context.text.contains("file4.rs"));
3106 }
3107
3108 #[gpui::test]
3109 async fn test_message_without_files(cx: &mut TestAppContext) {
3110 init_test_settings(cx);
3111
3112 let project = create_test_project(
3113 cx,
3114 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3115 )
3116 .await;
3117
3118 let (_, _thread_store, thread, _context_store, model) =
3119 setup_test_environment(cx, project.clone()).await;
3120
3121 // Insert user message without any context (empty context vector)
3122 let message_id = thread.update(cx, |thread, cx| {
3123 thread.insert_user_message(
3124 "What is the best way to learn Rust?",
3125 ContextLoadResult::default(),
3126 None,
3127 Vec::new(),
3128 cx,
3129 )
3130 });
3131
3132 // Check content and context in message object
3133 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3134
3135 // Context should be empty when no files are included
3136 assert_eq!(message.role, Role::User);
3137 assert_eq!(message.segments.len(), 1);
3138 assert_eq!(
3139 message.segments[0],
3140 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3141 );
3142 assert_eq!(message.loaded_context.text, "");
3143
3144 // Check message in request
3145 let request = thread.update(cx, |thread, cx| {
3146 thread.to_completion_request(model.clone(), cx)
3147 });
3148
3149 assert_eq!(request.messages.len(), 2);
3150 assert_eq!(
3151 request.messages[1].string_contents(),
3152 "What is the best way to learn Rust?"
3153 );
3154
3155 // Add second message, also without context
3156 let message2_id = thread.update(cx, |thread, cx| {
3157 thread.insert_user_message(
3158 "Are there any good books?",
3159 ContextLoadResult::default(),
3160 None,
3161 Vec::new(),
3162 cx,
3163 )
3164 });
3165
3166 let message2 =
3167 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3168 assert_eq!(message2.loaded_context.text, "");
3169
3170 // Check that both messages appear in the request
3171 let request = thread.update(cx, |thread, cx| {
3172 thread.to_completion_request(model.clone(), cx)
3173 });
3174
3175 assert_eq!(request.messages.len(), 3);
3176 assert_eq!(
3177 request.messages[1].string_contents(),
3178 "What is the best way to learn Rust?"
3179 );
3180 assert_eq!(
3181 request.messages[2].string_contents(),
3182 "Are there any good books?"
3183 );
3184 }
3185
3186 #[gpui::test]
3187 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3188 init_test_settings(cx);
3189
3190 let project = create_test_project(
3191 cx,
3192 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3193 )
3194 .await;
3195
3196 let (_workspace, _thread_store, thread, context_store, model) =
3197 setup_test_environment(cx, project.clone()).await;
3198
3199 // Open buffer and add it to context
3200 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3201 .await
3202 .unwrap();
3203
3204 let context =
3205 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3206 let loaded_context = cx
3207 .update(|cx| load_context(vec![context], &project, &None, cx))
3208 .await;
3209
3210 // Insert user message with the buffer as context
3211 thread.update(cx, |thread, cx| {
3212 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3213 });
3214
3215 // Create a request and check that it doesn't have a stale buffer warning yet
3216 let initial_request = thread.update(cx, |thread, cx| {
3217 thread.to_completion_request(model.clone(), cx)
3218 });
3219
3220 // Make sure we don't have a stale file warning yet
3221 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3222 msg.string_contents()
3223 .contains("These files changed since last read:")
3224 });
3225 assert!(
3226 !has_stale_warning,
3227 "Should not have stale buffer warning before buffer is modified"
3228 );
3229
3230 // Modify the buffer
3231 buffer.update(cx, |buffer, cx| {
3232 // Find a position at the end of line 1
3233 buffer.edit(
3234 [(1..1, "\n println!(\"Added a new line\");\n")],
3235 None,
3236 cx,
3237 );
3238 });
3239
3240 // Insert another user message without context
3241 thread.update(cx, |thread, cx| {
3242 thread.insert_user_message(
3243 "What does the code do now?",
3244 ContextLoadResult::default(),
3245 None,
3246 Vec::new(),
3247 cx,
3248 )
3249 });
3250
3251 // Create a new request and check for the stale buffer warning
3252 let new_request = thread.update(cx, |thread, cx| {
3253 thread.to_completion_request(model.clone(), cx)
3254 });
3255
3256 // We should have a stale file warning as the last message
3257 let last_message = new_request
3258 .messages
3259 .last()
3260 .expect("Request should have messages");
3261
3262 // The last message should be the stale buffer notification
3263 assert_eq!(last_message.role, Role::User);
3264
3265 // Check the exact content of the message
3266 let expected_content = "These files changed since last read:\n- code.rs\n";
3267 assert_eq!(
3268 last_message.string_contents(),
3269 expected_content,
3270 "Last message should be exactly the stale buffer notification"
3271 );
3272 }
3273
3274 #[gpui::test]
3275 async fn test_temperature_setting(cx: &mut TestAppContext) {
3276 init_test_settings(cx);
3277
3278 let project = create_test_project(
3279 cx,
3280 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3281 )
3282 .await;
3283
3284 let (_workspace, _thread_store, thread, _context_store, model) =
3285 setup_test_environment(cx, project.clone()).await;
3286
3287 // Both model and provider
3288 cx.update(|cx| {
3289 AgentSettings::override_global(
3290 AgentSettings {
3291 model_parameters: vec![LanguageModelParameters {
3292 provider: Some(model.provider_id().0.to_string().into()),
3293 model: Some(model.id().0.clone()),
3294 temperature: Some(0.66),
3295 }],
3296 ..AgentSettings::get_global(cx).clone()
3297 },
3298 cx,
3299 );
3300 });
3301
3302 let request = thread.update(cx, |thread, cx| {
3303 thread.to_completion_request(model.clone(), cx)
3304 });
3305 assert_eq!(request.temperature, Some(0.66));
3306
3307 // Only model
3308 cx.update(|cx| {
3309 AgentSettings::override_global(
3310 AgentSettings {
3311 model_parameters: vec![LanguageModelParameters {
3312 provider: None,
3313 model: Some(model.id().0.clone()),
3314 temperature: Some(0.66),
3315 }],
3316 ..AgentSettings::get_global(cx).clone()
3317 },
3318 cx,
3319 );
3320 });
3321
3322 let request = thread.update(cx, |thread, cx| {
3323 thread.to_completion_request(model.clone(), cx)
3324 });
3325 assert_eq!(request.temperature, Some(0.66));
3326
3327 // Only provider
3328 cx.update(|cx| {
3329 AgentSettings::override_global(
3330 AgentSettings {
3331 model_parameters: vec![LanguageModelParameters {
3332 provider: Some(model.provider_id().0.to_string().into()),
3333 model: None,
3334 temperature: Some(0.66),
3335 }],
3336 ..AgentSettings::get_global(cx).clone()
3337 },
3338 cx,
3339 );
3340 });
3341
3342 let request = thread.update(cx, |thread, cx| {
3343 thread.to_completion_request(model.clone(), cx)
3344 });
3345 assert_eq!(request.temperature, Some(0.66));
3346
3347 // Same model name, different provider
3348 cx.update(|cx| {
3349 AgentSettings::override_global(
3350 AgentSettings {
3351 model_parameters: vec![LanguageModelParameters {
3352 provider: Some("anthropic".into()),
3353 model: Some(model.id().0.clone()),
3354 temperature: Some(0.66),
3355 }],
3356 ..AgentSettings::get_global(cx).clone()
3357 },
3358 cx,
3359 );
3360 });
3361
3362 let request = thread.update(cx, |thread, cx| {
3363 thread.to_completion_request(model.clone(), cx)
3364 });
3365 assert_eq!(request.temperature, None);
3366 }
3367
3368 #[gpui::test]
3369 async fn test_thread_summary(cx: &mut TestAppContext) {
3370 init_test_settings(cx);
3371
3372 let project = create_test_project(cx, json!({})).await;
3373
3374 let (_, _thread_store, thread, _context_store, model) =
3375 setup_test_environment(cx, project.clone()).await;
3376
3377 // Initial state should be pending
3378 thread.read_with(cx, |thread, _| {
3379 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3380 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3381 });
3382
3383 // Manually setting the summary should not be allowed in this state
3384 thread.update(cx, |thread, cx| {
3385 thread.set_summary("This should not work", cx);
3386 });
3387
3388 thread.read_with(cx, |thread, _| {
3389 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3390 });
3391
3392 // Send a message
3393 thread.update(cx, |thread, cx| {
3394 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3395 thread.send_to_model(model.clone(), None, cx);
3396 });
3397
3398 let fake_model = model.as_fake();
3399 simulate_successful_response(&fake_model, cx);
3400
3401 // Should start generating summary when there are >= 2 messages
3402 thread.read_with(cx, |thread, _| {
3403 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3404 });
3405
3406 // Should not be able to set the summary while generating
3407 thread.update(cx, |thread, cx| {
3408 thread.set_summary("This should not work either", cx);
3409 });
3410
3411 thread.read_with(cx, |thread, _| {
3412 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3413 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3414 });
3415
3416 cx.run_until_parked();
3417 fake_model.stream_last_completion_response("Brief".into());
3418 fake_model.stream_last_completion_response(" Introduction".into());
3419 fake_model.end_last_completion_stream();
3420 cx.run_until_parked();
3421
3422 // Summary should be set
3423 thread.read_with(cx, |thread, _| {
3424 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3425 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3426 });
3427
3428 // Now we should be able to set a summary
3429 thread.update(cx, |thread, cx| {
3430 thread.set_summary("Brief Intro", cx);
3431 });
3432
3433 thread.read_with(cx, |thread, _| {
3434 assert_eq!(thread.summary().or_default(), "Brief Intro");
3435 });
3436
3437 // Test setting an empty summary (should default to DEFAULT)
3438 thread.update(cx, |thread, cx| {
3439 thread.set_summary("", cx);
3440 });
3441
3442 thread.read_with(cx, |thread, _| {
3443 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3444 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3445 });
3446 }
3447
3448 #[gpui::test]
3449 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3450 init_test_settings(cx);
3451
3452 let project = create_test_project(cx, json!({})).await;
3453
3454 let (_, _thread_store, thread, _context_store, model) =
3455 setup_test_environment(cx, project.clone()).await;
3456
3457 test_summarize_error(&model, &thread, cx);
3458
3459 // Now we should be able to set a summary
3460 thread.update(cx, |thread, cx| {
3461 thread.set_summary("Brief Intro", cx);
3462 });
3463
3464 thread.read_with(cx, |thread, _| {
3465 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3466 assert_eq!(thread.summary().or_default(), "Brief Intro");
3467 });
3468 }
3469
3470 #[gpui::test]
3471 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3472 init_test_settings(cx);
3473
3474 let project = create_test_project(cx, json!({})).await;
3475
3476 let (_, _thread_store, thread, _context_store, model) =
3477 setup_test_environment(cx, project.clone()).await;
3478
3479 test_summarize_error(&model, &thread, cx);
3480
3481 // Sending another message should not trigger another summarize request
3482 thread.update(cx, |thread, cx| {
3483 thread.insert_user_message(
3484 "How are you?",
3485 ContextLoadResult::default(),
3486 None,
3487 vec![],
3488 cx,
3489 );
3490 thread.send_to_model(model.clone(), None, cx);
3491 });
3492
3493 let fake_model = model.as_fake();
3494 simulate_successful_response(&fake_model, cx);
3495
3496 thread.read_with(cx, |thread, _| {
3497 // State is still Error, not Generating
3498 assert!(matches!(thread.summary(), ThreadSummary::Error));
3499 });
3500
3501 // But the summarize request can be invoked manually
3502 thread.update(cx, |thread, cx| {
3503 thread.summarize(cx);
3504 });
3505
3506 thread.read_with(cx, |thread, _| {
3507 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3508 });
3509
3510 cx.run_until_parked();
3511 fake_model.stream_last_completion_response("A successful summary".into());
3512 fake_model.end_last_completion_stream();
3513 cx.run_until_parked();
3514
3515 thread.read_with(cx, |thread, _| {
3516 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3517 assert_eq!(thread.summary().or_default(), "A successful summary");
3518 });
3519 }
3520
3521 fn test_summarize_error(
3522 model: &Arc<dyn LanguageModel>,
3523 thread: &Entity<Thread>,
3524 cx: &mut TestAppContext,
3525 ) {
3526 thread.update(cx, |thread, cx| {
3527 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3528 thread.send_to_model(model.clone(), None, cx);
3529 });
3530
3531 let fake_model = model.as_fake();
3532 simulate_successful_response(&fake_model, cx);
3533
3534 thread.read_with(cx, |thread, _| {
3535 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3536 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3537 });
3538
3539 // Simulate summary request ending
3540 cx.run_until_parked();
3541 fake_model.end_last_completion_stream();
3542 cx.run_until_parked();
3543
3544 // State is set to Error and default message
3545 thread.read_with(cx, |thread, _| {
3546 assert!(matches!(thread.summary(), ThreadSummary::Error));
3547 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3548 });
3549 }
3550
3551 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3552 cx.run_until_parked();
3553 fake_model.stream_last_completion_response("Assistant response".into());
3554 fake_model.end_last_completion_stream();
3555 cx.run_until_parked();
3556 }
3557
3558 fn init_test_settings(cx: &mut TestAppContext) {
3559 cx.update(|cx| {
3560 let settings_store = SettingsStore::test(cx);
3561 cx.set_global(settings_store);
3562 language::init(cx);
3563 Project::init_settings(cx);
3564 AgentSettings::register(cx);
3565 prompt_store::init(cx);
3566 thread_store::init(cx);
3567 workspace::init_settings(cx);
3568 language_model::init_settings(cx);
3569 ThemeSettings::register(cx);
3570 EditorSettings::register(cx);
3571 ToolRegistry::default_global(cx);
3572 });
3573 }
3574
3575 // Helper to create a test project with test files
3576 async fn create_test_project(
3577 cx: &mut TestAppContext,
3578 files: serde_json::Value,
3579 ) -> Entity<Project> {
3580 let fs = FakeFs::new(cx.executor());
3581 fs.insert_tree(path!("/test"), files).await;
3582 Project::test(fs, [path!("/test").as_ref()], cx).await
3583 }
3584
3585 async fn setup_test_environment(
3586 cx: &mut TestAppContext,
3587 project: Entity<Project>,
3588 ) -> (
3589 Entity<Workspace>,
3590 Entity<ThreadStore>,
3591 Entity<Thread>,
3592 Entity<ContextStore>,
3593 Arc<dyn LanguageModel>,
3594 ) {
3595 let (workspace, cx) =
3596 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3597
3598 let thread_store = cx
3599 .update(|_, cx| {
3600 ThreadStore::load(
3601 project.clone(),
3602 cx.new(|_| ToolWorkingSet::default()),
3603 None,
3604 Arc::new(PromptBuilder::new(None).unwrap()),
3605 cx,
3606 )
3607 })
3608 .await
3609 .unwrap();
3610
3611 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3612 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3613
3614 let provider = Arc::new(FakeLanguageModelProvider);
3615 let model = provider.test_model();
3616 let model: Arc<dyn LanguageModel> = Arc::new(model);
3617
3618 cx.update(|_, cx| {
3619 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3620 registry.set_default_model(
3621 Some(ConfiguredModel {
3622 provider: provider.clone(),
3623 model: model.clone(),
3624 }),
3625 cx,
3626 );
3627 registry.set_thread_summary_model(
3628 Some(ConfiguredModel {
3629 provider,
3630 model: model.clone(),
3631 }),
3632 cx,
3633 );
3634 })
3635 });
3636
3637 (workspace, thread_store, thread, context_store, model)
3638 }
3639
3640 async fn add_file_to_context(
3641 project: &Entity<Project>,
3642 context_store: &Entity<ContextStore>,
3643 path: &str,
3644 cx: &mut TestAppContext,
3645 ) -> Result<Entity<language::Buffer>> {
3646 let buffer_path = project
3647 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3648 .unwrap();
3649
3650 let buffer = project
3651 .update(cx, |project, cx| {
3652 project.open_buffer(buffer_path.clone(), cx)
3653 })
3654 .await
3655 .unwrap();
3656
3657 context_store.update(cx, |context_store, cx| {
3658 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3659 });
3660
3661 Ok(buffer)
3662 }
3663}