1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5use std::time::Instant;
6
7use agent_settings::{AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use editor::display_map::CreaseMetadata;
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::future::Shared;
15use futures::{FutureExt, StreamExt as _};
16use git::repository::DiffType;
17use gpui::{
18 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
19 WeakEntity,
20};
21use language_model::{
22 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
25 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
26 ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
27 StopReason, TokenUsage, WrappedTextContent,
28};
29use postage::stream::Stream as _;
30use project::Project;
31use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use schemars::JsonSchema;
35use serde::{Deserialize, Serialize};
36use settings::Settings;
37use thiserror::Error;
38use ui::Window;
39use util::{ResultExt as _, post_inc};
40use uuid::Uuid;
41use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
42
43use crate::ThreadStore;
44use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
45use crate::thread_store::{
46 SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
47 SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
48};
49use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
50
51#[derive(
52 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
53)]
54pub struct ThreadId(Arc<str>);
55
56impl ThreadId {
57 pub fn new() -> Self {
58 Self(Uuid::new_v4().to_string().into())
59 }
60}
61
62impl std::fmt::Display for ThreadId {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl From<&str> for ThreadId {
69 fn from(value: &str) -> Self {
70 Self(value.into())
71 }
72}
73
74/// The ID of the user prompt that initiated a request.
75///
76/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
77#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
78pub struct PromptId(Arc<str>);
79
80impl PromptId {
81 pub fn new() -> Self {
82 Self(Uuid::new_v4().to_string().into())
83 }
84}
85
86impl std::fmt::Display for PromptId {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
93pub struct MessageId(pub(crate) usize);
94
95impl MessageId {
96 fn post_inc(&mut self) -> Self {
97 Self(post_inc(&mut self.0))
98 }
99}
100
101/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
102#[derive(Clone, Debug)]
103pub struct MessageCrease {
104 pub range: Range<usize>,
105 pub metadata: CreaseMetadata,
106 /// None for a deserialized message, Some otherwise.
107 pub context: Option<AgentContextHandle>,
108}
109
110/// A message in a [`Thread`].
111#[derive(Debug, Clone)]
112pub struct Message {
113 pub id: MessageId,
114 pub role: Role,
115 pub segments: Vec<MessageSegment>,
116 pub loaded_context: LoadedContext,
117 pub creases: Vec<MessageCrease>,
118 pub is_hidden: bool,
119}
120
121impl Message {
122 /// Returns whether the message contains any meaningful text that should be displayed
123 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
124 pub fn should_display_content(&self) -> bool {
125 self.segments.iter().all(|segment| segment.should_display())
126 }
127
128 pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
129 if let Some(MessageSegment::Thinking {
130 text: segment,
131 signature: current_signature,
132 }) = self.segments.last_mut()
133 {
134 if let Some(signature) = signature {
135 *current_signature = Some(signature);
136 }
137 segment.push_str(text);
138 } else {
139 self.segments.push(MessageSegment::Thinking {
140 text: text.to_string(),
141 signature,
142 });
143 }
144 }
145
146 pub fn push_text(&mut self, text: &str) {
147 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
148 segment.push_str(text);
149 } else {
150 self.segments.push(MessageSegment::Text(text.to_string()));
151 }
152 }
153
154 pub fn to_string(&self) -> String {
155 let mut result = String::new();
156
157 if !self.loaded_context.text.is_empty() {
158 result.push_str(&self.loaded_context.text);
159 }
160
161 for segment in &self.segments {
162 match segment {
163 MessageSegment::Text(text) => result.push_str(text),
164 MessageSegment::Thinking { text, .. } => {
165 result.push_str("<think>\n");
166 result.push_str(text);
167 result.push_str("\n</think>");
168 }
169 MessageSegment::RedactedThinking(_) => {}
170 }
171 }
172
173 result
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub enum MessageSegment {
179 Text(String),
180 Thinking {
181 text: String,
182 signature: Option<String>,
183 },
184 RedactedThinking(Vec<u8>),
185}
186
187impl MessageSegment {
188 pub fn should_display(&self) -> bool {
189 match self {
190 Self::Text(text) => text.is_empty(),
191 Self::Thinking { text, .. } => text.is_empty(),
192 Self::RedactedThinking(_) => false,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ProjectSnapshot {
199 pub worktree_snapshots: Vec<WorktreeSnapshot>,
200 pub unsaved_buffer_paths: Vec<String>,
201 pub timestamp: DateTime<Utc>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct WorktreeSnapshot {
206 pub worktree_path: String,
207 pub git_state: Option<GitState>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct GitState {
212 pub remote_url: Option<String>,
213 pub head_sha: Option<String>,
214 pub current_branch: Option<String>,
215 pub diff: Option<String>,
216}
217
218#[derive(Clone, Debug)]
219pub struct ThreadCheckpoint {
220 message_id: MessageId,
221 git_checkpoint: GitStoreCheckpoint,
222}
223
224#[derive(Copy, Clone, Debug, PartialEq, Eq)]
225pub enum ThreadFeedback {
226 Positive,
227 Negative,
228}
229
230pub enum LastRestoreCheckpoint {
231 Pending {
232 message_id: MessageId,
233 },
234 Error {
235 message_id: MessageId,
236 error: String,
237 },
238}
239
240impl LastRestoreCheckpoint {
241 pub fn message_id(&self) -> MessageId {
242 match self {
243 LastRestoreCheckpoint::Pending { message_id } => *message_id,
244 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
245 }
246 }
247}
248
249#[derive(Clone, Debug, Default, Serialize, Deserialize)]
250pub enum DetailedSummaryState {
251 #[default]
252 NotGenerated,
253 Generating {
254 message_id: MessageId,
255 },
256 Generated {
257 text: SharedString,
258 message_id: MessageId,
259 },
260}
261
262impl DetailedSummaryState {
263 fn text(&self) -> Option<SharedString> {
264 if let Self::Generated { text, .. } = self {
265 Some(text.clone())
266 } else {
267 None
268 }
269 }
270}
271
272#[derive(Default, Debug)]
273pub struct TotalTokenUsage {
274 pub total: usize,
275 pub max: usize,
276}
277
278impl TotalTokenUsage {
279 pub fn ratio(&self) -> TokenUsageRatio {
280 #[cfg(debug_assertions)]
281 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
282 .unwrap_or("0.8".to_string())
283 .parse()
284 .unwrap();
285 #[cfg(not(debug_assertions))]
286 let warning_threshold: f32 = 0.8;
287
288 // When the maximum is unknown because there is no selected model,
289 // avoid showing the token limit warning.
290 if self.max == 0 {
291 TokenUsageRatio::Normal
292 } else if self.total >= self.max {
293 TokenUsageRatio::Exceeded
294 } else if self.total as f32 / self.max as f32 >= warning_threshold {
295 TokenUsageRatio::Warning
296 } else {
297 TokenUsageRatio::Normal
298 }
299 }
300
301 pub fn add(&self, tokens: usize) -> TotalTokenUsage {
302 TotalTokenUsage {
303 total: self.total + tokens,
304 max: self.max,
305 }
306 }
307}
308
309#[derive(Debug, Default, PartialEq, Eq)]
310pub enum TokenUsageRatio {
311 #[default]
312 Normal,
313 Warning,
314 Exceeded,
315}
316
317#[derive(Debug, Clone, Copy)]
318pub enum QueueState {
319 Sending,
320 Queued { position: usize },
321 Started,
322}
323
324/// A thread of conversation with the LLM.
325pub struct Thread {
326 id: ThreadId,
327 updated_at: DateTime<Utc>,
328 summary: ThreadSummary,
329 pending_summary: Task<Option<()>>,
330 detailed_summary_task: Task<Option<()>>,
331 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
332 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
333 completion_mode: agent_settings::CompletionMode,
334 messages: Vec<Message>,
335 next_message_id: MessageId,
336 last_prompt_id: PromptId,
337 project_context: SharedProjectContext,
338 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
339 completion_count: usize,
340 pending_completions: Vec<PendingCompletion>,
341 project: Entity<Project>,
342 prompt_builder: Arc<PromptBuilder>,
343 tools: Entity<ToolWorkingSet>,
344 tool_use: ToolUseState,
345 action_log: Entity<ActionLog>,
346 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
347 pending_checkpoint: Option<ThreadCheckpoint>,
348 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
349 request_token_usage: Vec<TokenUsage>,
350 cumulative_token_usage: TokenUsage,
351 exceeded_window_error: Option<ExceededWindowError>,
352 last_usage: Option<RequestUsage>,
353 tool_use_limit_reached: bool,
354 feedback: Option<ThreadFeedback>,
355 message_feedback: HashMap<MessageId, ThreadFeedback>,
356 last_auto_capture_at: Option<Instant>,
357 last_received_chunk_at: Option<Instant>,
358 request_callback: Option<
359 Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
360 >,
361 remaining_turns: u32,
362 configured_model: Option<ConfiguredModel>,
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
366pub enum ThreadSummary {
367 Pending,
368 Generating,
369 Ready(SharedString),
370 Error,
371}
372
373impl ThreadSummary {
374 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
375
376 pub fn or_default(&self) -> SharedString {
377 self.unwrap_or(Self::DEFAULT)
378 }
379
380 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
381 self.ready().unwrap_or_else(|| message.into())
382 }
383
384 pub fn ready(&self) -> Option<SharedString> {
385 match self {
386 ThreadSummary::Ready(summary) => Some(summary.clone()),
387 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
388 }
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ExceededWindowError {
394 /// Model used when last message exceeded context window
395 model_id: LanguageModelId,
396 /// Token count including last message
397 token_count: usize,
398}
399
400impl Thread {
401 pub fn new(
402 project: Entity<Project>,
403 tools: Entity<ToolWorkingSet>,
404 prompt_builder: Arc<PromptBuilder>,
405 system_prompt: SharedProjectContext,
406 cx: &mut Context<Self>,
407 ) -> Self {
408 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
409 let configured_model = LanguageModelRegistry::read_global(cx).default_model();
410
411 Self {
412 id: ThreadId::new(),
413 updated_at: Utc::now(),
414 summary: ThreadSummary::Pending,
415 pending_summary: Task::ready(None),
416 detailed_summary_task: Task::ready(None),
417 detailed_summary_tx,
418 detailed_summary_rx,
419 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
420 messages: Vec::new(),
421 next_message_id: MessageId(0),
422 last_prompt_id: PromptId::new(),
423 project_context: system_prompt,
424 checkpoints_by_message: HashMap::default(),
425 completion_count: 0,
426 pending_completions: Vec::new(),
427 project: project.clone(),
428 prompt_builder,
429 tools: tools.clone(),
430 last_restore_checkpoint: None,
431 pending_checkpoint: None,
432 tool_use: ToolUseState::new(tools.clone()),
433 action_log: cx.new(|_| ActionLog::new(project.clone())),
434 initial_project_snapshot: {
435 let project_snapshot = Self::project_snapshot(project, cx);
436 cx.foreground_executor()
437 .spawn(async move { Some(project_snapshot.await) })
438 .shared()
439 },
440 request_token_usage: Vec::new(),
441 cumulative_token_usage: TokenUsage::default(),
442 exceeded_window_error: None,
443 last_usage: None,
444 tool_use_limit_reached: false,
445 feedback: None,
446 message_feedback: HashMap::default(),
447 last_auto_capture_at: None,
448 last_received_chunk_at: None,
449 request_callback: None,
450 remaining_turns: u32::MAX,
451 configured_model,
452 }
453 }
454
455 pub fn deserialize(
456 id: ThreadId,
457 serialized: SerializedThread,
458 project: Entity<Project>,
459 tools: Entity<ToolWorkingSet>,
460 prompt_builder: Arc<PromptBuilder>,
461 project_context: SharedProjectContext,
462 window: Option<&mut Window>, // None in headless mode
463 cx: &mut Context<Self>,
464 ) -> Self {
465 let next_message_id = MessageId(
466 serialized
467 .messages
468 .last()
469 .map(|message| message.id.0 + 1)
470 .unwrap_or(0),
471 );
472 let tool_use = ToolUseState::from_serialized_messages(
473 tools.clone(),
474 &serialized.messages,
475 project.clone(),
476 window,
477 cx,
478 );
479 let (detailed_summary_tx, detailed_summary_rx) =
480 postage::watch::channel_with(serialized.detailed_summary_state);
481
482 let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
483 serialized
484 .model
485 .and_then(|model| {
486 let model = SelectedModel {
487 provider: model.provider.clone().into(),
488 model: model.model.clone().into(),
489 };
490 registry.select_model(&model, cx)
491 })
492 .or_else(|| registry.default_model())
493 });
494
495 let completion_mode = serialized
496 .completion_mode
497 .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
498
499 Self {
500 id,
501 updated_at: serialized.updated_at,
502 summary: ThreadSummary::Ready(serialized.summary),
503 pending_summary: Task::ready(None),
504 detailed_summary_task: Task::ready(None),
505 detailed_summary_tx,
506 detailed_summary_rx,
507 completion_mode,
508 messages: serialized
509 .messages
510 .into_iter()
511 .map(|message| Message {
512 id: message.id,
513 role: message.role,
514 segments: message
515 .segments
516 .into_iter()
517 .map(|segment| match segment {
518 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
519 SerializedMessageSegment::Thinking { text, signature } => {
520 MessageSegment::Thinking { text, signature }
521 }
522 SerializedMessageSegment::RedactedThinking { data } => {
523 MessageSegment::RedactedThinking(data)
524 }
525 })
526 .collect(),
527 loaded_context: LoadedContext {
528 contexts: Vec::new(),
529 text: message.context,
530 images: Vec::new(),
531 },
532 creases: message
533 .creases
534 .into_iter()
535 .map(|crease| MessageCrease {
536 range: crease.start..crease.end,
537 metadata: CreaseMetadata {
538 icon_path: crease.icon_path,
539 label: crease.label,
540 },
541 context: None,
542 })
543 .collect(),
544 is_hidden: message.is_hidden,
545 })
546 .collect(),
547 next_message_id,
548 last_prompt_id: PromptId::new(),
549 project_context,
550 checkpoints_by_message: HashMap::default(),
551 completion_count: 0,
552 pending_completions: Vec::new(),
553 last_restore_checkpoint: None,
554 pending_checkpoint: None,
555 project: project.clone(),
556 prompt_builder,
557 tools,
558 tool_use,
559 action_log: cx.new(|_| ActionLog::new(project)),
560 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
561 request_token_usage: serialized.request_token_usage,
562 cumulative_token_usage: serialized.cumulative_token_usage,
563 exceeded_window_error: None,
564 last_usage: None,
565 tool_use_limit_reached: serialized.tool_use_limit_reached,
566 feedback: None,
567 message_feedback: HashMap::default(),
568 last_auto_capture_at: None,
569 last_received_chunk_at: None,
570 request_callback: None,
571 remaining_turns: u32::MAX,
572 configured_model,
573 }
574 }
575
576 pub fn set_request_callback(
577 &mut self,
578 callback: impl 'static
579 + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
580 ) {
581 self.request_callback = Some(Box::new(callback));
582 }
583
584 pub fn id(&self) -> &ThreadId {
585 &self.id
586 }
587
588 pub fn is_empty(&self) -> bool {
589 self.messages.is_empty()
590 }
591
592 pub fn updated_at(&self) -> DateTime<Utc> {
593 self.updated_at
594 }
595
596 pub fn touch_updated_at(&mut self) {
597 self.updated_at = Utc::now();
598 }
599
600 pub fn advance_prompt_id(&mut self) {
601 self.last_prompt_id = PromptId::new();
602 }
603
604 pub fn project_context(&self) -> SharedProjectContext {
605 self.project_context.clone()
606 }
607
608 pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
609 if self.configured_model.is_none() {
610 self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
611 }
612 self.configured_model.clone()
613 }
614
615 pub fn configured_model(&self) -> Option<ConfiguredModel> {
616 self.configured_model.clone()
617 }
618
619 pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
620 self.configured_model = model;
621 cx.notify();
622 }
623
624 pub fn summary(&self) -> &ThreadSummary {
625 &self.summary
626 }
627
628 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
629 let current_summary = match &self.summary {
630 ThreadSummary::Pending | ThreadSummary::Generating => return,
631 ThreadSummary::Ready(summary) => summary,
632 ThreadSummary::Error => &ThreadSummary::DEFAULT,
633 };
634
635 let mut new_summary = new_summary.into();
636
637 if new_summary.is_empty() {
638 new_summary = ThreadSummary::DEFAULT;
639 }
640
641 if current_summary != &new_summary {
642 self.summary = ThreadSummary::Ready(new_summary);
643 cx.emit(ThreadEvent::SummaryChanged);
644 }
645 }
646
647 pub fn completion_mode(&self) -> CompletionMode {
648 self.completion_mode
649 }
650
651 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
652 self.completion_mode = mode;
653 }
654
655 pub fn message(&self, id: MessageId) -> Option<&Message> {
656 let index = self
657 .messages
658 .binary_search_by(|message| message.id.cmp(&id))
659 .ok()?;
660
661 self.messages.get(index)
662 }
663
664 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
665 self.messages.iter()
666 }
667
668 pub fn is_generating(&self) -> bool {
669 !self.pending_completions.is_empty() || !self.all_tools_finished()
670 }
671
672 /// Indicates whether streaming of language model events is stale.
673 /// When `is_generating()` is false, this method returns `None`.
674 pub fn is_generation_stale(&self) -> Option<bool> {
675 const STALE_THRESHOLD: u128 = 250;
676
677 self.last_received_chunk_at
678 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
679 }
680
681 fn received_chunk(&mut self) {
682 self.last_received_chunk_at = Some(Instant::now());
683 }
684
685 pub fn queue_state(&self) -> Option<QueueState> {
686 self.pending_completions
687 .first()
688 .map(|pending_completion| pending_completion.queue_state)
689 }
690
691 pub fn tools(&self) -> &Entity<ToolWorkingSet> {
692 &self.tools
693 }
694
695 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
696 self.tool_use
697 .pending_tool_uses()
698 .into_iter()
699 .find(|tool_use| &tool_use.id == id)
700 }
701
702 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
703 self.tool_use
704 .pending_tool_uses()
705 .into_iter()
706 .filter(|tool_use| tool_use.status.needs_confirmation())
707 }
708
709 pub fn has_pending_tool_uses(&self) -> bool {
710 !self.tool_use.pending_tool_uses().is_empty()
711 }
712
713 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
714 self.checkpoints_by_message.get(&id).cloned()
715 }
716
717 pub fn restore_checkpoint(
718 &mut self,
719 checkpoint: ThreadCheckpoint,
720 cx: &mut Context<Self>,
721 ) -> Task<Result<()>> {
722 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
723 message_id: checkpoint.message_id,
724 });
725 cx.emit(ThreadEvent::CheckpointChanged);
726 cx.notify();
727
728 let git_store = self.project().read(cx).git_store().clone();
729 let restore = git_store.update(cx, |git_store, cx| {
730 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
731 });
732
733 cx.spawn(async move |this, cx| {
734 let result = restore.await;
735 this.update(cx, |this, cx| {
736 if let Err(err) = result.as_ref() {
737 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
738 message_id: checkpoint.message_id,
739 error: err.to_string(),
740 });
741 } else {
742 this.truncate(checkpoint.message_id, cx);
743 this.last_restore_checkpoint = None;
744 }
745 this.pending_checkpoint = None;
746 cx.emit(ThreadEvent::CheckpointChanged);
747 cx.notify();
748 })?;
749 result
750 })
751 }
752
753 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
754 let pending_checkpoint = if self.is_generating() {
755 return;
756 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
757 checkpoint
758 } else {
759 return;
760 };
761
762 self.finalize_checkpoint(pending_checkpoint, cx);
763 }
764
765 fn finalize_checkpoint(
766 &mut self,
767 pending_checkpoint: ThreadCheckpoint,
768 cx: &mut Context<Self>,
769 ) {
770 let git_store = self.project.read(cx).git_store().clone();
771 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
772 cx.spawn(async move |this, cx| match final_checkpoint.await {
773 Ok(final_checkpoint) => {
774 let equal = git_store
775 .update(cx, |store, cx| {
776 store.compare_checkpoints(
777 pending_checkpoint.git_checkpoint.clone(),
778 final_checkpoint.clone(),
779 cx,
780 )
781 })?
782 .await
783 .unwrap_or(false);
784
785 if !equal {
786 this.update(cx, |this, cx| {
787 this.insert_checkpoint(pending_checkpoint, cx)
788 })?;
789 }
790
791 Ok(())
792 }
793 Err(_) => this.update(cx, |this, cx| {
794 this.insert_checkpoint(pending_checkpoint, cx)
795 }),
796 })
797 .detach();
798 }
799
800 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
801 self.checkpoints_by_message
802 .insert(checkpoint.message_id, checkpoint);
803 cx.emit(ThreadEvent::CheckpointChanged);
804 cx.notify();
805 }
806
807 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
808 self.last_restore_checkpoint.as_ref()
809 }
810
811 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
812 let Some(message_ix) = self
813 .messages
814 .iter()
815 .rposition(|message| message.id == message_id)
816 else {
817 return;
818 };
819 for deleted_message in self.messages.drain(message_ix..) {
820 self.checkpoints_by_message.remove(&deleted_message.id);
821 }
822 cx.notify();
823 }
824
825 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
826 self.messages
827 .iter()
828 .find(|message| message.id == id)
829 .into_iter()
830 .flat_map(|message| message.loaded_context.contexts.iter())
831 }
832
833 pub fn is_turn_end(&self, ix: usize) -> bool {
834 if self.messages.is_empty() {
835 return false;
836 }
837
838 if !self.is_generating() && ix == self.messages.len() - 1 {
839 return true;
840 }
841
842 let Some(message) = self.messages.get(ix) else {
843 return false;
844 };
845
846 if message.role != Role::Assistant {
847 return false;
848 }
849
850 self.messages
851 .get(ix + 1)
852 .and_then(|message| {
853 self.message(message.id)
854 .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
855 })
856 .unwrap_or(false)
857 }
858
859 pub fn last_usage(&self) -> Option<RequestUsage> {
860 self.last_usage
861 }
862
863 pub fn tool_use_limit_reached(&self) -> bool {
864 self.tool_use_limit_reached
865 }
866
867 /// Returns whether all of the tool uses have finished running.
868 pub fn all_tools_finished(&self) -> bool {
869 // If the only pending tool uses left are the ones with errors, then
870 // that means that we've finished running all of the pending tools.
871 self.tool_use
872 .pending_tool_uses()
873 .iter()
874 .all(|tool_use| tool_use.status.is_error())
875 }
876
877 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
878 self.tool_use.tool_uses_for_message(id, cx)
879 }
880
881 pub fn tool_results_for_message(
882 &self,
883 assistant_message_id: MessageId,
884 ) -> Vec<&LanguageModelToolResult> {
885 self.tool_use.tool_results_for_message(assistant_message_id)
886 }
887
888 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
889 self.tool_use.tool_result(id)
890 }
891
892 pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
893 match &self.tool_use.tool_result(id)?.content {
894 LanguageModelToolResultContent::Text(text)
895 | LanguageModelToolResultContent::WrappedText(WrappedTextContent { text, .. }) => {
896 Some(text)
897 }
898 LanguageModelToolResultContent::Image(_) => {
899 // TODO: We should display image
900 None
901 }
902 }
903 }
904
905 pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
906 self.tool_use.tool_result_card(id).cloned()
907 }
908
909 /// Return tools that are both enabled and supported by the model
910 pub fn available_tools(
911 &self,
912 cx: &App,
913 model: Arc<dyn LanguageModel>,
914 ) -> Vec<LanguageModelRequestTool> {
915 if model.supports_tools() {
916 self.tools()
917 .read(cx)
918 .enabled_tools(cx)
919 .into_iter()
920 .filter_map(|tool| {
921 // Skip tools that cannot be supported
922 let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
923 Some(LanguageModelRequestTool {
924 name: tool.name(),
925 description: tool.description(),
926 input_schema,
927 })
928 })
929 .collect()
930 } else {
931 Vec::default()
932 }
933 }
934
935 pub fn insert_user_message(
936 &mut self,
937 text: impl Into<String>,
938 loaded_context: ContextLoadResult,
939 git_checkpoint: Option<GitStoreCheckpoint>,
940 creases: Vec<MessageCrease>,
941 cx: &mut Context<Self>,
942 ) -> MessageId {
943 if !loaded_context.referenced_buffers.is_empty() {
944 self.action_log.update(cx, |log, cx| {
945 for buffer in loaded_context.referenced_buffers {
946 log.buffer_read(buffer, cx);
947 }
948 });
949 }
950
951 let message_id = self.insert_message(
952 Role::User,
953 vec![MessageSegment::Text(text.into())],
954 loaded_context.loaded_context,
955 creases,
956 false,
957 cx,
958 );
959
960 if let Some(git_checkpoint) = git_checkpoint {
961 self.pending_checkpoint = Some(ThreadCheckpoint {
962 message_id,
963 git_checkpoint,
964 });
965 }
966
967 self.auto_capture_telemetry(cx);
968
969 message_id
970 }
971
972 pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
973 let id = self.insert_message(
974 Role::User,
975 vec![MessageSegment::Text("Continue where you left off".into())],
976 LoadedContext::default(),
977 vec![],
978 true,
979 cx,
980 );
981 self.pending_checkpoint = None;
982
983 id
984 }
985
986 pub fn insert_assistant_message(
987 &mut self,
988 segments: Vec<MessageSegment>,
989 cx: &mut Context<Self>,
990 ) -> MessageId {
991 self.insert_message(
992 Role::Assistant,
993 segments,
994 LoadedContext::default(),
995 Vec::new(),
996 false,
997 cx,
998 )
999 }
1000
1001 pub fn insert_message(
1002 &mut self,
1003 role: Role,
1004 segments: Vec<MessageSegment>,
1005 loaded_context: LoadedContext,
1006 creases: Vec<MessageCrease>,
1007 is_hidden: bool,
1008 cx: &mut Context<Self>,
1009 ) -> MessageId {
1010 let id = self.next_message_id.post_inc();
1011 self.messages.push(Message {
1012 id,
1013 role,
1014 segments,
1015 loaded_context,
1016 creases,
1017 is_hidden,
1018 });
1019 self.touch_updated_at();
1020 cx.emit(ThreadEvent::MessageAdded(id));
1021 id
1022 }
1023
1024 pub fn edit_message(
1025 &mut self,
1026 id: MessageId,
1027 new_role: Role,
1028 new_segments: Vec<MessageSegment>,
1029 loaded_context: Option<LoadedContext>,
1030 checkpoint: Option<GitStoreCheckpoint>,
1031 cx: &mut Context<Self>,
1032 ) -> bool {
1033 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
1034 return false;
1035 };
1036 message.role = new_role;
1037 message.segments = new_segments;
1038 if let Some(context) = loaded_context {
1039 message.loaded_context = context;
1040 }
1041 if let Some(git_checkpoint) = checkpoint {
1042 self.checkpoints_by_message.insert(
1043 id,
1044 ThreadCheckpoint {
1045 message_id: id,
1046 git_checkpoint,
1047 },
1048 );
1049 }
1050 self.touch_updated_at();
1051 cx.emit(ThreadEvent::MessageEdited(id));
1052 true
1053 }
1054
1055 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
1056 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
1057 return false;
1058 };
1059 self.messages.remove(index);
1060 self.touch_updated_at();
1061 cx.emit(ThreadEvent::MessageDeleted(id));
1062 true
1063 }
1064
1065 /// Returns the representation of this [`Thread`] in a textual form.
1066 ///
1067 /// This is the representation we use when attaching a thread as context to another thread.
1068 pub fn text(&self) -> String {
1069 let mut text = String::new();
1070
1071 for message in &self.messages {
1072 text.push_str(match message.role {
1073 language_model::Role::User => "User:",
1074 language_model::Role::Assistant => "Agent:",
1075 language_model::Role::System => "System:",
1076 });
1077 text.push('\n');
1078
1079 for segment in &message.segments {
1080 match segment {
1081 MessageSegment::Text(content) => text.push_str(content),
1082 MessageSegment::Thinking { text: content, .. } => {
1083 text.push_str(&format!("<think>{}</think>", content))
1084 }
1085 MessageSegment::RedactedThinking(_) => {}
1086 }
1087 }
1088 text.push('\n');
1089 }
1090
1091 text
1092 }
1093
1094 /// Serializes this thread into a format for storage or telemetry.
1095 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
1096 let initial_project_snapshot = self.initial_project_snapshot.clone();
1097 cx.spawn(async move |this, cx| {
1098 let initial_project_snapshot = initial_project_snapshot.await;
1099 this.read_with(cx, |this, cx| SerializedThread {
1100 version: SerializedThread::VERSION.to_string(),
1101 summary: this.summary().or_default(),
1102 updated_at: this.updated_at(),
1103 messages: this
1104 .messages()
1105 .map(|message| SerializedMessage {
1106 id: message.id,
1107 role: message.role,
1108 segments: message
1109 .segments
1110 .iter()
1111 .map(|segment| match segment {
1112 MessageSegment::Text(text) => {
1113 SerializedMessageSegment::Text { text: text.clone() }
1114 }
1115 MessageSegment::Thinking { text, signature } => {
1116 SerializedMessageSegment::Thinking {
1117 text: text.clone(),
1118 signature: signature.clone(),
1119 }
1120 }
1121 MessageSegment::RedactedThinking(data) => {
1122 SerializedMessageSegment::RedactedThinking {
1123 data: data.clone(),
1124 }
1125 }
1126 })
1127 .collect(),
1128 tool_uses: this
1129 .tool_uses_for_message(message.id, cx)
1130 .into_iter()
1131 .map(|tool_use| SerializedToolUse {
1132 id: tool_use.id,
1133 name: tool_use.name,
1134 input: tool_use.input,
1135 })
1136 .collect(),
1137 tool_results: this
1138 .tool_results_for_message(message.id)
1139 .into_iter()
1140 .map(|tool_result| SerializedToolResult {
1141 tool_use_id: tool_result.tool_use_id.clone(),
1142 is_error: tool_result.is_error,
1143 content: tool_result.content.clone(),
1144 output: tool_result.output.clone(),
1145 })
1146 .collect(),
1147 context: message.loaded_context.text.clone(),
1148 creases: message
1149 .creases
1150 .iter()
1151 .map(|crease| SerializedCrease {
1152 start: crease.range.start,
1153 end: crease.range.end,
1154 icon_path: crease.metadata.icon_path.clone(),
1155 label: crease.metadata.label.clone(),
1156 })
1157 .collect(),
1158 is_hidden: message.is_hidden,
1159 })
1160 .collect(),
1161 initial_project_snapshot,
1162 cumulative_token_usage: this.cumulative_token_usage,
1163 request_token_usage: this.request_token_usage.clone(),
1164 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
1165 exceeded_window_error: this.exceeded_window_error.clone(),
1166 model: this
1167 .configured_model
1168 .as_ref()
1169 .map(|model| SerializedLanguageModel {
1170 provider: model.provider.id().0.to_string(),
1171 model: model.model.id().0.to_string(),
1172 }),
1173 completion_mode: Some(this.completion_mode),
1174 tool_use_limit_reached: this.tool_use_limit_reached,
1175 })
1176 })
1177 }
1178
1179 pub fn remaining_turns(&self) -> u32 {
1180 self.remaining_turns
1181 }
1182
1183 pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
1184 self.remaining_turns = remaining_turns;
1185 }
1186
1187 pub fn send_to_model(
1188 &mut self,
1189 model: Arc<dyn LanguageModel>,
1190 intent: CompletionIntent,
1191 window: Option<AnyWindowHandle>,
1192 cx: &mut Context<Self>,
1193 ) {
1194 if self.remaining_turns == 0 {
1195 return;
1196 }
1197
1198 self.remaining_turns -= 1;
1199
1200 let request = self.to_completion_request(model.clone(), intent, cx);
1201
1202 self.stream_completion(request, model, window, cx);
1203 }
1204
1205 pub fn used_tools_since_last_user_message(&self) -> bool {
1206 for message in self.messages.iter().rev() {
1207 if self.tool_use.message_has_tool_results(message.id) {
1208 return true;
1209 } else if message.role == Role::User {
1210 return false;
1211 }
1212 }
1213
1214 false
1215 }
1216
1217 pub fn to_completion_request(
1218 &self,
1219 model: Arc<dyn LanguageModel>,
1220 intent: CompletionIntent,
1221 cx: &mut Context<Self>,
1222 ) -> LanguageModelRequest {
1223 let mut request = LanguageModelRequest {
1224 thread_id: Some(self.id.to_string()),
1225 prompt_id: Some(self.last_prompt_id.to_string()),
1226 intent: Some(intent),
1227 mode: None,
1228 messages: vec![],
1229 tools: Vec::new(),
1230 tool_choice: None,
1231 stop: Vec::new(),
1232 temperature: AgentSettings::temperature_for_model(&model, cx),
1233 };
1234
1235 let available_tools = self.available_tools(cx, model.clone());
1236 let available_tool_names = available_tools
1237 .iter()
1238 .map(|tool| tool.name.clone())
1239 .collect();
1240
1241 let model_context = &ModelContext {
1242 available_tools: available_tool_names,
1243 };
1244
1245 if let Some(project_context) = self.project_context.borrow().as_ref() {
1246 match self
1247 .prompt_builder
1248 .generate_assistant_system_prompt(project_context, model_context)
1249 {
1250 Err(err) => {
1251 let message = format!("{err:?}").into();
1252 log::error!("{message}");
1253 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1254 header: "Error generating system prompt".into(),
1255 message,
1256 }));
1257 }
1258 Ok(system_prompt) => {
1259 request.messages.push(LanguageModelRequestMessage {
1260 role: Role::System,
1261 content: vec![MessageContent::Text(system_prompt)],
1262 cache: true,
1263 });
1264 }
1265 }
1266 } else {
1267 let message = "Context for system prompt unexpectedly not ready.".into();
1268 log::error!("{message}");
1269 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1270 header: "Error generating system prompt".into(),
1271 message,
1272 }));
1273 }
1274
1275 let mut message_ix_to_cache = None;
1276 for message in &self.messages {
1277 let mut request_message = LanguageModelRequestMessage {
1278 role: message.role,
1279 content: Vec::new(),
1280 cache: false,
1281 };
1282
1283 message
1284 .loaded_context
1285 .add_to_request_message(&mut request_message);
1286
1287 for segment in &message.segments {
1288 match segment {
1289 MessageSegment::Text(text) => {
1290 if !text.is_empty() {
1291 request_message
1292 .content
1293 .push(MessageContent::Text(text.into()));
1294 }
1295 }
1296 MessageSegment::Thinking { text, signature } => {
1297 if !text.is_empty() {
1298 request_message.content.push(MessageContent::Thinking {
1299 text: text.into(),
1300 signature: signature.clone(),
1301 });
1302 }
1303 }
1304 MessageSegment::RedactedThinking(data) => {
1305 request_message
1306 .content
1307 .push(MessageContent::RedactedThinking(data.clone()));
1308 }
1309 };
1310 }
1311
1312 let mut cache_message = true;
1313 let mut tool_results_message = LanguageModelRequestMessage {
1314 role: Role::User,
1315 content: Vec::new(),
1316 cache: false,
1317 };
1318 for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
1319 if let Some(tool_result) = tool_result {
1320 request_message
1321 .content
1322 .push(MessageContent::ToolUse(tool_use.clone()));
1323 tool_results_message
1324 .content
1325 .push(MessageContent::ToolResult(LanguageModelToolResult {
1326 tool_use_id: tool_use.id.clone(),
1327 tool_name: tool_result.tool_name.clone(),
1328 is_error: tool_result.is_error,
1329 content: if tool_result.content.is_empty() {
1330 // Surprisingly, the API fails if we return an empty string here.
1331 // It thinks we are sending a tool use without a tool result.
1332 "<Tool returned an empty string>".into()
1333 } else {
1334 tool_result.content.clone()
1335 },
1336 output: None,
1337 }));
1338 } else {
1339 cache_message = false;
1340 log::debug!(
1341 "skipped tool use {:?} because it is still pending",
1342 tool_use
1343 );
1344 }
1345 }
1346
1347 if cache_message {
1348 message_ix_to_cache = Some(request.messages.len());
1349 }
1350 request.messages.push(request_message);
1351
1352 if !tool_results_message.content.is_empty() {
1353 if cache_message {
1354 message_ix_to_cache = Some(request.messages.len());
1355 }
1356 request.messages.push(tool_results_message);
1357 }
1358 }
1359
1360 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1361 if let Some(message_ix_to_cache) = message_ix_to_cache {
1362 request.messages[message_ix_to_cache].cache = true;
1363 }
1364
1365 self.attached_tracked_files_state(&mut request.messages, cx);
1366
1367 request.tools = available_tools;
1368 request.mode = if model.supports_max_mode() {
1369 Some(self.completion_mode.into())
1370 } else {
1371 Some(CompletionMode::Normal.into())
1372 };
1373
1374 request
1375 }
1376
1377 fn to_summarize_request(
1378 &self,
1379 model: &Arc<dyn LanguageModel>,
1380 intent: CompletionIntent,
1381 added_user_message: String,
1382 cx: &App,
1383 ) -> LanguageModelRequest {
1384 let mut request = LanguageModelRequest {
1385 thread_id: None,
1386 prompt_id: None,
1387 intent: Some(intent),
1388 mode: None,
1389 messages: vec![],
1390 tools: Vec::new(),
1391 tool_choice: None,
1392 stop: Vec::new(),
1393 temperature: AgentSettings::temperature_for_model(model, cx),
1394 };
1395
1396 for message in &self.messages {
1397 let mut request_message = LanguageModelRequestMessage {
1398 role: message.role,
1399 content: Vec::new(),
1400 cache: false,
1401 };
1402
1403 for segment in &message.segments {
1404 match segment {
1405 MessageSegment::Text(text) => request_message
1406 .content
1407 .push(MessageContent::Text(text.clone())),
1408 MessageSegment::Thinking { .. } => {}
1409 MessageSegment::RedactedThinking(_) => {}
1410 }
1411 }
1412
1413 if request_message.content.is_empty() {
1414 continue;
1415 }
1416
1417 request.messages.push(request_message);
1418 }
1419
1420 request.messages.push(LanguageModelRequestMessage {
1421 role: Role::User,
1422 content: vec![MessageContent::Text(added_user_message)],
1423 cache: false,
1424 });
1425
1426 request
1427 }
1428
1429 fn attached_tracked_files_state(
1430 &self,
1431 messages: &mut Vec<LanguageModelRequestMessage>,
1432 cx: &App,
1433 ) {
1434 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1435
1436 let mut stale_message = String::new();
1437
1438 let action_log = self.action_log.read(cx);
1439
1440 for stale_file in action_log.stale_buffers(cx) {
1441 let Some(file) = stale_file.read(cx).file() else {
1442 continue;
1443 };
1444
1445 if stale_message.is_empty() {
1446 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1447 }
1448
1449 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1450 }
1451
1452 let mut content = Vec::with_capacity(2);
1453
1454 if !stale_message.is_empty() {
1455 content.push(stale_message.into());
1456 }
1457
1458 if !content.is_empty() {
1459 let context_message = LanguageModelRequestMessage {
1460 role: Role::User,
1461 content,
1462 cache: false,
1463 };
1464
1465 messages.push(context_message);
1466 }
1467 }
1468
1469 pub fn stream_completion(
1470 &mut self,
1471 request: LanguageModelRequest,
1472 model: Arc<dyn LanguageModel>,
1473 window: Option<AnyWindowHandle>,
1474 cx: &mut Context<Self>,
1475 ) {
1476 self.tool_use_limit_reached = false;
1477
1478 let pending_completion_id = post_inc(&mut self.completion_count);
1479 let mut request_callback_parameters = if self.request_callback.is_some() {
1480 Some((request.clone(), Vec::new()))
1481 } else {
1482 None
1483 };
1484 let prompt_id = self.last_prompt_id.clone();
1485 let tool_use_metadata = ToolUseMetadata {
1486 model: model.clone(),
1487 thread_id: self.id.clone(),
1488 prompt_id: prompt_id.clone(),
1489 };
1490
1491 self.last_received_chunk_at = Some(Instant::now());
1492
1493 let task = cx.spawn(async move |thread, cx| {
1494 let stream_completion_future = model.stream_completion(request, &cx);
1495 let initial_token_usage =
1496 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1497 let stream_completion = async {
1498 let mut events = stream_completion_future.await?;
1499
1500 let mut stop_reason = StopReason::EndTurn;
1501 let mut current_token_usage = TokenUsage::default();
1502
1503 thread
1504 .update(cx, |_thread, cx| {
1505 cx.emit(ThreadEvent::NewRequest);
1506 })
1507 .ok();
1508
1509 let mut request_assistant_message_id = None;
1510
1511 while let Some(event) = events.next().await {
1512 if let Some((_, response_events)) = request_callback_parameters.as_mut() {
1513 response_events
1514 .push(event.as_ref().map_err(|error| error.to_string()).cloned());
1515 }
1516
1517 thread.update(cx, |thread, cx| {
1518 let event = match event {
1519 Ok(event) => event,
1520 Err(LanguageModelCompletionError::BadInputJson {
1521 id,
1522 tool_name,
1523 raw_input: invalid_input_json,
1524 json_parse_error,
1525 }) => {
1526 thread.receive_invalid_tool_json(
1527 id,
1528 tool_name,
1529 invalid_input_json,
1530 json_parse_error,
1531 window,
1532 cx,
1533 );
1534 return Ok(());
1535 }
1536 Err(LanguageModelCompletionError::Other(error)) => {
1537 return Err(error);
1538 }
1539 };
1540
1541 match event {
1542 LanguageModelCompletionEvent::StartMessage { .. } => {
1543 request_assistant_message_id =
1544 Some(thread.insert_assistant_message(
1545 vec![MessageSegment::Text(String::new())],
1546 cx,
1547 ));
1548 }
1549 LanguageModelCompletionEvent::Stop(reason) => {
1550 stop_reason = reason;
1551 }
1552 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1553 thread.update_token_usage_at_last_message(token_usage);
1554 thread.cumulative_token_usage = thread.cumulative_token_usage
1555 + token_usage
1556 - current_token_usage;
1557 current_token_usage = token_usage;
1558 }
1559 LanguageModelCompletionEvent::Text(chunk) => {
1560 thread.received_chunk();
1561
1562 cx.emit(ThreadEvent::ReceivedTextChunk);
1563 if let Some(last_message) = thread.messages.last_mut() {
1564 if last_message.role == Role::Assistant
1565 && !thread.tool_use.has_tool_results(last_message.id)
1566 {
1567 last_message.push_text(&chunk);
1568 cx.emit(ThreadEvent::StreamedAssistantText(
1569 last_message.id,
1570 chunk,
1571 ));
1572 } else {
1573 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1574 // of a new Assistant response.
1575 //
1576 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1577 // will result in duplicating the text of the chunk in the rendered Markdown.
1578 request_assistant_message_id =
1579 Some(thread.insert_assistant_message(
1580 vec![MessageSegment::Text(chunk.to_string())],
1581 cx,
1582 ));
1583 };
1584 }
1585 }
1586 LanguageModelCompletionEvent::Thinking {
1587 text: chunk,
1588 signature,
1589 } => {
1590 thread.received_chunk();
1591
1592 if let Some(last_message) = thread.messages.last_mut() {
1593 if last_message.role == Role::Assistant
1594 && !thread.tool_use.has_tool_results(last_message.id)
1595 {
1596 last_message.push_thinking(&chunk, signature);
1597 cx.emit(ThreadEvent::StreamedAssistantThinking(
1598 last_message.id,
1599 chunk,
1600 ));
1601 } else {
1602 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1603 // of a new Assistant response.
1604 //
1605 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1606 // will result in duplicating the text of the chunk in the rendered Markdown.
1607 request_assistant_message_id =
1608 Some(thread.insert_assistant_message(
1609 vec![MessageSegment::Thinking {
1610 text: chunk.to_string(),
1611 signature,
1612 }],
1613 cx,
1614 ));
1615 };
1616 }
1617 }
1618 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1619 let last_assistant_message_id = request_assistant_message_id
1620 .unwrap_or_else(|| {
1621 let new_assistant_message_id =
1622 thread.insert_assistant_message(vec![], cx);
1623 request_assistant_message_id =
1624 Some(new_assistant_message_id);
1625 new_assistant_message_id
1626 });
1627
1628 let tool_use_id = tool_use.id.clone();
1629 let streamed_input = if tool_use.is_input_complete {
1630 None
1631 } else {
1632 Some((&tool_use.input).clone())
1633 };
1634
1635 let ui_text = thread.tool_use.request_tool_use(
1636 last_assistant_message_id,
1637 tool_use,
1638 tool_use_metadata.clone(),
1639 cx,
1640 );
1641
1642 if let Some(input) = streamed_input {
1643 cx.emit(ThreadEvent::StreamedToolUse {
1644 tool_use_id,
1645 ui_text,
1646 input,
1647 });
1648 }
1649 }
1650 LanguageModelCompletionEvent::StatusUpdate(status_update) => {
1651 if let Some(completion) = thread
1652 .pending_completions
1653 .iter_mut()
1654 .find(|completion| completion.id == pending_completion_id)
1655 {
1656 match status_update {
1657 CompletionRequestStatus::Queued {
1658 position,
1659 } => {
1660 completion.queue_state = QueueState::Queued { position };
1661 }
1662 CompletionRequestStatus::Started => {
1663 completion.queue_state = QueueState::Started;
1664 }
1665 CompletionRequestStatus::Failed {
1666 code, message, request_id
1667 } => {
1668 anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
1669 }
1670 CompletionRequestStatus::UsageUpdated {
1671 amount, limit
1672 } => {
1673 let usage = RequestUsage { limit, amount: amount as i32 };
1674
1675 thread.last_usage = Some(usage);
1676 }
1677 CompletionRequestStatus::ToolUseLimitReached => {
1678 thread.tool_use_limit_reached = true;
1679 }
1680 }
1681 }
1682 }
1683 }
1684
1685 thread.touch_updated_at();
1686 cx.emit(ThreadEvent::StreamedCompletion);
1687 cx.notify();
1688
1689 thread.auto_capture_telemetry(cx);
1690 Ok(())
1691 })??;
1692
1693 smol::future::yield_now().await;
1694 }
1695
1696 thread.update(cx, |thread, cx| {
1697 thread.last_received_chunk_at = None;
1698 thread
1699 .pending_completions
1700 .retain(|completion| completion.id != pending_completion_id);
1701
1702 // If there is a response without tool use, summarize the message. Otherwise,
1703 // allow two tool uses before summarizing.
1704 if matches!(thread.summary, ThreadSummary::Pending)
1705 && thread.messages.len() >= 2
1706 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
1707 {
1708 thread.summarize(cx);
1709 }
1710 })?;
1711
1712 anyhow::Ok(stop_reason)
1713 };
1714
1715 let result = stream_completion.await;
1716
1717 thread
1718 .update(cx, |thread, cx| {
1719 thread.finalize_pending_checkpoint(cx);
1720 match result.as_ref() {
1721 Ok(stop_reason) => match stop_reason {
1722 StopReason::ToolUse => {
1723 let tool_uses = thread.use_pending_tools(window, cx, model.clone());
1724 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1725 }
1726 StopReason::EndTurn | StopReason::MaxTokens => {
1727 thread.project.update(cx, |project, cx| {
1728 project.set_agent_location(None, cx);
1729 });
1730 }
1731 StopReason::Refusal => {
1732 thread.project.update(cx, |project, cx| {
1733 project.set_agent_location(None, cx);
1734 });
1735
1736 // Remove the turn that was refused.
1737 //
1738 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
1739 {
1740 let mut messages_to_remove = Vec::new();
1741
1742 for (ix, message) in thread.messages.iter().enumerate().rev() {
1743 messages_to_remove.push(message.id);
1744
1745 if message.role == Role::User {
1746 if ix == 0 {
1747 break;
1748 }
1749
1750 if let Some(prev_message) = thread.messages.get(ix - 1) {
1751 if prev_message.role == Role::Assistant {
1752 break;
1753 }
1754 }
1755 }
1756 }
1757
1758 for message_id in messages_to_remove {
1759 thread.delete_message(message_id, cx);
1760 }
1761 }
1762
1763 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1764 header: "Language model refusal".into(),
1765 message: "Model refused to generate content for safety reasons.".into(),
1766 }));
1767 }
1768 },
1769 Err(error) => {
1770 thread.project.update(cx, |project, cx| {
1771 project.set_agent_location(None, cx);
1772 });
1773
1774 if error.is::<PaymentRequiredError>() {
1775 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1776 } else if let Some(error) =
1777 error.downcast_ref::<ModelRequestLimitReachedError>()
1778 {
1779 cx.emit(ThreadEvent::ShowError(
1780 ThreadError::ModelRequestLimitReached { plan: error.plan },
1781 ));
1782 } else if let Some(known_error) =
1783 error.downcast_ref::<LanguageModelKnownError>()
1784 {
1785 match known_error {
1786 LanguageModelKnownError::ContextWindowLimitExceeded {
1787 tokens,
1788 } => {
1789 thread.exceeded_window_error = Some(ExceededWindowError {
1790 model_id: model.id(),
1791 token_count: *tokens,
1792 });
1793 cx.notify();
1794 }
1795 }
1796 } else {
1797 let error_message = error
1798 .chain()
1799 .map(|err| err.to_string())
1800 .collect::<Vec<_>>()
1801 .join("\n");
1802 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1803 header: "Error interacting with language model".into(),
1804 message: SharedString::from(error_message.clone()),
1805 }));
1806 }
1807
1808 thread.cancel_last_completion(window, cx);
1809 }
1810 }
1811
1812 cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1813
1814 if let Some((request_callback, (request, response_events))) = thread
1815 .request_callback
1816 .as_mut()
1817 .zip(request_callback_parameters.as_ref())
1818 {
1819 request_callback(request, response_events);
1820 }
1821
1822 thread.auto_capture_telemetry(cx);
1823
1824 if let Ok(initial_usage) = initial_token_usage {
1825 let usage = thread.cumulative_token_usage - initial_usage;
1826
1827 telemetry::event!(
1828 "Assistant Thread Completion",
1829 thread_id = thread.id().to_string(),
1830 prompt_id = prompt_id,
1831 model = model.telemetry_id(),
1832 model_provider = model.provider_id().to_string(),
1833 input_tokens = usage.input_tokens,
1834 output_tokens = usage.output_tokens,
1835 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1836 cache_read_input_tokens = usage.cache_read_input_tokens,
1837 );
1838 }
1839 })
1840 .ok();
1841 });
1842
1843 self.pending_completions.push(PendingCompletion {
1844 id: pending_completion_id,
1845 queue_state: QueueState::Sending,
1846 _task: task,
1847 });
1848 }
1849
1850 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1851 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1852 println!("No thread summary model");
1853 return;
1854 };
1855
1856 if !model.provider.is_authenticated(cx) {
1857 return;
1858 }
1859
1860 let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1861 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1862 If the conversation is about a specific subject, include it in the title. \
1863 Be descriptive. DO NOT speak in the first person.";
1864
1865 let request = self.to_summarize_request(
1866 &model.model,
1867 CompletionIntent::ThreadSummarization,
1868 added_user_message.into(),
1869 cx,
1870 );
1871
1872 self.summary = ThreadSummary::Generating;
1873
1874 self.pending_summary = cx.spawn(async move |this, cx| {
1875 let result = async {
1876 let mut messages = model.model.stream_completion(request, &cx).await?;
1877
1878 let mut new_summary = String::new();
1879 while let Some(event) = messages.next().await {
1880 let Ok(event) = event else {
1881 continue;
1882 };
1883 let text = match event {
1884 LanguageModelCompletionEvent::Text(text) => text,
1885 LanguageModelCompletionEvent::StatusUpdate(
1886 CompletionRequestStatus::UsageUpdated { amount, limit },
1887 ) => {
1888 this.update(cx, |thread, _cx| {
1889 thread.last_usage = Some(RequestUsage {
1890 limit,
1891 amount: amount as i32,
1892 });
1893 })?;
1894 continue;
1895 }
1896 _ => continue,
1897 };
1898
1899 let mut lines = text.lines();
1900 new_summary.extend(lines.next());
1901
1902 // Stop if the LLM generated multiple lines.
1903 if lines.next().is_some() {
1904 break;
1905 }
1906 }
1907
1908 anyhow::Ok(new_summary)
1909 }
1910 .await;
1911
1912 this.update(cx, |this, cx| {
1913 match result {
1914 Ok(new_summary) => {
1915 if new_summary.is_empty() {
1916 this.summary = ThreadSummary::Error;
1917 } else {
1918 this.summary = ThreadSummary::Ready(new_summary.into());
1919 }
1920 }
1921 Err(err) => {
1922 this.summary = ThreadSummary::Error;
1923 log::error!("Failed to generate thread summary: {}", err);
1924 }
1925 }
1926 cx.emit(ThreadEvent::SummaryGenerated);
1927 })
1928 .log_err()?;
1929
1930 Some(())
1931 });
1932 }
1933
1934 pub fn start_generating_detailed_summary_if_needed(
1935 &mut self,
1936 thread_store: WeakEntity<ThreadStore>,
1937 cx: &mut Context<Self>,
1938 ) {
1939 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
1940 return;
1941 };
1942
1943 match &*self.detailed_summary_rx.borrow() {
1944 DetailedSummaryState::Generating { message_id, .. }
1945 | DetailedSummaryState::Generated { message_id, .. }
1946 if *message_id == last_message_id =>
1947 {
1948 // Already up-to-date
1949 return;
1950 }
1951 _ => {}
1952 }
1953
1954 let Some(ConfiguredModel { model, provider }) =
1955 LanguageModelRegistry::read_global(cx).thread_summary_model()
1956 else {
1957 return;
1958 };
1959
1960 if !provider.is_authenticated(cx) {
1961 return;
1962 }
1963
1964 let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1965 1. A brief overview of what was discussed\n\
1966 2. Key facts or information discovered\n\
1967 3. Outcomes or conclusions reached\n\
1968 4. Any action items or next steps if any\n\
1969 Format it in Markdown with headings and bullet points.";
1970
1971 let request = self.to_summarize_request(
1972 &model,
1973 CompletionIntent::ThreadContextSummarization,
1974 added_user_message.into(),
1975 cx,
1976 );
1977
1978 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
1979 message_id: last_message_id,
1980 };
1981
1982 // Replace the detailed summarization task if there is one, cancelling it. It would probably
1983 // be better to allow the old task to complete, but this would require logic for choosing
1984 // which result to prefer (the old task could complete after the new one, resulting in a
1985 // stale summary).
1986 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
1987 let stream = model.stream_completion_text(request, &cx);
1988 let Some(mut messages) = stream.await.log_err() else {
1989 thread
1990 .update(cx, |thread, _cx| {
1991 *thread.detailed_summary_tx.borrow_mut() =
1992 DetailedSummaryState::NotGenerated;
1993 })
1994 .ok()?;
1995 return None;
1996 };
1997
1998 let mut new_detailed_summary = String::new();
1999
2000 while let Some(chunk) = messages.stream.next().await {
2001 if let Some(chunk) = chunk.log_err() {
2002 new_detailed_summary.push_str(&chunk);
2003 }
2004 }
2005
2006 thread
2007 .update(cx, |thread, _cx| {
2008 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
2009 text: new_detailed_summary.into(),
2010 message_id: last_message_id,
2011 };
2012 })
2013 .ok()?;
2014
2015 // Save thread so its summary can be reused later
2016 if let Some(thread) = thread.upgrade() {
2017 if let Ok(Ok(save_task)) = cx.update(|cx| {
2018 thread_store
2019 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
2020 }) {
2021 save_task.await.log_err();
2022 }
2023 }
2024
2025 Some(())
2026 });
2027 }
2028
2029 pub async fn wait_for_detailed_summary_or_text(
2030 this: &Entity<Self>,
2031 cx: &mut AsyncApp,
2032 ) -> Option<SharedString> {
2033 let mut detailed_summary_rx = this
2034 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
2035 .ok()?;
2036 loop {
2037 match detailed_summary_rx.recv().await? {
2038 DetailedSummaryState::Generating { .. } => {}
2039 DetailedSummaryState::NotGenerated => {
2040 return this.read_with(cx, |this, _cx| this.text().into()).ok();
2041 }
2042 DetailedSummaryState::Generated { text, .. } => return Some(text),
2043 }
2044 }
2045 }
2046
2047 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
2048 self.detailed_summary_rx
2049 .borrow()
2050 .text()
2051 .unwrap_or_else(|| self.text().into())
2052 }
2053
2054 pub fn is_generating_detailed_summary(&self) -> bool {
2055 matches!(
2056 &*self.detailed_summary_rx.borrow(),
2057 DetailedSummaryState::Generating { .. }
2058 )
2059 }
2060
2061 pub fn use_pending_tools(
2062 &mut self,
2063 window: Option<AnyWindowHandle>,
2064 cx: &mut Context<Self>,
2065 model: Arc<dyn LanguageModel>,
2066 ) -> Vec<PendingToolUse> {
2067 self.auto_capture_telemetry(cx);
2068 let request =
2069 Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
2070 let pending_tool_uses = self
2071 .tool_use
2072 .pending_tool_uses()
2073 .into_iter()
2074 .filter(|tool_use| tool_use.status.is_idle())
2075 .cloned()
2076 .collect::<Vec<_>>();
2077
2078 for tool_use in pending_tool_uses.iter() {
2079 if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
2080 if tool.needs_confirmation(&tool_use.input, cx)
2081 && !AgentSettings::get_global(cx).always_allow_tool_actions
2082 {
2083 self.tool_use.confirm_tool_use(
2084 tool_use.id.clone(),
2085 tool_use.ui_text.clone(),
2086 tool_use.input.clone(),
2087 request.clone(),
2088 tool,
2089 );
2090 cx.emit(ThreadEvent::ToolConfirmationNeeded);
2091 } else {
2092 self.run_tool(
2093 tool_use.id.clone(),
2094 tool_use.ui_text.clone(),
2095 tool_use.input.clone(),
2096 request.clone(),
2097 tool,
2098 model.clone(),
2099 window,
2100 cx,
2101 );
2102 }
2103 } else {
2104 self.handle_hallucinated_tool_use(
2105 tool_use.id.clone(),
2106 tool_use.name.clone(),
2107 window,
2108 cx,
2109 );
2110 }
2111 }
2112
2113 pending_tool_uses
2114 }
2115
2116 pub fn handle_hallucinated_tool_use(
2117 &mut self,
2118 tool_use_id: LanguageModelToolUseId,
2119 hallucinated_tool_name: Arc<str>,
2120 window: Option<AnyWindowHandle>,
2121 cx: &mut Context<Thread>,
2122 ) {
2123 let available_tools = self.tools.read(cx).enabled_tools(cx);
2124
2125 let tool_list = available_tools
2126 .iter()
2127 .map(|tool| format!("- {}: {}", tool.name(), tool.description()))
2128 .collect::<Vec<_>>()
2129 .join("\n");
2130
2131 let error_message = format!(
2132 "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}",
2133 hallucinated_tool_name, tool_list
2134 );
2135
2136 let pending_tool_use = self.tool_use.insert_tool_output(
2137 tool_use_id.clone(),
2138 hallucinated_tool_name,
2139 Err(anyhow!("Missing tool call: {error_message}")),
2140 self.configured_model.as_ref(),
2141 );
2142
2143 cx.emit(ThreadEvent::MissingToolUse {
2144 tool_use_id: tool_use_id.clone(),
2145 ui_text: error_message.into(),
2146 });
2147
2148 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2149 }
2150
2151 pub fn receive_invalid_tool_json(
2152 &mut self,
2153 tool_use_id: LanguageModelToolUseId,
2154 tool_name: Arc<str>,
2155 invalid_json: Arc<str>,
2156 error: String,
2157 window: Option<AnyWindowHandle>,
2158 cx: &mut Context<Thread>,
2159 ) {
2160 log::error!("The model returned invalid input JSON: {invalid_json}");
2161
2162 let pending_tool_use = self.tool_use.insert_tool_output(
2163 tool_use_id.clone(),
2164 tool_name,
2165 Err(anyhow!("Error parsing input JSON: {error}")),
2166 self.configured_model.as_ref(),
2167 );
2168 let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
2169 pending_tool_use.ui_text.clone()
2170 } else {
2171 log::error!(
2172 "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
2173 );
2174 format!("Unknown tool {}", tool_use_id).into()
2175 };
2176
2177 cx.emit(ThreadEvent::InvalidToolInput {
2178 tool_use_id: tool_use_id.clone(),
2179 ui_text,
2180 invalid_input_json: invalid_json,
2181 });
2182
2183 self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2184 }
2185
2186 pub fn run_tool(
2187 &mut self,
2188 tool_use_id: LanguageModelToolUseId,
2189 ui_text: impl Into<SharedString>,
2190 input: serde_json::Value,
2191 request: Arc<LanguageModelRequest>,
2192 tool: Arc<dyn Tool>,
2193 model: Arc<dyn LanguageModel>,
2194 window: Option<AnyWindowHandle>,
2195 cx: &mut Context<Thread>,
2196 ) {
2197 let task =
2198 self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
2199 self.tool_use
2200 .run_pending_tool(tool_use_id, ui_text.into(), task);
2201 }
2202
2203 fn spawn_tool_use(
2204 &mut self,
2205 tool_use_id: LanguageModelToolUseId,
2206 request: Arc<LanguageModelRequest>,
2207 input: serde_json::Value,
2208 tool: Arc<dyn Tool>,
2209 model: Arc<dyn LanguageModel>,
2210 window: Option<AnyWindowHandle>,
2211 cx: &mut Context<Thread>,
2212 ) -> Task<()> {
2213 let tool_name: Arc<str> = tool.name().into();
2214
2215 let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
2216 Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
2217 } else {
2218 tool.run(
2219 input,
2220 request,
2221 self.project.clone(),
2222 self.action_log.clone(),
2223 model,
2224 window,
2225 cx,
2226 )
2227 };
2228
2229 // Store the card separately if it exists
2230 if let Some(card) = tool_result.card.clone() {
2231 self.tool_use
2232 .insert_tool_result_card(tool_use_id.clone(), card);
2233 }
2234
2235 cx.spawn({
2236 async move |thread: WeakEntity<Thread>, cx| {
2237 let output = tool_result.output.await;
2238
2239 thread
2240 .update(cx, |thread, cx| {
2241 let pending_tool_use = thread.tool_use.insert_tool_output(
2242 tool_use_id.clone(),
2243 tool_name,
2244 output,
2245 thread.configured_model.as_ref(),
2246 );
2247 thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
2248 })
2249 .ok();
2250 }
2251 })
2252 }
2253
2254 fn tool_finished(
2255 &mut self,
2256 tool_use_id: LanguageModelToolUseId,
2257 pending_tool_use: Option<PendingToolUse>,
2258 canceled: bool,
2259 window: Option<AnyWindowHandle>,
2260 cx: &mut Context<Self>,
2261 ) {
2262 if self.all_tools_finished() {
2263 if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
2264 if !canceled {
2265 self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
2266 }
2267 self.auto_capture_telemetry(cx);
2268 }
2269 }
2270
2271 cx.emit(ThreadEvent::ToolFinished {
2272 tool_use_id,
2273 pending_tool_use,
2274 });
2275 }
2276
2277 /// Cancels the last pending completion, if there are any pending.
2278 ///
2279 /// Returns whether a completion was canceled.
2280 pub fn cancel_last_completion(
2281 &mut self,
2282 window: Option<AnyWindowHandle>,
2283 cx: &mut Context<Self>,
2284 ) -> bool {
2285 let mut canceled = self.pending_completions.pop().is_some();
2286
2287 for pending_tool_use in self.tool_use.cancel_pending() {
2288 canceled = true;
2289 self.tool_finished(
2290 pending_tool_use.id.clone(),
2291 Some(pending_tool_use),
2292 true,
2293 window,
2294 cx,
2295 );
2296 }
2297
2298 if canceled {
2299 cx.emit(ThreadEvent::CompletionCanceled);
2300
2301 // When canceled, we always want to insert the checkpoint.
2302 // (We skip over finalize_pending_checkpoint, because it
2303 // would conclude we didn't have anything to insert here.)
2304 if let Some(checkpoint) = self.pending_checkpoint.take() {
2305 self.insert_checkpoint(checkpoint, cx);
2306 }
2307 } else {
2308 self.finalize_pending_checkpoint(cx);
2309 }
2310
2311 canceled
2312 }
2313
2314 /// Signals that any in-progress editing should be canceled.
2315 ///
2316 /// This method is used to notify listeners (like ActiveThread) that
2317 /// they should cancel any editing operations.
2318 pub fn cancel_editing(&mut self, cx: &mut Context<Self>) {
2319 cx.emit(ThreadEvent::CancelEditing);
2320 }
2321
2322 pub fn feedback(&self) -> Option<ThreadFeedback> {
2323 self.feedback
2324 }
2325
2326 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
2327 self.message_feedback.get(&message_id).copied()
2328 }
2329
2330 pub fn report_message_feedback(
2331 &mut self,
2332 message_id: MessageId,
2333 feedback: ThreadFeedback,
2334 cx: &mut Context<Self>,
2335 ) -> Task<Result<()>> {
2336 if self.message_feedback.get(&message_id) == Some(&feedback) {
2337 return Task::ready(Ok(()));
2338 }
2339
2340 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2341 let serialized_thread = self.serialize(cx);
2342 let thread_id = self.id().clone();
2343 let client = self.project.read(cx).client();
2344
2345 let enabled_tool_names: Vec<String> = self
2346 .tools()
2347 .read(cx)
2348 .enabled_tools(cx)
2349 .iter()
2350 .map(|tool| tool.name())
2351 .collect();
2352
2353 self.message_feedback.insert(message_id, feedback);
2354
2355 cx.notify();
2356
2357 let message_content = self
2358 .message(message_id)
2359 .map(|msg| msg.to_string())
2360 .unwrap_or_default();
2361
2362 cx.background_spawn(async move {
2363 let final_project_snapshot = final_project_snapshot.await;
2364 let serialized_thread = serialized_thread.await?;
2365 let thread_data =
2366 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
2367
2368 let rating = match feedback {
2369 ThreadFeedback::Positive => "positive",
2370 ThreadFeedback::Negative => "negative",
2371 };
2372 telemetry::event!(
2373 "Assistant Thread Rated",
2374 rating,
2375 thread_id,
2376 enabled_tool_names,
2377 message_id = message_id.0,
2378 message_content,
2379 thread_data,
2380 final_project_snapshot
2381 );
2382 client.telemetry().flush_events().await;
2383
2384 Ok(())
2385 })
2386 }
2387
2388 pub fn report_feedback(
2389 &mut self,
2390 feedback: ThreadFeedback,
2391 cx: &mut Context<Self>,
2392 ) -> Task<Result<()>> {
2393 let last_assistant_message_id = self
2394 .messages
2395 .iter()
2396 .rev()
2397 .find(|msg| msg.role == Role::Assistant)
2398 .map(|msg| msg.id);
2399
2400 if let Some(message_id) = last_assistant_message_id {
2401 self.report_message_feedback(message_id, feedback, cx)
2402 } else {
2403 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
2404 let serialized_thread = self.serialize(cx);
2405 let thread_id = self.id().clone();
2406 let client = self.project.read(cx).client();
2407 self.feedback = Some(feedback);
2408 cx.notify();
2409
2410 cx.background_spawn(async move {
2411 let final_project_snapshot = final_project_snapshot.await;
2412 let serialized_thread = serialized_thread.await?;
2413 let thread_data = serde_json::to_value(serialized_thread)
2414 .unwrap_or_else(|_| serde_json::Value::Null);
2415
2416 let rating = match feedback {
2417 ThreadFeedback::Positive => "positive",
2418 ThreadFeedback::Negative => "negative",
2419 };
2420 telemetry::event!(
2421 "Assistant Thread Rated",
2422 rating,
2423 thread_id,
2424 thread_data,
2425 final_project_snapshot
2426 );
2427 client.telemetry().flush_events().await;
2428
2429 Ok(())
2430 })
2431 }
2432 }
2433
2434 /// Create a snapshot of the current project state including git information and unsaved buffers.
2435 fn project_snapshot(
2436 project: Entity<Project>,
2437 cx: &mut Context<Self>,
2438 ) -> Task<Arc<ProjectSnapshot>> {
2439 let git_store = project.read(cx).git_store().clone();
2440 let worktree_snapshots: Vec<_> = project
2441 .read(cx)
2442 .visible_worktrees(cx)
2443 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
2444 .collect();
2445
2446 cx.spawn(async move |_, cx| {
2447 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
2448
2449 let mut unsaved_buffers = Vec::new();
2450 cx.update(|app_cx| {
2451 let buffer_store = project.read(app_cx).buffer_store();
2452 for buffer_handle in buffer_store.read(app_cx).buffers() {
2453 let buffer = buffer_handle.read(app_cx);
2454 if buffer.is_dirty() {
2455 if let Some(file) = buffer.file() {
2456 let path = file.path().to_string_lossy().to_string();
2457 unsaved_buffers.push(path);
2458 }
2459 }
2460 }
2461 })
2462 .ok();
2463
2464 Arc::new(ProjectSnapshot {
2465 worktree_snapshots,
2466 unsaved_buffer_paths: unsaved_buffers,
2467 timestamp: Utc::now(),
2468 })
2469 })
2470 }
2471
2472 fn worktree_snapshot(
2473 worktree: Entity<project::Worktree>,
2474 git_store: Entity<GitStore>,
2475 cx: &App,
2476 ) -> Task<WorktreeSnapshot> {
2477 cx.spawn(async move |cx| {
2478 // Get worktree path and snapshot
2479 let worktree_info = cx.update(|app_cx| {
2480 let worktree = worktree.read(app_cx);
2481 let path = worktree.abs_path().to_string_lossy().to_string();
2482 let snapshot = worktree.snapshot();
2483 (path, snapshot)
2484 });
2485
2486 let Ok((worktree_path, _snapshot)) = worktree_info else {
2487 return WorktreeSnapshot {
2488 worktree_path: String::new(),
2489 git_state: None,
2490 };
2491 };
2492
2493 let git_state = git_store
2494 .update(cx, |git_store, cx| {
2495 git_store
2496 .repositories()
2497 .values()
2498 .find(|repo| {
2499 repo.read(cx)
2500 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
2501 .is_some()
2502 })
2503 .cloned()
2504 })
2505 .ok()
2506 .flatten()
2507 .map(|repo| {
2508 repo.update(cx, |repo, _| {
2509 let current_branch =
2510 repo.branch.as_ref().map(|branch| branch.name().to_owned());
2511 repo.send_job(None, |state, _| async move {
2512 let RepositoryState::Local { backend, .. } = state else {
2513 return GitState {
2514 remote_url: None,
2515 head_sha: None,
2516 current_branch,
2517 diff: None,
2518 };
2519 };
2520
2521 let remote_url = backend.remote_url("origin");
2522 let head_sha = backend.head_sha().await;
2523 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
2524
2525 GitState {
2526 remote_url,
2527 head_sha,
2528 current_branch,
2529 diff,
2530 }
2531 })
2532 })
2533 });
2534
2535 let git_state = match git_state {
2536 Some(git_state) => match git_state.ok() {
2537 Some(git_state) => git_state.await.ok(),
2538 None => None,
2539 },
2540 None => None,
2541 };
2542
2543 WorktreeSnapshot {
2544 worktree_path,
2545 git_state,
2546 }
2547 })
2548 }
2549
2550 pub fn to_markdown(&self, cx: &App) -> Result<String> {
2551 let mut markdown = Vec::new();
2552
2553 let summary = self.summary().or_default();
2554 writeln!(markdown, "# {summary}\n")?;
2555
2556 for message in self.messages() {
2557 writeln!(
2558 markdown,
2559 "## {role}\n",
2560 role = match message.role {
2561 Role::User => "User",
2562 Role::Assistant => "Agent",
2563 Role::System => "System",
2564 }
2565 )?;
2566
2567 if !message.loaded_context.text.is_empty() {
2568 writeln!(markdown, "{}", message.loaded_context.text)?;
2569 }
2570
2571 if !message.loaded_context.images.is_empty() {
2572 writeln!(
2573 markdown,
2574 "\n{} images attached as context.\n",
2575 message.loaded_context.images.len()
2576 )?;
2577 }
2578
2579 for segment in &message.segments {
2580 match segment {
2581 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
2582 MessageSegment::Thinking { text, .. } => {
2583 writeln!(markdown, "<think>\n{}\n</think>\n", text)?
2584 }
2585 MessageSegment::RedactedThinking(_) => {}
2586 }
2587 }
2588
2589 for tool_use in self.tool_uses_for_message(message.id, cx) {
2590 writeln!(
2591 markdown,
2592 "**Use Tool: {} ({})**",
2593 tool_use.name, tool_use.id
2594 )?;
2595 writeln!(markdown, "```json")?;
2596 writeln!(
2597 markdown,
2598 "{}",
2599 serde_json::to_string_pretty(&tool_use.input)?
2600 )?;
2601 writeln!(markdown, "```")?;
2602 }
2603
2604 for tool_result in self.tool_results_for_message(message.id) {
2605 write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
2606 if tool_result.is_error {
2607 write!(markdown, " (Error)")?;
2608 }
2609
2610 writeln!(markdown, "**\n")?;
2611 match &tool_result.content {
2612 LanguageModelToolResultContent::Text(text)
2613 | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
2614 text,
2615 ..
2616 }) => {
2617 writeln!(markdown, "{text}")?;
2618 }
2619 LanguageModelToolResultContent::Image(image) => {
2620 writeln!(markdown, "", image.source)?;
2621 }
2622 }
2623
2624 if let Some(output) = tool_result.output.as_ref() {
2625 writeln!(
2626 markdown,
2627 "\n\nDebug Output:\n\n```json\n{}\n```\n",
2628 serde_json::to_string_pretty(output)?
2629 )?;
2630 }
2631 }
2632 }
2633
2634 Ok(String::from_utf8_lossy(&markdown).to_string())
2635 }
2636
2637 pub fn keep_edits_in_range(
2638 &mut self,
2639 buffer: Entity<language::Buffer>,
2640 buffer_range: Range<language::Anchor>,
2641 cx: &mut Context<Self>,
2642 ) {
2643 self.action_log.update(cx, |action_log, cx| {
2644 action_log.keep_edits_in_range(buffer, buffer_range, cx)
2645 });
2646 }
2647
2648 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
2649 self.action_log
2650 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
2651 }
2652
2653 pub fn reject_edits_in_ranges(
2654 &mut self,
2655 buffer: Entity<language::Buffer>,
2656 buffer_ranges: Vec<Range<language::Anchor>>,
2657 cx: &mut Context<Self>,
2658 ) -> Task<Result<()>> {
2659 self.action_log.update(cx, |action_log, cx| {
2660 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
2661 })
2662 }
2663
2664 pub fn action_log(&self) -> &Entity<ActionLog> {
2665 &self.action_log
2666 }
2667
2668 pub fn project(&self) -> &Entity<Project> {
2669 &self.project
2670 }
2671
2672 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
2673 if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
2674 return;
2675 }
2676
2677 let now = Instant::now();
2678 if let Some(last) = self.last_auto_capture_at {
2679 if now.duration_since(last).as_secs() < 10 {
2680 return;
2681 }
2682 }
2683
2684 self.last_auto_capture_at = Some(now);
2685
2686 let thread_id = self.id().clone();
2687 let github_login = self
2688 .project
2689 .read(cx)
2690 .user_store()
2691 .read(cx)
2692 .current_user()
2693 .map(|user| user.github_login.clone());
2694 let client = self.project.read(cx).client();
2695 let serialize_task = self.serialize(cx);
2696
2697 cx.background_executor()
2698 .spawn(async move {
2699 if let Ok(serialized_thread) = serialize_task.await {
2700 if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
2701 telemetry::event!(
2702 "Agent Thread Auto-Captured",
2703 thread_id = thread_id.to_string(),
2704 thread_data = thread_data,
2705 auto_capture_reason = "tracked_user",
2706 github_login = github_login
2707 );
2708
2709 client.telemetry().flush_events().await;
2710 }
2711 }
2712 })
2713 .detach();
2714 }
2715
2716 pub fn cumulative_token_usage(&self) -> TokenUsage {
2717 self.cumulative_token_usage
2718 }
2719
2720 pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
2721 let Some(model) = self.configured_model.as_ref() else {
2722 return TotalTokenUsage::default();
2723 };
2724
2725 let max = model.model.max_token_count();
2726
2727 let index = self
2728 .messages
2729 .iter()
2730 .position(|msg| msg.id == message_id)
2731 .unwrap_or(0);
2732
2733 if index == 0 {
2734 return TotalTokenUsage { total: 0, max };
2735 }
2736
2737 let token_usage = &self
2738 .request_token_usage
2739 .get(index - 1)
2740 .cloned()
2741 .unwrap_or_default();
2742
2743 TotalTokenUsage {
2744 total: token_usage.total_tokens() as usize,
2745 max,
2746 }
2747 }
2748
2749 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
2750 let model = self.configured_model.as_ref()?;
2751
2752 let max = model.model.max_token_count();
2753
2754 if let Some(exceeded_error) = &self.exceeded_window_error {
2755 if model.model.id() == exceeded_error.model_id {
2756 return Some(TotalTokenUsage {
2757 total: exceeded_error.token_count,
2758 max,
2759 });
2760 }
2761 }
2762
2763 let total = self
2764 .token_usage_at_last_message()
2765 .unwrap_or_default()
2766 .total_tokens() as usize;
2767
2768 Some(TotalTokenUsage { total, max })
2769 }
2770
2771 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2772 self.request_token_usage
2773 .get(self.messages.len().saturating_sub(1))
2774 .or_else(|| self.request_token_usage.last())
2775 .cloned()
2776 }
2777
2778 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2779 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2780 self.request_token_usage
2781 .resize(self.messages.len(), placeholder);
2782
2783 if let Some(last) = self.request_token_usage.last_mut() {
2784 *last = token_usage;
2785 }
2786 }
2787
2788 pub fn deny_tool_use(
2789 &mut self,
2790 tool_use_id: LanguageModelToolUseId,
2791 tool_name: Arc<str>,
2792 window: Option<AnyWindowHandle>,
2793 cx: &mut Context<Self>,
2794 ) {
2795 let err = Err(anyhow::anyhow!(
2796 "Permission to run tool action denied by user"
2797 ));
2798
2799 self.tool_use.insert_tool_output(
2800 tool_use_id.clone(),
2801 tool_name,
2802 err,
2803 self.configured_model.as_ref(),
2804 );
2805 self.tool_finished(tool_use_id.clone(), None, true, window, cx);
2806 }
2807}
2808
2809#[derive(Debug, Clone, Error)]
2810pub enum ThreadError {
2811 #[error("Payment required")]
2812 PaymentRequired,
2813 #[error("Model request limit reached")]
2814 ModelRequestLimitReached { plan: Plan },
2815 #[error("Message {header}: {message}")]
2816 Message {
2817 header: SharedString,
2818 message: SharedString,
2819 },
2820}
2821
2822#[derive(Debug, Clone)]
2823pub enum ThreadEvent {
2824 ShowError(ThreadError),
2825 StreamedCompletion,
2826 ReceivedTextChunk,
2827 NewRequest,
2828 StreamedAssistantText(MessageId, String),
2829 StreamedAssistantThinking(MessageId, String),
2830 StreamedToolUse {
2831 tool_use_id: LanguageModelToolUseId,
2832 ui_text: Arc<str>,
2833 input: serde_json::Value,
2834 },
2835 MissingToolUse {
2836 tool_use_id: LanguageModelToolUseId,
2837 ui_text: Arc<str>,
2838 },
2839 InvalidToolInput {
2840 tool_use_id: LanguageModelToolUseId,
2841 ui_text: Arc<str>,
2842 invalid_input_json: Arc<str>,
2843 },
2844 Stopped(Result<StopReason, Arc<anyhow::Error>>),
2845 MessageAdded(MessageId),
2846 MessageEdited(MessageId),
2847 MessageDeleted(MessageId),
2848 SummaryGenerated,
2849 SummaryChanged,
2850 UsePendingTools {
2851 tool_uses: Vec<PendingToolUse>,
2852 },
2853 ToolFinished {
2854 #[allow(unused)]
2855 tool_use_id: LanguageModelToolUseId,
2856 /// The pending tool use that corresponds to this tool.
2857 pending_tool_use: Option<PendingToolUse>,
2858 },
2859 CheckpointChanged,
2860 ToolConfirmationNeeded,
2861 CancelEditing,
2862 CompletionCanceled,
2863}
2864
2865impl EventEmitter<ThreadEvent> for Thread {}
2866
2867struct PendingCompletion {
2868 id: usize,
2869 queue_state: QueueState,
2870 _task: Task<()>,
2871}
2872
2873#[cfg(test)]
2874mod tests {
2875 use super::*;
2876 use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
2877 use agent_settings::{AgentSettings, LanguageModelParameters};
2878 use assistant_tool::ToolRegistry;
2879 use editor::EditorSettings;
2880 use gpui::TestAppContext;
2881 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
2882 use project::{FakeFs, Project};
2883 use prompt_store::PromptBuilder;
2884 use serde_json::json;
2885 use settings::{Settings, SettingsStore};
2886 use std::sync::Arc;
2887 use theme::ThemeSettings;
2888 use util::path;
2889 use workspace::Workspace;
2890
2891 #[gpui::test]
2892 async fn test_message_with_context(cx: &mut TestAppContext) {
2893 init_test_settings(cx);
2894
2895 let project = create_test_project(
2896 cx,
2897 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2898 )
2899 .await;
2900
2901 let (_workspace, _thread_store, thread, context_store, model) =
2902 setup_test_environment(cx, project.clone()).await;
2903
2904 add_file_to_context(&project, &context_store, "test/code.rs", cx)
2905 .await
2906 .unwrap();
2907
2908 let context =
2909 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
2910 let loaded_context = cx
2911 .update(|cx| load_context(vec![context], &project, &None, cx))
2912 .await;
2913
2914 // Insert user message with context
2915 let message_id = thread.update(cx, |thread, cx| {
2916 thread.insert_user_message(
2917 "Please explain this code",
2918 loaded_context,
2919 None,
2920 Vec::new(),
2921 cx,
2922 )
2923 });
2924
2925 // Check content and context in message object
2926 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2927
2928 // Use different path format strings based on platform for the test
2929 #[cfg(windows)]
2930 let path_part = r"test\code.rs";
2931 #[cfg(not(windows))]
2932 let path_part = "test/code.rs";
2933
2934 let expected_context = format!(
2935 r#"
2936<context>
2937The following items were attached by the user. They are up-to-date and don't need to be re-read.
2938
2939<files>
2940```rs {path_part}
2941fn main() {{
2942 println!("Hello, world!");
2943}}
2944```
2945</files>
2946</context>
2947"#
2948 );
2949
2950 assert_eq!(message.role, Role::User);
2951 assert_eq!(message.segments.len(), 1);
2952 assert_eq!(
2953 message.segments[0],
2954 MessageSegment::Text("Please explain this code".to_string())
2955 );
2956 assert_eq!(message.loaded_context.text, expected_context);
2957
2958 // Check message in request
2959 let request = thread.update(cx, |thread, cx| {
2960 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
2961 });
2962
2963 assert_eq!(request.messages.len(), 2);
2964 let expected_full_message = format!("{}Please explain this code", expected_context);
2965 assert_eq!(request.messages[1].string_contents(), expected_full_message);
2966 }
2967
2968 #[gpui::test]
2969 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2970 init_test_settings(cx);
2971
2972 let project = create_test_project(
2973 cx,
2974 json!({
2975 "file1.rs": "fn function1() {}\n",
2976 "file2.rs": "fn function2() {}\n",
2977 "file3.rs": "fn function3() {}\n",
2978 "file4.rs": "fn function4() {}\n",
2979 }),
2980 )
2981 .await;
2982
2983 let (_, _thread_store, thread, context_store, model) =
2984 setup_test_environment(cx, project.clone()).await;
2985
2986 // First message with context 1
2987 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2988 .await
2989 .unwrap();
2990 let new_contexts = context_store.update(cx, |store, cx| {
2991 store.new_context_for_thread(thread.read(cx), None)
2992 });
2993 assert_eq!(new_contexts.len(), 1);
2994 let loaded_context = cx
2995 .update(|cx| load_context(new_contexts, &project, &None, cx))
2996 .await;
2997 let message1_id = thread.update(cx, |thread, cx| {
2998 thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx)
2999 });
3000
3001 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
3002 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
3003 .await
3004 .unwrap();
3005 let new_contexts = context_store.update(cx, |store, cx| {
3006 store.new_context_for_thread(thread.read(cx), None)
3007 });
3008 assert_eq!(new_contexts.len(), 1);
3009 let loaded_context = cx
3010 .update(|cx| load_context(new_contexts, &project, &None, cx))
3011 .await;
3012 let message2_id = thread.update(cx, |thread, cx| {
3013 thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx)
3014 });
3015
3016 // Third message with all three contexts (contexts 1 and 2 should be skipped)
3017 //
3018 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
3019 .await
3020 .unwrap();
3021 let new_contexts = context_store.update(cx, |store, cx| {
3022 store.new_context_for_thread(thread.read(cx), None)
3023 });
3024 assert_eq!(new_contexts.len(), 1);
3025 let loaded_context = cx
3026 .update(|cx| load_context(new_contexts, &project, &None, cx))
3027 .await;
3028 let message3_id = thread.update(cx, |thread, cx| {
3029 thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx)
3030 });
3031
3032 // Check what contexts are included in each message
3033 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
3034 (
3035 thread.message(message1_id).unwrap().clone(),
3036 thread.message(message2_id).unwrap().clone(),
3037 thread.message(message3_id).unwrap().clone(),
3038 )
3039 });
3040
3041 // First message should include context 1
3042 assert!(message1.loaded_context.text.contains("file1.rs"));
3043
3044 // Second message should include only context 2 (not 1)
3045 assert!(!message2.loaded_context.text.contains("file1.rs"));
3046 assert!(message2.loaded_context.text.contains("file2.rs"));
3047
3048 // Third message should include only context 3 (not 1 or 2)
3049 assert!(!message3.loaded_context.text.contains("file1.rs"));
3050 assert!(!message3.loaded_context.text.contains("file2.rs"));
3051 assert!(message3.loaded_context.text.contains("file3.rs"));
3052
3053 // Check entire request to make sure all contexts are properly included
3054 let request = thread.update(cx, |thread, cx| {
3055 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3056 });
3057
3058 // The request should contain all 3 messages
3059 assert_eq!(request.messages.len(), 4);
3060
3061 // Check that the contexts are properly formatted in each message
3062 assert!(request.messages[1].string_contents().contains("file1.rs"));
3063 assert!(!request.messages[1].string_contents().contains("file2.rs"));
3064 assert!(!request.messages[1].string_contents().contains("file3.rs"));
3065
3066 assert!(!request.messages[2].string_contents().contains("file1.rs"));
3067 assert!(request.messages[2].string_contents().contains("file2.rs"));
3068 assert!(!request.messages[2].string_contents().contains("file3.rs"));
3069
3070 assert!(!request.messages[3].string_contents().contains("file1.rs"));
3071 assert!(!request.messages[3].string_contents().contains("file2.rs"));
3072 assert!(request.messages[3].string_contents().contains("file3.rs"));
3073
3074 add_file_to_context(&project, &context_store, "test/file4.rs", cx)
3075 .await
3076 .unwrap();
3077 let new_contexts = context_store.update(cx, |store, cx| {
3078 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3079 });
3080 assert_eq!(new_contexts.len(), 3);
3081 let loaded_context = cx
3082 .update(|cx| load_context(new_contexts, &project, &None, cx))
3083 .await
3084 .loaded_context;
3085
3086 assert!(!loaded_context.text.contains("file1.rs"));
3087 assert!(loaded_context.text.contains("file2.rs"));
3088 assert!(loaded_context.text.contains("file3.rs"));
3089 assert!(loaded_context.text.contains("file4.rs"));
3090
3091 let new_contexts = context_store.update(cx, |store, cx| {
3092 // Remove file4.rs
3093 store.remove_context(&loaded_context.contexts[2].handle(), cx);
3094 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3095 });
3096 assert_eq!(new_contexts.len(), 2);
3097 let loaded_context = cx
3098 .update(|cx| load_context(new_contexts, &project, &None, cx))
3099 .await
3100 .loaded_context;
3101
3102 assert!(!loaded_context.text.contains("file1.rs"));
3103 assert!(loaded_context.text.contains("file2.rs"));
3104 assert!(loaded_context.text.contains("file3.rs"));
3105 assert!(!loaded_context.text.contains("file4.rs"));
3106
3107 let new_contexts = context_store.update(cx, |store, cx| {
3108 // Remove file3.rs
3109 store.remove_context(&loaded_context.contexts[1].handle(), cx);
3110 store.new_context_for_thread(thread.read(cx), Some(message2_id))
3111 });
3112 assert_eq!(new_contexts.len(), 1);
3113 let loaded_context = cx
3114 .update(|cx| load_context(new_contexts, &project, &None, cx))
3115 .await
3116 .loaded_context;
3117
3118 assert!(!loaded_context.text.contains("file1.rs"));
3119 assert!(loaded_context.text.contains("file2.rs"));
3120 assert!(!loaded_context.text.contains("file3.rs"));
3121 assert!(!loaded_context.text.contains("file4.rs"));
3122 }
3123
3124 #[gpui::test]
3125 async fn test_message_without_files(cx: &mut TestAppContext) {
3126 init_test_settings(cx);
3127
3128 let project = create_test_project(
3129 cx,
3130 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3131 )
3132 .await;
3133
3134 let (_, _thread_store, thread, _context_store, model) =
3135 setup_test_environment(cx, project.clone()).await;
3136
3137 // Insert user message without any context (empty context vector)
3138 let message_id = thread.update(cx, |thread, cx| {
3139 thread.insert_user_message(
3140 "What is the best way to learn Rust?",
3141 ContextLoadResult::default(),
3142 None,
3143 Vec::new(),
3144 cx,
3145 )
3146 });
3147
3148 // Check content and context in message object
3149 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
3150
3151 // Context should be empty when no files are included
3152 assert_eq!(message.role, Role::User);
3153 assert_eq!(message.segments.len(), 1);
3154 assert_eq!(
3155 message.segments[0],
3156 MessageSegment::Text("What is the best way to learn Rust?".to_string())
3157 );
3158 assert_eq!(message.loaded_context.text, "");
3159
3160 // Check message in request
3161 let request = thread.update(cx, |thread, cx| {
3162 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3163 });
3164
3165 assert_eq!(request.messages.len(), 2);
3166 assert_eq!(
3167 request.messages[1].string_contents(),
3168 "What is the best way to learn Rust?"
3169 );
3170
3171 // Add second message, also without context
3172 let message2_id = thread.update(cx, |thread, cx| {
3173 thread.insert_user_message(
3174 "Are there any good books?",
3175 ContextLoadResult::default(),
3176 None,
3177 Vec::new(),
3178 cx,
3179 )
3180 });
3181
3182 let message2 =
3183 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
3184 assert_eq!(message2.loaded_context.text, "");
3185
3186 // Check that both messages appear in the request
3187 let request = thread.update(cx, |thread, cx| {
3188 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3189 });
3190
3191 assert_eq!(request.messages.len(), 3);
3192 assert_eq!(
3193 request.messages[1].string_contents(),
3194 "What is the best way to learn Rust?"
3195 );
3196 assert_eq!(
3197 request.messages[2].string_contents(),
3198 "Are there any good books?"
3199 );
3200 }
3201
3202 #[gpui::test]
3203 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
3204 init_test_settings(cx);
3205
3206 let project = create_test_project(
3207 cx,
3208 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3209 )
3210 .await;
3211
3212 let (_workspace, _thread_store, thread, context_store, model) =
3213 setup_test_environment(cx, project.clone()).await;
3214
3215 // Open buffer and add it to context
3216 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
3217 .await
3218 .unwrap();
3219
3220 let context =
3221 context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap());
3222 let loaded_context = cx
3223 .update(|cx| load_context(vec![context], &project, &None, cx))
3224 .await;
3225
3226 // Insert user message with the buffer as context
3227 thread.update(cx, |thread, cx| {
3228 thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx)
3229 });
3230
3231 // Create a request and check that it doesn't have a stale buffer warning yet
3232 let initial_request = thread.update(cx, |thread, cx| {
3233 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3234 });
3235
3236 // Make sure we don't have a stale file warning yet
3237 let has_stale_warning = initial_request.messages.iter().any(|msg| {
3238 msg.string_contents()
3239 .contains("These files changed since last read:")
3240 });
3241 assert!(
3242 !has_stale_warning,
3243 "Should not have stale buffer warning before buffer is modified"
3244 );
3245
3246 // Modify the buffer
3247 buffer.update(cx, |buffer, cx| {
3248 // Find a position at the end of line 1
3249 buffer.edit(
3250 [(1..1, "\n println!(\"Added a new line\");\n")],
3251 None,
3252 cx,
3253 );
3254 });
3255
3256 // Insert another user message without context
3257 thread.update(cx, |thread, cx| {
3258 thread.insert_user_message(
3259 "What does the code do now?",
3260 ContextLoadResult::default(),
3261 None,
3262 Vec::new(),
3263 cx,
3264 )
3265 });
3266
3267 // Create a new request and check for the stale buffer warning
3268 let new_request = thread.update(cx, |thread, cx| {
3269 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3270 });
3271
3272 // We should have a stale file warning as the last message
3273 let last_message = new_request
3274 .messages
3275 .last()
3276 .expect("Request should have messages");
3277
3278 // The last message should be the stale buffer notification
3279 assert_eq!(last_message.role, Role::User);
3280
3281 // Check the exact content of the message
3282 let expected_content = "These files changed since last read:\n- code.rs\n";
3283 assert_eq!(
3284 last_message.string_contents(),
3285 expected_content,
3286 "Last message should be exactly the stale buffer notification"
3287 );
3288 }
3289
3290 #[gpui::test]
3291 async fn test_temperature_setting(cx: &mut TestAppContext) {
3292 init_test_settings(cx);
3293
3294 let project = create_test_project(
3295 cx,
3296 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
3297 )
3298 .await;
3299
3300 let (_workspace, _thread_store, thread, _context_store, model) =
3301 setup_test_environment(cx, project.clone()).await;
3302
3303 // Both model and provider
3304 cx.update(|cx| {
3305 AgentSettings::override_global(
3306 AgentSettings {
3307 model_parameters: vec![LanguageModelParameters {
3308 provider: Some(model.provider_id().0.to_string().into()),
3309 model: Some(model.id().0.clone()),
3310 temperature: Some(0.66),
3311 }],
3312 ..AgentSettings::get_global(cx).clone()
3313 },
3314 cx,
3315 );
3316 });
3317
3318 let request = thread.update(cx, |thread, cx| {
3319 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3320 });
3321 assert_eq!(request.temperature, Some(0.66));
3322
3323 // Only model
3324 cx.update(|cx| {
3325 AgentSettings::override_global(
3326 AgentSettings {
3327 model_parameters: vec![LanguageModelParameters {
3328 provider: None,
3329 model: Some(model.id().0.clone()),
3330 temperature: Some(0.66),
3331 }],
3332 ..AgentSettings::get_global(cx).clone()
3333 },
3334 cx,
3335 );
3336 });
3337
3338 let request = thread.update(cx, |thread, cx| {
3339 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3340 });
3341 assert_eq!(request.temperature, Some(0.66));
3342
3343 // Only provider
3344 cx.update(|cx| {
3345 AgentSettings::override_global(
3346 AgentSettings {
3347 model_parameters: vec![LanguageModelParameters {
3348 provider: Some(model.provider_id().0.to_string().into()),
3349 model: None,
3350 temperature: Some(0.66),
3351 }],
3352 ..AgentSettings::get_global(cx).clone()
3353 },
3354 cx,
3355 );
3356 });
3357
3358 let request = thread.update(cx, |thread, cx| {
3359 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3360 });
3361 assert_eq!(request.temperature, Some(0.66));
3362
3363 // Same model name, different provider
3364 cx.update(|cx| {
3365 AgentSettings::override_global(
3366 AgentSettings {
3367 model_parameters: vec![LanguageModelParameters {
3368 provider: Some("anthropic".into()),
3369 model: Some(model.id().0.clone()),
3370 temperature: Some(0.66),
3371 }],
3372 ..AgentSettings::get_global(cx).clone()
3373 },
3374 cx,
3375 );
3376 });
3377
3378 let request = thread.update(cx, |thread, cx| {
3379 thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
3380 });
3381 assert_eq!(request.temperature, None);
3382 }
3383
3384 #[gpui::test]
3385 async fn test_thread_summary(cx: &mut TestAppContext) {
3386 init_test_settings(cx);
3387
3388 let project = create_test_project(cx, json!({})).await;
3389
3390 let (_, _thread_store, thread, _context_store, model) =
3391 setup_test_environment(cx, project.clone()).await;
3392
3393 // Initial state should be pending
3394 thread.read_with(cx, |thread, _| {
3395 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3396 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3397 });
3398
3399 // Manually setting the summary should not be allowed in this state
3400 thread.update(cx, |thread, cx| {
3401 thread.set_summary("This should not work", cx);
3402 });
3403
3404 thread.read_with(cx, |thread, _| {
3405 assert!(matches!(thread.summary(), ThreadSummary::Pending));
3406 });
3407
3408 // Send a message
3409 thread.update(cx, |thread, cx| {
3410 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3411 thread.send_to_model(
3412 model.clone(),
3413 CompletionIntent::ThreadSummarization,
3414 None,
3415 cx,
3416 );
3417 });
3418
3419 let fake_model = model.as_fake();
3420 simulate_successful_response(&fake_model, cx);
3421
3422 // Should start generating summary when there are >= 2 messages
3423 thread.read_with(cx, |thread, _| {
3424 assert_eq!(*thread.summary(), ThreadSummary::Generating);
3425 });
3426
3427 // Should not be able to set the summary while generating
3428 thread.update(cx, |thread, cx| {
3429 thread.set_summary("This should not work either", cx);
3430 });
3431
3432 thread.read_with(cx, |thread, _| {
3433 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3434 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3435 });
3436
3437 cx.run_until_parked();
3438 fake_model.stream_last_completion_response("Brief");
3439 fake_model.stream_last_completion_response(" Introduction");
3440 fake_model.end_last_completion_stream();
3441 cx.run_until_parked();
3442
3443 // Summary should be set
3444 thread.read_with(cx, |thread, _| {
3445 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3446 assert_eq!(thread.summary().or_default(), "Brief Introduction");
3447 });
3448
3449 // Now we should be able to set a summary
3450 thread.update(cx, |thread, cx| {
3451 thread.set_summary("Brief Intro", cx);
3452 });
3453
3454 thread.read_with(cx, |thread, _| {
3455 assert_eq!(thread.summary().or_default(), "Brief Intro");
3456 });
3457
3458 // Test setting an empty summary (should default to DEFAULT)
3459 thread.update(cx, |thread, cx| {
3460 thread.set_summary("", cx);
3461 });
3462
3463 thread.read_with(cx, |thread, _| {
3464 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3465 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3466 });
3467 }
3468
3469 #[gpui::test]
3470 async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
3471 init_test_settings(cx);
3472
3473 let project = create_test_project(cx, json!({})).await;
3474
3475 let (_, _thread_store, thread, _context_store, model) =
3476 setup_test_environment(cx, project.clone()).await;
3477
3478 test_summarize_error(&model, &thread, cx);
3479
3480 // Now we should be able to set a summary
3481 thread.update(cx, |thread, cx| {
3482 thread.set_summary("Brief Intro", cx);
3483 });
3484
3485 thread.read_with(cx, |thread, _| {
3486 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3487 assert_eq!(thread.summary().or_default(), "Brief Intro");
3488 });
3489 }
3490
3491 #[gpui::test]
3492 async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
3493 init_test_settings(cx);
3494
3495 let project = create_test_project(cx, json!({})).await;
3496
3497 let (_, _thread_store, thread, _context_store, model) =
3498 setup_test_environment(cx, project.clone()).await;
3499
3500 test_summarize_error(&model, &thread, cx);
3501
3502 // Sending another message should not trigger another summarize request
3503 thread.update(cx, |thread, cx| {
3504 thread.insert_user_message(
3505 "How are you?",
3506 ContextLoadResult::default(),
3507 None,
3508 vec![],
3509 cx,
3510 );
3511 thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
3512 });
3513
3514 let fake_model = model.as_fake();
3515 simulate_successful_response(&fake_model, cx);
3516
3517 thread.read_with(cx, |thread, _| {
3518 // State is still Error, not Generating
3519 assert!(matches!(thread.summary(), ThreadSummary::Error));
3520 });
3521
3522 // But the summarize request can be invoked manually
3523 thread.update(cx, |thread, cx| {
3524 thread.summarize(cx);
3525 });
3526
3527 thread.read_with(cx, |thread, _| {
3528 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3529 });
3530
3531 cx.run_until_parked();
3532 fake_model.stream_last_completion_response("A successful summary");
3533 fake_model.end_last_completion_stream();
3534 cx.run_until_parked();
3535
3536 thread.read_with(cx, |thread, _| {
3537 assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
3538 assert_eq!(thread.summary().or_default(), "A successful summary");
3539 });
3540 }
3541
3542 fn test_summarize_error(
3543 model: &Arc<dyn LanguageModel>,
3544 thread: &Entity<Thread>,
3545 cx: &mut TestAppContext,
3546 ) {
3547 thread.update(cx, |thread, cx| {
3548 thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
3549 thread.send_to_model(
3550 model.clone(),
3551 CompletionIntent::ThreadSummarization,
3552 None,
3553 cx,
3554 );
3555 });
3556
3557 let fake_model = model.as_fake();
3558 simulate_successful_response(&fake_model, cx);
3559
3560 thread.read_with(cx, |thread, _| {
3561 assert!(matches!(thread.summary(), ThreadSummary::Generating));
3562 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3563 });
3564
3565 // Simulate summary request ending
3566 cx.run_until_parked();
3567 fake_model.end_last_completion_stream();
3568 cx.run_until_parked();
3569
3570 // State is set to Error and default message
3571 thread.read_with(cx, |thread, _| {
3572 assert!(matches!(thread.summary(), ThreadSummary::Error));
3573 assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
3574 });
3575 }
3576
3577 fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
3578 cx.run_until_parked();
3579 fake_model.stream_last_completion_response("Assistant response");
3580 fake_model.end_last_completion_stream();
3581 cx.run_until_parked();
3582 }
3583
3584 fn init_test_settings(cx: &mut TestAppContext) {
3585 cx.update(|cx| {
3586 let settings_store = SettingsStore::test(cx);
3587 cx.set_global(settings_store);
3588 language::init(cx);
3589 Project::init_settings(cx);
3590 AgentSettings::register(cx);
3591 prompt_store::init(cx);
3592 thread_store::init(cx);
3593 workspace::init_settings(cx);
3594 language_model::init_settings(cx);
3595 ThemeSettings::register(cx);
3596 EditorSettings::register(cx);
3597 ToolRegistry::default_global(cx);
3598 });
3599 }
3600
3601 // Helper to create a test project with test files
3602 async fn create_test_project(
3603 cx: &mut TestAppContext,
3604 files: serde_json::Value,
3605 ) -> Entity<Project> {
3606 let fs = FakeFs::new(cx.executor());
3607 fs.insert_tree(path!("/test"), files).await;
3608 Project::test(fs, [path!("/test").as_ref()], cx).await
3609 }
3610
3611 async fn setup_test_environment(
3612 cx: &mut TestAppContext,
3613 project: Entity<Project>,
3614 ) -> (
3615 Entity<Workspace>,
3616 Entity<ThreadStore>,
3617 Entity<Thread>,
3618 Entity<ContextStore>,
3619 Arc<dyn LanguageModel>,
3620 ) {
3621 let (workspace, cx) =
3622 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
3623
3624 let thread_store = cx
3625 .update(|_, cx| {
3626 ThreadStore::load(
3627 project.clone(),
3628 cx.new(|_| ToolWorkingSet::default()),
3629 None,
3630 Arc::new(PromptBuilder::new(None).unwrap()),
3631 cx,
3632 )
3633 })
3634 .await
3635 .unwrap();
3636
3637 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
3638 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
3639
3640 let provider = Arc::new(FakeLanguageModelProvider);
3641 let model = provider.test_model();
3642 let model: Arc<dyn LanguageModel> = Arc::new(model);
3643
3644 cx.update(|_, cx| {
3645 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
3646 registry.set_default_model(
3647 Some(ConfiguredModel {
3648 provider: provider.clone(),
3649 model: model.clone(),
3650 }),
3651 cx,
3652 );
3653 registry.set_thread_summary_model(
3654 Some(ConfiguredModel {
3655 provider,
3656 model: model.clone(),
3657 }),
3658 cx,
3659 );
3660 })
3661 });
3662
3663 (workspace, thread_store, thread, context_store, model)
3664 }
3665
3666 async fn add_file_to_context(
3667 project: &Entity<Project>,
3668 context_store: &Entity<ContextStore>,
3669 path: &str,
3670 cx: &mut TestAppContext,
3671 ) -> Result<Entity<language::Buffer>> {
3672 let buffer_path = project
3673 .read_with(cx, |project, cx| project.find_project_path(path, cx))
3674 .unwrap();
3675
3676 let buffer = project
3677 .update(cx, |project, cx| {
3678 project.open_buffer(buffer_path.clone(), cx)
3679 })
3680 .await
3681 .unwrap();
3682
3683 context_store.update(cx, |context_store, cx| {
3684 context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
3685 });
3686
3687 Ok(buffer)
3688 }
3689}