1use crate::{
2 AgentThread, AgentThreadId, AgentThreadMessageId, AgentThreadUserMessageChunk,
3 agent_profile::AgentProfile,
4 context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
5 thread_store::{SharedProjectContext, ThreadStore},
6};
7use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
8use anyhow::{Result, anyhow};
9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use client::{ModelRequestUsage, RequestUsage};
12use collections::{HashMap, HashSet};
13use feature_flags::{self, FeatureFlagAppExt};
14use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared};
15use git::repository::DiffType;
16use gpui::{
17 AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
18 WeakEntity,
19};
20use language_model::{
21 ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
24 LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
25 ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
26};
27use postage::stream::Stream as _;
28use project::{
29 Project,
30 git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
31};
32use prompt_store::{ModelContext, PromptBuilder};
33use proto::Plan;
34use serde::{Deserialize, Serialize};
35use settings::Settings;
36use std::{
37 io::Write,
38 ops::Range,
39 sync::Arc,
40 time::{Duration, Instant},
41};
42use thiserror::Error;
43use util::{ResultExt as _, post_inc};
44use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
45
46/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
47#[derive(Clone, Debug)]
48pub struct MessageCrease {
49 pub range: Range<usize>,
50 pub icon_path: SharedString,
51 pub label: SharedString,
52 /// None for a deserialized message, Some otherwise.
53 pub context: Option<AgentContextHandle>,
54}
55
56pub enum MessageTool {
57 Pending {
58 tool: Arc<dyn Tool>,
59 input: serde_json::Value,
60 },
61 NeedsConfirmation {
62 tool: Arc<dyn Tool>,
63 input_json: serde_json::Value,
64 confirm_tx: oneshot::Sender<bool>,
65 },
66 Confirmed {
67 card: AnyToolCard,
68 },
69 Declined {
70 tool: Arc<dyn Tool>,
71 input_json: serde_json::Value,
72 },
73}
74
75/// A message in a [`Thread`].
76pub struct Message {
77 pub id: AgentThreadMessageId,
78 pub role: Role,
79 pub thinking: String,
80 pub text: String,
81 pub tools: Vec<MessageTool>,
82 pub loaded_context: LoadedContext,
83 pub creases: Vec<MessageCrease>,
84 pub is_hidden: bool,
85 pub ui_only: bool,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
89pub struct ProjectSnapshot {
90 pub worktree_snapshots: Vec<WorktreeSnapshot>,
91 pub unsaved_buffer_paths: Vec<String>,
92 pub timestamp: DateTime<Utc>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
96pub struct WorktreeSnapshot {
97 pub worktree_path: String,
98 pub git_state: Option<GitState>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
102pub struct GitState {
103 pub remote_url: Option<String>,
104 pub head_sha: Option<String>,
105 pub current_branch: Option<String>,
106 pub diff: Option<String>,
107}
108
109#[derive(Clone, Debug)]
110pub struct ThreadCheckpoint {
111 message_id: AgentThreadMessageId,
112 git_checkpoint: GitStoreCheckpoint,
113}
114
115#[derive(Copy, Clone, Debug, PartialEq, Eq)]
116pub enum ThreadFeedback {
117 Positive,
118 Negative,
119}
120
121pub enum LastRestoreCheckpoint {
122 Pending {
123 message_id: AgentThreadMessageId,
124 },
125 Error {
126 message_id: AgentThreadMessageId,
127 error: String,
128 },
129}
130
131impl LastRestoreCheckpoint {
132 pub fn message_id(&self) -> AgentThreadMessageId {
133 match self {
134 LastRestoreCheckpoint::Pending { message_id } => *message_id,
135 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
136 }
137 }
138}
139
140#[derive(Clone, Debug, Default)]
141pub enum DetailedSummaryState {
142 #[default]
143 NotGenerated,
144 Generating {
145 message_id: AgentThreadMessageId,
146 },
147 Generated {
148 text: SharedString,
149 message_id: AgentThreadMessageId,
150 },
151}
152
153impl DetailedSummaryState {
154 fn text(&self) -> Option<SharedString> {
155 if let Self::Generated { text, .. } = self {
156 Some(text.clone())
157 } else {
158 None
159 }
160 }
161}
162
163#[derive(Default, Debug)]
164pub struct TotalTokenUsage {
165 pub total: u64,
166 pub max: u64,
167}
168
169impl TotalTokenUsage {
170 pub fn ratio(&self) -> TokenUsageRatio {
171 #[cfg(debug_assertions)]
172 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
173 .unwrap_or("0.8".to_string())
174 .parse()
175 .unwrap();
176 #[cfg(not(debug_assertions))]
177 let warning_threshold: f32 = 0.8;
178
179 // When the maximum is unknown because there is no selected model,
180 // avoid showing the token limit warning.
181 if self.max == 0 {
182 TokenUsageRatio::Normal
183 } else if self.total >= self.max {
184 TokenUsageRatio::Exceeded
185 } else if self.total as f32 / self.max as f32 >= warning_threshold {
186 TokenUsageRatio::Warning
187 } else {
188 TokenUsageRatio::Normal
189 }
190 }
191
192 pub fn add(&self, tokens: u64) -> TotalTokenUsage {
193 TotalTokenUsage {
194 total: self.total + tokens,
195 max: self.max,
196 }
197 }
198}
199
200#[derive(Debug, Default, PartialEq, Eq)]
201pub enum TokenUsageRatio {
202 #[default]
203 Normal,
204 Warning,
205 Exceeded,
206}
207
208#[derive(Debug, Clone, Copy)]
209pub enum QueueState {
210 Sending,
211 Queued { position: usize },
212 Started,
213}
214
215/// A thread of conversation with the LLM.
216pub struct Thread {
217 agent_thread: Arc<dyn AgentThread>,
218 summary: ThreadSummary,
219 pending_send: Option<Task<Result<()>>>,
220 pending_summary: Task<Option<()>>,
221 detailed_summary_task: Task<Option<()>>,
222 detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
223 detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
224 completion_mode: agent_settings::CompletionMode,
225 messages: Vec<Message>,
226 checkpoints_by_message: HashMap<AgentThreadMessageId, ThreadCheckpoint>,
227 project: Entity<Project>,
228 action_log: Entity<ActionLog>,
229 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
230 pending_checkpoint: Option<ThreadCheckpoint>,
231 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
232 request_token_usage: Vec<TokenUsage>,
233 cumulative_token_usage: TokenUsage,
234 exceeded_window_error: Option<ExceededWindowError>,
235 tool_use_limit_reached: bool,
236 // todo!(keep track of retries from the underlying agent)
237 feedback: Option<ThreadFeedback>,
238 message_feedback: HashMap<AgentThreadMessageId, ThreadFeedback>,
239 last_auto_capture_at: Option<Instant>,
240 last_received_chunk_at: Option<Instant>,
241}
242
243#[derive(Clone, Debug, PartialEq, Eq)]
244pub enum ThreadSummary {
245 Pending,
246 Generating,
247 Ready(SharedString),
248 Error,
249}
250
251impl ThreadSummary {
252 pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
253
254 pub fn or_default(&self) -> SharedString {
255 self.unwrap_or(Self::DEFAULT)
256 }
257
258 pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
259 self.ready().unwrap_or_else(|| message.into())
260 }
261
262 pub fn ready(&self) -> Option<SharedString> {
263 match self {
264 ThreadSummary::Ready(summary) => Some(summary.clone()),
265 ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
266 }
267 }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
271pub struct ExceededWindowError {
272 /// Model used when last message exceeded context window
273 model_id: LanguageModelId,
274 /// Token count including last message
275 token_count: u64,
276}
277
278impl Thread {
279 pub fn load(
280 agent_thread: Arc<dyn AgentThread>,
281 project: Entity<Project>,
282 cx: &mut Context<Self>,
283 ) -> Self {
284 let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
285 Self {
286 agent_thread,
287 summary: ThreadSummary::Pending,
288 pending_send: None,
289 pending_summary: Task::ready(None),
290 detailed_summary_task: Task::ready(None),
291 detailed_summary_tx,
292 detailed_summary_rx,
293 completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
294 messages: todo!("read from agent"),
295 checkpoints_by_message: HashMap::default(),
296 project: project.clone(),
297 last_restore_checkpoint: None,
298 pending_checkpoint: None,
299 action_log: cx.new(|_| ActionLog::new(project.clone())),
300 initial_project_snapshot: {
301 let project_snapshot = Self::project_snapshot(project, cx);
302 cx.foreground_executor()
303 .spawn(async move { Some(project_snapshot.await) })
304 .shared()
305 },
306 request_token_usage: Vec::new(),
307 cumulative_token_usage: TokenUsage::default(),
308 exceeded_window_error: None,
309 tool_use_limit_reached: false,
310 feedback: None,
311 message_feedback: HashMap::default(),
312 last_auto_capture_at: None,
313 last_received_chunk_at: None,
314 }
315 }
316
317 pub fn id(&self) -> AgentThreadId {
318 self.agent_thread.id()
319 }
320
321 pub fn profile(&self) -> &AgentProfile {
322 todo!()
323 }
324
325 pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
326 todo!()
327 // if &id != self.profile.id() {
328 // self.profile = AgentProfile::new(id, self.tools.clone());
329 // cx.emit(ThreadEvent::ProfileChanged);
330 // }
331 }
332
333 pub fn is_empty(&self) -> bool {
334 self.messages.is_empty()
335 }
336
337 pub fn advance_prompt_id(&mut self) {
338 todo!()
339 // self.last_prompt_id = PromptId::new();
340 }
341
342 pub fn project_context(&self) -> SharedProjectContext {
343 todo!()
344 // self.project_context.clone()
345 }
346
347 pub fn summary(&self) -> &ThreadSummary {
348 &self.summary
349 }
350
351 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
352 todo!()
353 // let current_summary = match &self.summary {
354 // ThreadSummary::Pending | ThreadSummary::Generating => return,
355 // ThreadSummary::Ready(summary) => summary,
356 // ThreadSummary::Error => &ThreadSummary::DEFAULT,
357 // };
358
359 // let mut new_summary = new_summary.into();
360
361 // if new_summary.is_empty() {
362 // new_summary = ThreadSummary::DEFAULT;
363 // }
364
365 // if current_summary != &new_summary {
366 // self.summary = ThreadSummary::Ready(new_summary);
367 // cx.emit(ThreadEvent::SummaryChanged);
368 // }
369 }
370
371 pub fn completion_mode(&self) -> CompletionMode {
372 self.completion_mode
373 }
374
375 pub fn set_completion_mode(&mut self, mode: CompletionMode) {
376 self.completion_mode = mode;
377 }
378
379 pub fn message(&self, id: AgentThreadMessageId) -> Option<&Message> {
380 let index = self
381 .messages
382 .binary_search_by(|message| message.id.cmp(&id))
383 .ok()?;
384
385 self.messages.get(index)
386 }
387
388 pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
389 self.messages.iter()
390 }
391
392 pub fn is_generating(&self) -> bool {
393 self.pending_send.is_some()
394 }
395
396 /// Indicates whether streaming of language model events is stale.
397 /// When `is_generating()` is false, this method returns `None`.
398 pub fn is_generation_stale(&self) -> Option<bool> {
399 const STALE_THRESHOLD: u128 = 250;
400
401 self.last_received_chunk_at
402 .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
403 }
404
405 fn received_chunk(&mut self) {
406 self.last_received_chunk_at = Some(Instant::now());
407 }
408
409 pub fn checkpoint_for_message(&self, id: AgentThreadMessageId) -> Option<ThreadCheckpoint> {
410 self.checkpoints_by_message.get(&id).cloned()
411 }
412
413 pub fn restore_checkpoint(
414 &mut self,
415 checkpoint: ThreadCheckpoint,
416 cx: &mut Context<Self>,
417 ) -> Task<Result<()>> {
418 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
419 message_id: checkpoint.message_id,
420 });
421 cx.emit(ThreadEvent::CheckpointChanged);
422 cx.notify();
423
424 let git_store = self.project().read(cx).git_store().clone();
425 let restore = git_store.update(cx, |git_store, cx| {
426 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
427 });
428
429 cx.spawn(async move |this, cx| {
430 let result = restore.await;
431 this.update(cx, |this, cx| {
432 if let Err(err) = result.as_ref() {
433 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
434 message_id: checkpoint.message_id,
435 error: err.to_string(),
436 });
437 } else {
438 this.truncate(checkpoint.message_id, cx);
439 this.last_restore_checkpoint = None;
440 }
441 this.pending_checkpoint = None;
442 cx.emit(ThreadEvent::CheckpointChanged);
443 cx.notify();
444 })?;
445 result
446 })
447 }
448
449 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
450 let pending_checkpoint = if self.is_generating() {
451 return;
452 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
453 checkpoint
454 } else {
455 return;
456 };
457
458 self.finalize_checkpoint(pending_checkpoint, cx);
459 }
460
461 fn finalize_checkpoint(
462 &mut self,
463 pending_checkpoint: ThreadCheckpoint,
464 cx: &mut Context<Self>,
465 ) {
466 let git_store = self.project.read(cx).git_store().clone();
467 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
468 cx.spawn(async move |this, cx| match final_checkpoint.await {
469 Ok(final_checkpoint) => {
470 let equal = git_store
471 .update(cx, |store, cx| {
472 store.compare_checkpoints(
473 pending_checkpoint.git_checkpoint.clone(),
474 final_checkpoint.clone(),
475 cx,
476 )
477 })?
478 .await
479 .unwrap_or(false);
480
481 if !equal {
482 this.update(cx, |this, cx| {
483 this.insert_checkpoint(pending_checkpoint, cx)
484 })?;
485 }
486
487 Ok(())
488 }
489 Err(_) => this.update(cx, |this, cx| {
490 this.insert_checkpoint(pending_checkpoint, cx)
491 }),
492 })
493 .detach();
494 }
495
496 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
497 self.checkpoints_by_message
498 .insert(checkpoint.message_id, checkpoint);
499 cx.emit(ThreadEvent::CheckpointChanged);
500 cx.notify();
501 }
502
503 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
504 self.last_restore_checkpoint.as_ref()
505 }
506
507 pub fn truncate(&mut self, message_id: AgentThreadMessageId, cx: &mut Context<Self>) {
508 todo!("call truncate on the agent");
509 let Some(message_ix) = self
510 .messages
511 .iter()
512 .rposition(|message| message.id == message_id)
513 else {
514 return;
515 };
516 for deleted_message in self.messages.drain(message_ix..) {
517 self.checkpoints_by_message.remove(&deleted_message.id);
518 }
519 cx.notify();
520 }
521
522 pub fn is_turn_end(&self, ix: usize) -> bool {
523 todo!()
524 // if self.messages.is_empty() {
525 // return false;
526 // }
527
528 // if !self.is_generating() && ix == self.messages.len() - 1 {
529 // return true;
530 // }
531
532 // let Some(message) = self.messages.get(ix) else {
533 // return false;
534 // };
535
536 // if message.role != Role::Assistant {
537 // return false;
538 // }
539
540 // self.messages
541 // .get(ix + 1)
542 // .and_then(|message| {
543 // self.message(message.id)
544 // .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
545 // })
546 // .unwrap_or(false)
547 }
548
549 pub fn tool_use_limit_reached(&self) -> bool {
550 self.tool_use_limit_reached
551 }
552
553 /// Returns whether any pending tool uses may perform edits
554 pub fn has_pending_edit_tool_uses(&self) -> bool {
555 todo!()
556 }
557
558 // pub fn insert_user_message(
559 // &mut self,
560 // text: impl Into<String>,
561 // loaded_context: ContextLoadResult,
562 // git_checkpoint: Option<GitStoreCheckpoint>,
563 // creases: Vec<MessageCrease>,
564 // cx: &mut Context<Self>,
565 // ) -> AgentThreadMessageId {
566 // todo!("move this logic into send")
567 // if !loaded_context.referenced_buffers.is_empty() {
568 // self.action_log.update(cx, |log, cx| {
569 // for buffer in loaded_context.referenced_buffers {
570 // log.buffer_read(buffer, cx);
571 // }
572 // });
573 // }
574
575 // let message_id = self.insert_message(
576 // Role::User,
577 // vec![MessageSegment::Text(text.into())],
578 // loaded_context.loaded_context,
579 // creases,
580 // false,
581 // cx,
582 // );
583
584 // if let Some(git_checkpoint) = git_checkpoint {
585 // self.pending_checkpoint = Some(ThreadCheckpoint {
586 // message_id,
587 // git_checkpoint,
588 // });
589 // }
590
591 // self.auto_capture_telemetry(cx);
592
593 // message_id
594 // }
595
596 pub fn send(&mut self, message: Vec<AgentThreadUserMessageChunk>, cx: &mut Context<Self>) {}
597
598 pub fn resume(&mut self, cx: &mut Context<Self>) {
599 todo!()
600 }
601
602 pub fn edit(
603 &mut self,
604 message_id: AgentThreadMessageId,
605 message: Vec<AgentThreadUserMessageChunk>,
606 cx: &mut Context<Self>,
607 ) {
608 todo!()
609 }
610
611 pub fn cancel(&mut self, cx: &mut Context<Self>) {
612 todo!()
613 }
614
615 // pub fn insert_invisible_continue_message(
616 // &mut self,
617 // cx: &mut Context<Self>,
618 // ) -> AgentThreadMessageId {
619 // let id = self.insert_message(
620 // Role::User,
621 // vec![MessageSegment::Text("Continue where you left off".into())],
622 // LoadedContext::default(),
623 // vec![],
624 // true,
625 // cx,
626 // );
627 // self.pending_checkpoint = None;
628
629 // id
630 // }
631
632 // pub fn insert_assistant_message(
633 // &mut self,
634 // segments: Vec<MessageSegment>,
635 // cx: &mut Context<Self>,
636 // ) -> AgentThreadMessageId {
637 // self.insert_message(
638 // Role::Assistant,
639 // segments,
640 // LoadedContext::default(),
641 // Vec::new(),
642 // false,
643 // cx,
644 // )
645 // }
646
647 // pub fn insert_message(
648 // &mut self,
649 // role: Role,
650 // segments: Vec<MessageSegment>,
651 // loaded_context: LoadedContext,
652 // creases: Vec<MessageCrease>,
653 // is_hidden: bool,
654 // cx: &mut Context<Self>,
655 // ) -> AgentThreadMessageId {
656 // let id = self.next_message_id.post_inc();
657 // self.messages.push(Message {
658 // id,
659 // role,
660 // segments,
661 // loaded_context,
662 // creases,
663 // is_hidden,
664 // ui_only: false,
665 // });
666 // self.touch_updated_at();
667 // cx.emit(ThreadEvent::MessageAdded(id));
668 // id
669 // }
670
671 // pub fn edit_message(
672 // &mut self,
673 // id: AgentThreadMessageId,
674 // new_role: Role,
675 // new_segments: Vec<MessageSegment>,
676 // creases: Vec<MessageCrease>,
677 // loaded_context: Option<LoadedContext>,
678 // checkpoint: Option<GitStoreCheckpoint>,
679 // cx: &mut Context<Self>,
680 // ) -> bool {
681 // let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
682 // return false;
683 // };
684 // message.role = new_role;
685 // message.segments = new_segments;
686 // message.creases = creases;
687 // if let Some(context) = loaded_context {
688 // message.loaded_context = context;
689 // }
690 // if let Some(git_checkpoint) = checkpoint {
691 // self.checkpoints_by_message.insert(
692 // id,
693 // ThreadCheckpoint {
694 // message_id: id,
695 // git_checkpoint,
696 // },
697 // );
698 // }
699 // self.touch_updated_at();
700 // cx.emit(ThreadEvent::MessageEdited(id));
701 // true
702 // }
703
704 /// Returns the representation of this [`Thread`] in a textual form.
705 ///
706 /// This is the representation we use when attaching a thread as context to another thread.
707 pub fn text(&self) -> String {
708 let mut text = String::new();
709
710 for message in &self.messages {
711 text.push_str(match message.role {
712 language_model::Role::User => "User:",
713 language_model::Role::Assistant => "Agent:",
714 language_model::Role::System => "System:",
715 });
716 text.push('\n');
717
718 text.push_str("<think>");
719 text.push_str(&message.thinking);
720 text.push_str("</think>");
721 text.push_str(&message.text);
722
723 // todo!('what about tools?');
724
725 text.push('\n');
726 }
727
728 text
729 }
730
731 pub fn used_tools_since_last_user_message(&self) -> bool {
732 todo!()
733 // for message in self.messages.iter().rev() {
734 // if self.tool_use.message_has_tool_results(message.id) {
735 // return true;
736 // } else if message.role == Role::User {
737 // return false;
738 // }
739 // }
740
741 // false
742 }
743
744 pub fn start_generating_detailed_summary_if_needed(
745 &mut self,
746 thread_store: WeakEntity<ThreadStore>,
747 cx: &mut Context<Self>,
748 ) {
749 let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
750 return;
751 };
752
753 match &*self.detailed_summary_rx.borrow() {
754 DetailedSummaryState::Generating { message_id, .. }
755 | DetailedSummaryState::Generated { message_id, .. }
756 if *message_id == last_message_id =>
757 {
758 // Already up-to-date
759 return;
760 }
761 _ => {}
762 }
763
764 let summary = self.agent_thread.summary();
765
766 *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
767 message_id: last_message_id,
768 };
769
770 // Replace the detailed summarization task if there is one, cancelling it. It would probably
771 // be better to allow the old task to complete, but this would require logic for choosing
772 // which result to prefer (the old task could complete after the new one, resulting in a
773 // stale summary).
774 self.detailed_summary_task = cx.spawn(async move |thread, cx| {
775 let Some(summary) = summary.await.log_err() else {
776 thread
777 .update(cx, |thread, _cx| {
778 *thread.detailed_summary_tx.borrow_mut() =
779 DetailedSummaryState::NotGenerated;
780 })
781 .ok()?;
782 return None;
783 };
784
785 thread
786 .update(cx, |thread, _cx| {
787 *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
788 text: summary.into(),
789 message_id: last_message_id,
790 };
791 })
792 .ok()?;
793
794 Some(())
795 });
796 }
797
798 pub async fn wait_for_detailed_summary_or_text(
799 this: &Entity<Self>,
800 cx: &mut AsyncApp,
801 ) -> Option<SharedString> {
802 let mut detailed_summary_rx = this
803 .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
804 .ok()?;
805 loop {
806 match detailed_summary_rx.recv().await? {
807 DetailedSummaryState::Generating { .. } => {}
808 DetailedSummaryState::NotGenerated => {
809 return this.read_with(cx, |this, _cx| this.text().into()).ok();
810 }
811 DetailedSummaryState::Generated { text, .. } => return Some(text),
812 }
813 }
814 }
815
816 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
817 self.detailed_summary_rx
818 .borrow()
819 .text()
820 .unwrap_or_else(|| self.text().into())
821 }
822
823 pub fn is_generating_detailed_summary(&self) -> bool {
824 matches!(
825 &*self.detailed_summary_rx.borrow(),
826 DetailedSummaryState::Generating { .. }
827 )
828 }
829
830 pub fn feedback(&self) -> Option<ThreadFeedback> {
831 self.feedback
832 }
833
834 pub fn message_feedback(&self, message_id: AgentThreadMessageId) -> Option<ThreadFeedback> {
835 self.message_feedback.get(&message_id).copied()
836 }
837
838 pub fn report_message_feedback(
839 &mut self,
840 message_id: AgentThreadMessageId,
841 feedback: ThreadFeedback,
842 cx: &mut Context<Self>,
843 ) -> Task<Result<()>> {
844 todo!()
845 // if self.message_feedback.get(&message_id) == Some(&feedback) {
846 // return Task::ready(Ok(()));
847 // }
848
849 // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
850 // let serialized_thread = self.serialize(cx);
851 // let thread_id = self.id().clone();
852 // let client = self.project.read(cx).client();
853
854 // let enabled_tool_names: Vec<String> = self
855 // .profile
856 // .enabled_tools(cx)
857 // .iter()
858 // .map(|tool| tool.name())
859 // .collect();
860
861 // self.message_feedback.insert(message_id, feedback);
862
863 // cx.notify();
864
865 // let message_content = self
866 // .message(message_id)
867 // .map(|msg| msg.to_string())
868 // .unwrap_or_default();
869
870 // cx.background_spawn(async move {
871 // let final_project_snapshot = final_project_snapshot.await;
872 // let serialized_thread = serialized_thread.await?;
873 // let thread_data =
874 // serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
875
876 // let rating = match feedback {
877 // ThreadFeedback::Positive => "positive",
878 // ThreadFeedback::Negative => "negative",
879 // };
880 // telemetry::event!(
881 // "Assistant Thread Rated",
882 // rating,
883 // thread_id,
884 // enabled_tool_names,
885 // message_id = message_id,
886 // message_content,
887 // thread_data,
888 // final_project_snapshot
889 // );
890 // client.telemetry().flush_events().await;
891
892 // Ok(())
893 // })
894 }
895
896 pub fn report_feedback(
897 &mut self,
898 feedback: ThreadFeedback,
899 cx: &mut Context<Self>,
900 ) -> Task<Result<()>> {
901 todo!()
902 // let last_assistant_message_id = self
903 // .messages
904 // .iter()
905 // .rev()
906 // .find(|msg| msg.role == Role::Assistant)
907 // .map(|msg| msg.id);
908
909 // if let Some(message_id) = last_assistant_message_id {
910 // self.report_message_feedback(message_id, feedback, cx)
911 // } else {
912 // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
913 // let serialized_thread = self.serialize(cx);
914 // let thread_id = self.id().clone();
915 // let client = self.project.read(cx).client();
916 // self.feedback = Some(feedback);
917 // cx.notify();
918
919 // cx.background_spawn(async move {
920 // let final_project_snapshot = final_project_snapshot.await;
921 // let serialized_thread = serialized_thread.await?;
922 // let thread_data = serde_json::to_value(serialized_thread)
923 // .unwrap_or_else(|_| serde_json::Value::Null);
924
925 // let rating = match feedback {
926 // ThreadFeedback::Positive => "positive",
927 // ThreadFeedback::Negative => "negative",
928 // };
929 // telemetry::event!(
930 // "Assistant Thread Rated",
931 // rating,
932 // thread_id,
933 // thread_data,
934 // final_project_snapshot
935 // );
936 // client.telemetry().flush_events().await;
937
938 // Ok(())
939 // })
940 // }
941 }
942
943 /// Create a snapshot of the current project state including git information and unsaved buffers.
944 fn project_snapshot(
945 project: Entity<Project>,
946 cx: &mut Context<Self>,
947 ) -> Task<Arc<ProjectSnapshot>> {
948 let git_store = project.read(cx).git_store().clone();
949 let worktree_snapshots: Vec<_> = project
950 .read(cx)
951 .visible_worktrees(cx)
952 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
953 .collect();
954
955 cx.spawn(async move |_, cx| {
956 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
957
958 let mut unsaved_buffers = Vec::new();
959 cx.update(|app_cx| {
960 let buffer_store = project.read(app_cx).buffer_store();
961 for buffer_handle in buffer_store.read(app_cx).buffers() {
962 let buffer = buffer_handle.read(app_cx);
963 if buffer.is_dirty() {
964 if let Some(file) = buffer.file() {
965 let path = file.path().to_string_lossy().to_string();
966 unsaved_buffers.push(path);
967 }
968 }
969 }
970 })
971 .ok();
972
973 Arc::new(ProjectSnapshot {
974 worktree_snapshots,
975 unsaved_buffer_paths: unsaved_buffers,
976 timestamp: Utc::now(),
977 })
978 })
979 }
980
981 fn worktree_snapshot(
982 worktree: Entity<project::Worktree>,
983 git_store: Entity<GitStore>,
984 cx: &App,
985 ) -> Task<WorktreeSnapshot> {
986 cx.spawn(async move |cx| {
987 // Get worktree path and snapshot
988 let worktree_info = cx.update(|app_cx| {
989 let worktree = worktree.read(app_cx);
990 let path = worktree.abs_path().to_string_lossy().to_string();
991 let snapshot = worktree.snapshot();
992 (path, snapshot)
993 });
994
995 let Ok((worktree_path, _snapshot)) = worktree_info else {
996 return WorktreeSnapshot {
997 worktree_path: String::new(),
998 git_state: None,
999 };
1000 };
1001
1002 let git_state = git_store
1003 .update(cx, |git_store, cx| {
1004 git_store
1005 .repositories()
1006 .values()
1007 .find(|repo| {
1008 repo.read(cx)
1009 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1010 .is_some()
1011 })
1012 .cloned()
1013 })
1014 .ok()
1015 .flatten()
1016 .map(|repo| {
1017 repo.update(cx, |repo, _| {
1018 let current_branch =
1019 repo.branch.as_ref().map(|branch| branch.name().to_owned());
1020 repo.send_job(None, |state, _| async move {
1021 let RepositoryState::Local { backend, .. } = state else {
1022 return GitState {
1023 remote_url: None,
1024 head_sha: None,
1025 current_branch,
1026 diff: None,
1027 };
1028 };
1029
1030 let remote_url = backend.remote_url("origin");
1031 let head_sha = backend.head_sha().await;
1032 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1033
1034 GitState {
1035 remote_url,
1036 head_sha,
1037 current_branch,
1038 diff,
1039 }
1040 })
1041 })
1042 });
1043
1044 let git_state = match git_state {
1045 Some(git_state) => match git_state.ok() {
1046 Some(git_state) => git_state.await.ok(),
1047 None => None,
1048 },
1049 None => None,
1050 };
1051
1052 WorktreeSnapshot {
1053 worktree_path,
1054 git_state,
1055 }
1056 })
1057 }
1058
1059 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1060 todo!()
1061 // let mut markdown = Vec::new();
1062
1063 // let summary = self.summary().or_default();
1064 // writeln!(markdown, "# {summary}\n")?;
1065
1066 // for message in self.messages() {
1067 // writeln!(
1068 // markdown,
1069 // "## {role}\n",
1070 // role = match message.role {
1071 // Role::User => "User",
1072 // Role::Assistant => "Agent",
1073 // Role::System => "System",
1074 // }
1075 // )?;
1076
1077 // if !message.loaded_context.text.is_empty() {
1078 // writeln!(markdown, "{}", message.loaded_context.text)?;
1079 // }
1080
1081 // if !message.loaded_context.images.is_empty() {
1082 // writeln!(
1083 // markdown,
1084 // "\n{} images attached as context.\n",
1085 // message.loaded_context.images.len()
1086 // )?;
1087 // }
1088
1089 // for segment in &message.segments {
1090 // match segment {
1091 // MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1092 // MessageSegment::Thinking { text, .. } => {
1093 // writeln!(markdown, "<think>\n{}\n</think>\n", text)?
1094 // }
1095 // MessageSegment::RedactedThinking(_) => {}
1096 // }
1097 // }
1098
1099 // for tool_use in self.tool_uses_for_message(message.id, cx) {
1100 // writeln!(
1101 // markdown,
1102 // "**Use Tool: {} ({})**",
1103 // tool_use.name, tool_use.id
1104 // )?;
1105 // writeln!(markdown, "```json")?;
1106 // writeln!(
1107 // markdown,
1108 // "{}",
1109 // serde_json::to_string_pretty(&tool_use.input)?
1110 // )?;
1111 // writeln!(markdown, "```")?;
1112 // }
1113
1114 // for tool_result in self.tool_results_for_message(message.id) {
1115 // write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
1116 // if tool_result.is_error {
1117 // write!(markdown, " (Error)")?;
1118 // }
1119
1120 // writeln!(markdown, "**\n")?;
1121 // match &tool_result.content {
1122 // LanguageModelToolResultContent::Text(text) => {
1123 // writeln!(markdown, "{text}")?;
1124 // }
1125 // LanguageModelToolResultContent::Image(image) => {
1126 // writeln!(markdown, "", image.source)?;
1127 // }
1128 // }
1129
1130 // if let Some(output) = tool_result.output.as_ref() {
1131 // writeln!(
1132 // markdown,
1133 // "\n\nDebug Output:\n\n```json\n{}\n```\n",
1134 // serde_json::to_string_pretty(output)?
1135 // )?;
1136 // }
1137 // }
1138 // }
1139
1140 // Ok(String::from_utf8_lossy(&markdown).to_string())
1141 }
1142
1143 pub fn keep_edits_in_range(
1144 &mut self,
1145 buffer: Entity<language::Buffer>,
1146 buffer_range: Range<language::Anchor>,
1147 cx: &mut Context<Self>,
1148 ) {
1149 self.action_log.update(cx, |action_log, cx| {
1150 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1151 });
1152 }
1153
1154 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1155 self.action_log
1156 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1157 }
1158
1159 pub fn reject_edits_in_ranges(
1160 &mut self,
1161 buffer: Entity<language::Buffer>,
1162 buffer_ranges: Vec<Range<language::Anchor>>,
1163 cx: &mut Context<Self>,
1164 ) -> Task<Result<()>> {
1165 self.action_log.update(cx, |action_log, cx| {
1166 action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1167 })
1168 }
1169
1170 pub fn action_log(&self) -> &Entity<ActionLog> {
1171 &self.action_log
1172 }
1173
1174 pub fn project(&self) -> &Entity<Project> {
1175 &self.project
1176 }
1177
1178 pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1179 todo!()
1180 // if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
1181 // return;
1182 // }
1183
1184 // let now = Instant::now();
1185 // if let Some(last) = self.last_auto_capture_at {
1186 // if now.duration_since(last).as_secs() < 10 {
1187 // return;
1188 // }
1189 // }
1190
1191 // self.last_auto_capture_at = Some(now);
1192
1193 // let thread_id = self.id().clone();
1194 // let github_login = self
1195 // .project
1196 // .read(cx)
1197 // .user_store()
1198 // .read(cx)
1199 // .current_user()
1200 // .map(|user| user.github_login.clone());
1201 // let client = self.project.read(cx).client();
1202 // let serialize_task = self.serialize(cx);
1203
1204 // cx.background_executor()
1205 // .spawn(async move {
1206 // if let Ok(serialized_thread) = serialize_task.await {
1207 // if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1208 // telemetry::event!(
1209 // "Agent Thread Auto-Captured",
1210 // thread_id = thread_id.to_string(),
1211 // thread_data = thread_data,
1212 // auto_capture_reason = "tracked_user",
1213 // github_login = github_login
1214 // );
1215
1216 // client.telemetry().flush_events().await;
1217 // }
1218 // }
1219 // })
1220 // .detach();
1221 }
1222
1223 pub fn cumulative_token_usage(&self) -> TokenUsage {
1224 self.cumulative_token_usage
1225 }
1226
1227 pub fn token_usage_up_to_message(&self, message_id: AgentThreadMessageId) -> TotalTokenUsage {
1228 todo!()
1229 // let Some(model) = self.configured_model.as_ref() else {
1230 // return TotalTokenUsage::default();
1231 // };
1232
1233 // let max = model.model.max_token_count();
1234
1235 // let index = self
1236 // .messages
1237 // .iter()
1238 // .position(|msg| msg.id == message_id)
1239 // .unwrap_or(0);
1240
1241 // if index == 0 {
1242 // return TotalTokenUsage { total: 0, max };
1243 // }
1244
1245 // let token_usage = &self
1246 // .request_token_usage
1247 // .get(index - 1)
1248 // .cloned()
1249 // .unwrap_or_default();
1250
1251 // TotalTokenUsage {
1252 // total: token_usage.total_tokens(),
1253 // max,
1254 // }
1255 }
1256
1257 pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
1258 todo!()
1259 // let model = self.configured_model.as_ref()?;
1260
1261 // let max = model.model.max_token_count();
1262
1263 // if let Some(exceeded_error) = &self.exceeded_window_error {
1264 // if model.model.id() == exceeded_error.model_id {
1265 // return Some(TotalTokenUsage {
1266 // total: exceeded_error.token_count,
1267 // max,
1268 // });
1269 // }
1270 // }
1271
1272 // let total = self
1273 // .token_usage_at_last_message()
1274 // .unwrap_or_default()
1275 // .total_tokens();
1276
1277 // Some(TotalTokenUsage { total, max })
1278 }
1279
1280 fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
1281 self.request_token_usage
1282 .get(self.messages.len().saturating_sub(1))
1283 .or_else(|| self.request_token_usage.last())
1284 .cloned()
1285 }
1286
1287 fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
1288 let placeholder = self.token_usage_at_last_message().unwrap_or_default();
1289 self.request_token_usage
1290 .resize(self.messages.len(), placeholder);
1291
1292 if let Some(last) = self.request_token_usage.last_mut() {
1293 *last = token_usage;
1294 }
1295 }
1296
1297 fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
1298 self.project.update(cx, |project, cx| {
1299 project.user_store().update(cx, |user_store, cx| {
1300 user_store.update_model_request_usage(
1301 ModelRequestUsage(RequestUsage {
1302 amount: amount as i32,
1303 limit,
1304 }),
1305 cx,
1306 )
1307 })
1308 });
1309 }
1310}
1311
1312#[derive(Debug, Clone, Error)]
1313pub enum ThreadError {
1314 #[error("Payment required")]
1315 PaymentRequired,
1316 #[error("Model request limit reached")]
1317 ModelRequestLimitReached { plan: Plan },
1318 #[error("Message {header}: {message}")]
1319 Message {
1320 header: SharedString,
1321 message: SharedString,
1322 },
1323}
1324
1325#[derive(Debug, Clone)]
1326pub enum ThreadEvent {
1327 ShowError(ThreadError),
1328 StreamedCompletion,
1329 ReceivedTextChunk,
1330 NewRequest,
1331 StreamedAssistantText(AgentThreadMessageId, String),
1332 StreamedAssistantThinking(AgentThreadMessageId, String),
1333 StreamedToolUse {
1334 tool_use_id: LanguageModelToolUseId,
1335 ui_text: Arc<str>,
1336 input: serde_json::Value,
1337 },
1338 MissingToolUse {
1339 tool_use_id: LanguageModelToolUseId,
1340 ui_text: Arc<str>,
1341 },
1342 InvalidToolInput {
1343 tool_use_id: LanguageModelToolUseId,
1344 ui_text: Arc<str>,
1345 invalid_input_json: Arc<str>,
1346 },
1347 Stopped(Result<StopReason, Arc<anyhow::Error>>),
1348 MessageAdded(AgentThreadMessageId),
1349 MessageEdited(AgentThreadMessageId),
1350 MessageDeleted(AgentThreadMessageId),
1351 SummaryGenerated,
1352 SummaryChanged,
1353 CheckpointChanged,
1354 ToolConfirmationNeeded,
1355 ToolUseLimitReached,
1356 CancelEditing,
1357 CompletionCanceled,
1358 ProfileChanged,
1359 RetriesFailed {
1360 message: SharedString,
1361 },
1362}
1363
1364impl EventEmitter<ThreadEvent> for Thread {}
1365
1366struct PendingCompletion {
1367 id: usize,
1368 queue_state: QueueState,
1369 _task: Task<()>,
1370}
1371
1372/// Resolves tool name conflicts by ensuring all tool names are unique.
1373///
1374/// When multiple tools have the same name, this function applies the following rules:
1375/// 1. Native tools always keep their original name
1376/// 2. Context server tools get prefixed with their server ID and an underscore
1377/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
1378/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
1379///
1380/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
1381fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
1382 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
1383 let mut tool_name = tool.name();
1384 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
1385 tool_name
1386 }
1387
1388 const MAX_TOOL_NAME_LENGTH: usize = 64;
1389
1390 let mut duplicated_tool_names = HashSet::default();
1391 let mut seen_tool_names = HashSet::default();
1392 for tool in tools {
1393 let tool_name = resolve_tool_name(tool);
1394 if seen_tool_names.contains(&tool_name) {
1395 debug_assert!(
1396 tool.source() != assistant_tool::ToolSource::Native,
1397 "There are two built-in tools with the same name: {}",
1398 tool_name
1399 );
1400 duplicated_tool_names.insert(tool_name);
1401 } else {
1402 seen_tool_names.insert(tool_name);
1403 }
1404 }
1405
1406 if duplicated_tool_names.is_empty() {
1407 return tools
1408 .into_iter()
1409 .map(|tool| (resolve_tool_name(tool), tool.clone()))
1410 .collect();
1411 }
1412
1413 tools
1414 .into_iter()
1415 .filter_map(|tool| {
1416 let mut tool_name = resolve_tool_name(tool);
1417 if !duplicated_tool_names.contains(&tool_name) {
1418 return Some((tool_name, tool.clone()));
1419 }
1420 match tool.source() {
1421 assistant_tool::ToolSource::Native => {
1422 // Built-in tools always keep their original name
1423 Some((tool_name, tool.clone()))
1424 }
1425 assistant_tool::ToolSource::ContextServer { id } => {
1426 // Context server tools are prefixed with the context server ID, and truncated if necessary
1427 tool_name.insert(0, '_');
1428 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
1429 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
1430 let mut id = id.to_string();
1431 id.truncate(len);
1432 tool_name.insert_str(0, &id);
1433 } else {
1434 tool_name.insert_str(0, &id);
1435 }
1436
1437 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
1438
1439 if seen_tool_names.contains(&tool_name) {
1440 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
1441 None
1442 } else {
1443 Some((tool_name, tool.clone()))
1444 }
1445 }
1446 }
1447 })
1448 .collect()
1449}