1use std::fmt::Write as _;
2use std::io::Write;
3use std::ops::Range;
4use std::sync::Arc;
5
6use agent_rules::load_worktree_rules_file;
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::AssistantSettings;
9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::{BTreeMap, HashMap};
12use fs::Fs;
13use futures::future::Shared;
14use futures::{FutureExt, StreamExt as _};
15use git::repository::DiffType;
16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
17use language_model::{
18 ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
19 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
20 LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
21 PaymentRequiredError, Role, StopReason, TokenUsage,
22};
23use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
24use project::{Project, Worktree};
25use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use settings::Settings;
29use util::{ResultExt as _, TryFutureExt as _, post_inc};
30use uuid::Uuid;
31
32use crate::context::{AssistantContext, ContextId, format_context_as_string};
33use crate::thread_store::{
34 SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
35 SerializedToolUse,
36};
37use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
38
39#[derive(Debug, Clone, Copy)]
40pub enum RequestKind {
41 Chat,
42 /// Used when summarizing a thread.
43 Summarize,
44}
45
46#[derive(
47 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
48)]
49pub struct ThreadId(Arc<str>);
50
51impl ThreadId {
52 pub fn new() -> Self {
53 Self(Uuid::new_v4().to_string().into())
54 }
55}
56
57impl std::fmt::Display for ThreadId {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "{}", self.0)
60 }
61}
62
63impl From<&str> for ThreadId {
64 fn from(value: &str) -> Self {
65 Self(value.into())
66 }
67}
68
69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
70pub struct MessageId(pub(crate) usize);
71
72impl MessageId {
73 fn post_inc(&mut self) -> Self {
74 Self(post_inc(&mut self.0))
75 }
76}
77
78/// A message in a [`Thread`].
79#[derive(Debug, Clone)]
80pub struct Message {
81 pub id: MessageId,
82 pub role: Role,
83 pub segments: Vec<MessageSegment>,
84 pub context: String,
85}
86
87impl Message {
88 /// Returns whether the message contains any meaningful text that should be displayed
89 /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
90 pub fn should_display_content(&self) -> bool {
91 self.segments.iter().all(|segment| segment.should_display())
92 }
93
94 pub fn push_thinking(&mut self, text: &str) {
95 if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
96 segment.push_str(text);
97 } else {
98 self.segments
99 .push(MessageSegment::Thinking(text.to_string()));
100 }
101 }
102
103 pub fn push_text(&mut self, text: &str) {
104 if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
105 segment.push_str(text);
106 } else {
107 self.segments.push(MessageSegment::Text(text.to_string()));
108 }
109 }
110
111 pub fn to_string(&self) -> String {
112 let mut result = String::new();
113
114 if !self.context.is_empty() {
115 result.push_str(&self.context);
116 }
117
118 for segment in &self.segments {
119 match segment {
120 MessageSegment::Text(text) => result.push_str(text),
121 MessageSegment::Thinking(text) => {
122 result.push_str("<think>");
123 result.push_str(text);
124 result.push_str("</think>");
125 }
126 }
127 }
128
129 result
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum MessageSegment {
135 Text(String),
136 Thinking(String),
137}
138
139impl MessageSegment {
140 pub fn text_mut(&mut self) -> &mut String {
141 match self {
142 Self::Text(text) => text,
143 Self::Thinking(text) => text,
144 }
145 }
146
147 pub fn should_display(&self) -> bool {
148 // We add USING_TOOL_MARKER when making a request that includes tool uses
149 // without non-whitespace text around them, and this can cause the model
150 // to mimic the pattern, so we consider those segments not displayable.
151 match self {
152 Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
153 Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ProjectSnapshot {
160 pub worktree_snapshots: Vec<WorktreeSnapshot>,
161 pub unsaved_buffer_paths: Vec<String>,
162 pub timestamp: DateTime<Utc>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct WorktreeSnapshot {
167 pub worktree_path: String,
168 pub git_state: Option<GitState>,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct GitState {
173 pub remote_url: Option<String>,
174 pub head_sha: Option<String>,
175 pub current_branch: Option<String>,
176 pub diff: Option<String>,
177}
178
179#[derive(Clone)]
180pub struct ThreadCheckpoint {
181 message_id: MessageId,
182 git_checkpoint: GitStoreCheckpoint,
183}
184
185#[derive(Copy, Clone, Debug, PartialEq, Eq)]
186pub enum ThreadFeedback {
187 Positive,
188 Negative,
189}
190
191pub enum LastRestoreCheckpoint {
192 Pending {
193 message_id: MessageId,
194 },
195 Error {
196 message_id: MessageId,
197 error: String,
198 },
199}
200
201impl LastRestoreCheckpoint {
202 pub fn message_id(&self) -> MessageId {
203 match self {
204 LastRestoreCheckpoint::Pending { message_id } => *message_id,
205 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
206 }
207 }
208}
209
210#[derive(Clone, Debug, Default, Serialize, Deserialize)]
211pub enum DetailedSummaryState {
212 #[default]
213 NotGenerated,
214 Generating {
215 message_id: MessageId,
216 },
217 Generated {
218 text: SharedString,
219 message_id: MessageId,
220 },
221}
222
223#[derive(Default)]
224pub struct TotalTokenUsage {
225 pub total: usize,
226 pub max: usize,
227 pub ratio: TokenUsageRatio,
228}
229
230#[derive(Default, PartialEq, Eq)]
231pub enum TokenUsageRatio {
232 #[default]
233 Normal,
234 Warning,
235 Exceeded,
236}
237
238/// A thread of conversation with the LLM.
239pub struct Thread {
240 id: ThreadId,
241 updated_at: DateTime<Utc>,
242 summary: Option<SharedString>,
243 pending_summary: Task<Option<()>>,
244 detailed_summary_state: DetailedSummaryState,
245 messages: Vec<Message>,
246 next_message_id: MessageId,
247 context: BTreeMap<ContextId, AssistantContext>,
248 context_by_message: HashMap<MessageId, Vec<ContextId>>,
249 system_prompt_context: Option<AssistantSystemPromptContext>,
250 checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
251 completion_count: usize,
252 pending_completions: Vec<PendingCompletion>,
253 project: Entity<Project>,
254 prompt_builder: Arc<PromptBuilder>,
255 tools: Arc<ToolWorkingSet>,
256 tool_use: ToolUseState,
257 action_log: Entity<ActionLog>,
258 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
259 pending_checkpoint: Option<ThreadCheckpoint>,
260 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
261 cumulative_token_usage: TokenUsage,
262 feedback: Option<ThreadFeedback>,
263 message_feedback: HashMap<MessageId, ThreadFeedback>,
264}
265
266impl Thread {
267 pub fn new(
268 project: Entity<Project>,
269 tools: Arc<ToolWorkingSet>,
270 prompt_builder: Arc<PromptBuilder>,
271 cx: &mut Context<Self>,
272 ) -> Self {
273 Self {
274 id: ThreadId::new(),
275 updated_at: Utc::now(),
276 summary: None,
277 pending_summary: Task::ready(None),
278 detailed_summary_state: DetailedSummaryState::NotGenerated,
279 messages: Vec::new(),
280 next_message_id: MessageId(0),
281 context: BTreeMap::default(),
282 context_by_message: HashMap::default(),
283 system_prompt_context: None,
284 checkpoints_by_message: HashMap::default(),
285 completion_count: 0,
286 pending_completions: Vec::new(),
287 project: project.clone(),
288 prompt_builder,
289 tools: tools.clone(),
290 last_restore_checkpoint: None,
291 pending_checkpoint: None,
292 tool_use: ToolUseState::new(tools.clone()),
293 action_log: cx.new(|_| ActionLog::new(project.clone())),
294 initial_project_snapshot: {
295 let project_snapshot = Self::project_snapshot(project, cx);
296 cx.foreground_executor()
297 .spawn(async move { Some(project_snapshot.await) })
298 .shared()
299 },
300 cumulative_token_usage: TokenUsage::default(),
301 feedback: None,
302 message_feedback: HashMap::default(),
303 }
304 }
305
306 pub fn deserialize(
307 id: ThreadId,
308 serialized: SerializedThread,
309 project: Entity<Project>,
310 tools: Arc<ToolWorkingSet>,
311 prompt_builder: Arc<PromptBuilder>,
312 cx: &mut Context<Self>,
313 ) -> Self {
314 let next_message_id = MessageId(
315 serialized
316 .messages
317 .last()
318 .map(|message| message.id.0 + 1)
319 .unwrap_or(0),
320 );
321 let tool_use =
322 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
323
324 Self {
325 id,
326 updated_at: serialized.updated_at,
327 summary: Some(serialized.summary),
328 pending_summary: Task::ready(None),
329 detailed_summary_state: serialized.detailed_summary_state,
330 messages: serialized
331 .messages
332 .into_iter()
333 .map(|message| Message {
334 id: message.id,
335 role: message.role,
336 segments: message
337 .segments
338 .into_iter()
339 .map(|segment| match segment {
340 SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
341 SerializedMessageSegment::Thinking { text } => {
342 MessageSegment::Thinking(text)
343 }
344 })
345 .collect(),
346 context: message.context,
347 })
348 .collect(),
349 next_message_id,
350 context: BTreeMap::default(),
351 context_by_message: HashMap::default(),
352 system_prompt_context: None,
353 checkpoints_by_message: HashMap::default(),
354 completion_count: 0,
355 pending_completions: Vec::new(),
356 last_restore_checkpoint: None,
357 pending_checkpoint: None,
358 project: project.clone(),
359 prompt_builder,
360 tools,
361 tool_use,
362 action_log: cx.new(|_| ActionLog::new(project)),
363 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
364 cumulative_token_usage: serialized.cumulative_token_usage,
365 feedback: None,
366 message_feedback: HashMap::default(),
367 }
368 }
369
370 pub fn id(&self) -> &ThreadId {
371 &self.id
372 }
373
374 pub fn is_empty(&self) -> bool {
375 self.messages.is_empty()
376 }
377
378 pub fn updated_at(&self) -> DateTime<Utc> {
379 self.updated_at
380 }
381
382 pub fn touch_updated_at(&mut self) {
383 self.updated_at = Utc::now();
384 }
385
386 pub fn summary(&self) -> Option<SharedString> {
387 self.summary.clone()
388 }
389
390 pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
391
392 pub fn summary_or_default(&self) -> SharedString {
393 self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
394 }
395
396 pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
397 let Some(current_summary) = &self.summary else {
398 // Don't allow setting summary until generated
399 return;
400 };
401
402 let mut new_summary = new_summary.into();
403
404 if new_summary.is_empty() {
405 new_summary = Self::DEFAULT_SUMMARY;
406 }
407
408 if current_summary != &new_summary {
409 self.summary = Some(new_summary);
410 cx.emit(ThreadEvent::SummaryChanged);
411 }
412 }
413
414 pub fn latest_detailed_summary_or_text(&self) -> SharedString {
415 self.latest_detailed_summary()
416 .unwrap_or_else(|| self.text().into())
417 }
418
419 fn latest_detailed_summary(&self) -> Option<SharedString> {
420 if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
421 Some(text.clone())
422 } else {
423 None
424 }
425 }
426
427 pub fn message(&self, id: MessageId) -> Option<&Message> {
428 self.messages.iter().find(|message| message.id == id)
429 }
430
431 pub fn messages(&self) -> impl Iterator<Item = &Message> {
432 self.messages.iter()
433 }
434
435 pub fn is_generating(&self) -> bool {
436 !self.pending_completions.is_empty() || !self.all_tools_finished()
437 }
438
439 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
440 &self.tools
441 }
442
443 pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
444 self.tool_use
445 .pending_tool_uses()
446 .into_iter()
447 .find(|tool_use| &tool_use.id == id)
448 }
449
450 pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
451 self.tool_use
452 .pending_tool_uses()
453 .into_iter()
454 .filter(|tool_use| tool_use.status.needs_confirmation())
455 }
456
457 pub fn has_pending_tool_uses(&self) -> bool {
458 !self.tool_use.pending_tool_uses().is_empty()
459 }
460
461 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
462 self.checkpoints_by_message.get(&id).cloned()
463 }
464
465 pub fn restore_checkpoint(
466 &mut self,
467 checkpoint: ThreadCheckpoint,
468 cx: &mut Context<Self>,
469 ) -> Task<Result<()>> {
470 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
471 message_id: checkpoint.message_id,
472 });
473 cx.emit(ThreadEvent::CheckpointChanged);
474 cx.notify();
475
476 let git_store = self.project().read(cx).git_store().clone();
477 let restore = git_store.update(cx, |git_store, cx| {
478 git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
479 });
480
481 cx.spawn(async move |this, cx| {
482 let result = restore.await;
483 this.update(cx, |this, cx| {
484 if let Err(err) = result.as_ref() {
485 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
486 message_id: checkpoint.message_id,
487 error: err.to_string(),
488 });
489 } else {
490 this.truncate(checkpoint.message_id, cx);
491 this.last_restore_checkpoint = None;
492 }
493 this.pending_checkpoint = None;
494 cx.emit(ThreadEvent::CheckpointChanged);
495 cx.notify();
496 })?;
497 result
498 })
499 }
500
501 fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
502 let pending_checkpoint = if self.is_generating() {
503 return;
504 } else if let Some(checkpoint) = self.pending_checkpoint.take() {
505 checkpoint
506 } else {
507 return;
508 };
509
510 let git_store = self.project.read(cx).git_store().clone();
511 let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
512 cx.spawn(async move |this, cx| match final_checkpoint.await {
513 Ok(final_checkpoint) => {
514 let equal = git_store
515 .update(cx, |store, cx| {
516 store.compare_checkpoints(
517 pending_checkpoint.git_checkpoint.clone(),
518 final_checkpoint.clone(),
519 cx,
520 )
521 })?
522 .await
523 .unwrap_or(false);
524
525 if equal {
526 git_store
527 .update(cx, |store, cx| {
528 store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
529 })?
530 .detach();
531 } else {
532 this.update(cx, |this, cx| {
533 this.insert_checkpoint(pending_checkpoint, cx)
534 })?;
535 }
536
537 git_store
538 .update(cx, |store, cx| {
539 store.delete_checkpoint(final_checkpoint, cx)
540 })?
541 .detach();
542
543 Ok(())
544 }
545 Err(_) => this.update(cx, |this, cx| {
546 this.insert_checkpoint(pending_checkpoint, cx)
547 }),
548 })
549 .detach();
550 }
551
552 fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
553 self.checkpoints_by_message
554 .insert(checkpoint.message_id, checkpoint);
555 cx.emit(ThreadEvent::CheckpointChanged);
556 cx.notify();
557 }
558
559 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
560 self.last_restore_checkpoint.as_ref()
561 }
562
563 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
564 let Some(message_ix) = self
565 .messages
566 .iter()
567 .rposition(|message| message.id == message_id)
568 else {
569 return;
570 };
571 for deleted_message in self.messages.drain(message_ix..) {
572 self.context_by_message.remove(&deleted_message.id);
573 self.checkpoints_by_message.remove(&deleted_message.id);
574 }
575 cx.notify();
576 }
577
578 pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
579 self.context_by_message
580 .get(&id)
581 .into_iter()
582 .flat_map(|context| {
583 context
584 .iter()
585 .filter_map(|context_id| self.context.get(&context_id))
586 })
587 }
588
589 /// Returns whether all of the tool uses have finished running.
590 pub fn all_tools_finished(&self) -> bool {
591 // If the only pending tool uses left are the ones with errors, then
592 // that means that we've finished running all of the pending tools.
593 self.tool_use
594 .pending_tool_uses()
595 .iter()
596 .all(|tool_use| tool_use.status.is_error())
597 }
598
599 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
600 self.tool_use.tool_uses_for_message(id, cx)
601 }
602
603 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
604 self.tool_use.tool_results_for_message(id)
605 }
606
607 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
608 self.tool_use.tool_result(id)
609 }
610
611 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
612 self.tool_use.message_has_tool_results(message_id)
613 }
614
615 pub fn insert_user_message(
616 &mut self,
617 text: impl Into<String>,
618 context: Vec<AssistantContext>,
619 git_checkpoint: Option<GitStoreCheckpoint>,
620 cx: &mut Context<Self>,
621 ) -> MessageId {
622 let text = text.into();
623
624 let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
625
626 // Filter out contexts that have already been included in previous messages
627 let new_context: Vec<_> = context
628 .into_iter()
629 .filter(|ctx| !self.context.contains_key(&ctx.id()))
630 .collect();
631
632 if !new_context.is_empty() {
633 if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
634 if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
635 message.context = context_string;
636 }
637 }
638
639 self.action_log.update(cx, |log, cx| {
640 // Track all buffers added as context
641 for ctx in &new_context {
642 match ctx {
643 AssistantContext::File(file_ctx) => {
644 log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
645 }
646 AssistantContext::Directory(dir_ctx) => {
647 for context_buffer in &dir_ctx.context_buffers {
648 log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
649 }
650 }
651 AssistantContext::Symbol(symbol_ctx) => {
652 log.buffer_added_as_context(
653 symbol_ctx.context_symbol.buffer.clone(),
654 cx,
655 );
656 }
657 AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
658 }
659 }
660 });
661 }
662
663 let context_ids = new_context
664 .iter()
665 .map(|context| context.id())
666 .collect::<Vec<_>>();
667 self.context.extend(
668 new_context
669 .into_iter()
670 .map(|context| (context.id(), context)),
671 );
672 self.context_by_message.insert(message_id, context_ids);
673
674 if let Some(git_checkpoint) = git_checkpoint {
675 self.pending_checkpoint = Some(ThreadCheckpoint {
676 message_id,
677 git_checkpoint,
678 });
679 }
680 message_id
681 }
682
683 pub fn insert_message(
684 &mut self,
685 role: Role,
686 segments: Vec<MessageSegment>,
687 cx: &mut Context<Self>,
688 ) -> MessageId {
689 let id = self.next_message_id.post_inc();
690 self.messages.push(Message {
691 id,
692 role,
693 segments,
694 context: String::new(),
695 });
696 self.touch_updated_at();
697 cx.emit(ThreadEvent::MessageAdded(id));
698 id
699 }
700
701 pub fn edit_message(
702 &mut self,
703 id: MessageId,
704 new_role: Role,
705 new_segments: Vec<MessageSegment>,
706 cx: &mut Context<Self>,
707 ) -> bool {
708 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
709 return false;
710 };
711 message.role = new_role;
712 message.segments = new_segments;
713 self.touch_updated_at();
714 cx.emit(ThreadEvent::MessageEdited(id));
715 true
716 }
717
718 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
719 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
720 return false;
721 };
722 self.messages.remove(index);
723 self.context_by_message.remove(&id);
724 self.touch_updated_at();
725 cx.emit(ThreadEvent::MessageDeleted(id));
726 true
727 }
728
729 /// Returns the representation of this [`Thread`] in a textual form.
730 ///
731 /// This is the representation we use when attaching a thread as context to another thread.
732 pub fn text(&self) -> String {
733 let mut text = String::new();
734
735 for message in &self.messages {
736 text.push_str(match message.role {
737 language_model::Role::User => "User:",
738 language_model::Role::Assistant => "Assistant:",
739 language_model::Role::System => "System:",
740 });
741 text.push('\n');
742
743 for segment in &message.segments {
744 match segment {
745 MessageSegment::Text(content) => text.push_str(content),
746 MessageSegment::Thinking(content) => {
747 text.push_str(&format!("<think>{}</think>", content))
748 }
749 }
750 }
751 text.push('\n');
752 }
753
754 text
755 }
756
757 /// Serializes this thread into a format for storage or telemetry.
758 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
759 let initial_project_snapshot = self.initial_project_snapshot.clone();
760 cx.spawn(async move |this, cx| {
761 let initial_project_snapshot = initial_project_snapshot.await;
762 this.read_with(cx, |this, cx| SerializedThread {
763 version: SerializedThread::VERSION.to_string(),
764 summary: this.summary_or_default(),
765 updated_at: this.updated_at(),
766 messages: this
767 .messages()
768 .map(|message| SerializedMessage {
769 id: message.id,
770 role: message.role,
771 segments: message
772 .segments
773 .iter()
774 .map(|segment| match segment {
775 MessageSegment::Text(text) => {
776 SerializedMessageSegment::Text { text: text.clone() }
777 }
778 MessageSegment::Thinking(text) => {
779 SerializedMessageSegment::Thinking { text: text.clone() }
780 }
781 })
782 .collect(),
783 tool_uses: this
784 .tool_uses_for_message(message.id, cx)
785 .into_iter()
786 .map(|tool_use| SerializedToolUse {
787 id: tool_use.id,
788 name: tool_use.name,
789 input: tool_use.input,
790 })
791 .collect(),
792 tool_results: this
793 .tool_results_for_message(message.id)
794 .into_iter()
795 .map(|tool_result| SerializedToolResult {
796 tool_use_id: tool_result.tool_use_id.clone(),
797 is_error: tool_result.is_error,
798 content: tool_result.content.clone(),
799 })
800 .collect(),
801 context: message.context.clone(),
802 })
803 .collect(),
804 initial_project_snapshot,
805 cumulative_token_usage: this.cumulative_token_usage.clone(),
806 detailed_summary_state: this.detailed_summary_state.clone(),
807 })
808 })
809 }
810
811 pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
812 self.system_prompt_context = Some(context);
813 }
814
815 pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
816 &self.system_prompt_context
817 }
818
819 pub fn load_system_prompt_context(
820 &self,
821 cx: &App,
822 ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
823 let project = self.project.read(cx);
824 let tasks = project
825 .visible_worktrees(cx)
826 .map(|worktree| {
827 Self::load_worktree_info_for_system_prompt(
828 project.fs().clone(),
829 worktree.read(cx),
830 cx,
831 )
832 })
833 .collect::<Vec<_>>();
834
835 cx.spawn(async |_cx| {
836 let results = futures::future::join_all(tasks).await;
837 let mut first_err = None;
838 let worktrees = results
839 .into_iter()
840 .map(|(worktree, err)| {
841 if first_err.is_none() && err.is_some() {
842 first_err = err;
843 }
844 worktree
845 })
846 .collect::<Vec<_>>();
847 (AssistantSystemPromptContext::new(worktrees), first_err)
848 })
849 }
850
851 fn load_worktree_info_for_system_prompt(
852 fs: Arc<dyn Fs>,
853 worktree: &Worktree,
854 cx: &App,
855 ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
856 let root_name = worktree.root_name().into();
857 let abs_path = worktree.abs_path();
858
859 let rules_task = load_worktree_rules_file(fs, worktree, cx);
860 let Some(rules_task) = rules_task else {
861 return Task::ready((
862 WorktreeInfoForSystemPrompt {
863 root_name,
864 abs_path,
865 rules_file: None,
866 },
867 None,
868 ));
869 };
870
871 cx.spawn(async move |_| {
872 let (rules_file, rules_file_error) = match rules_task.await {
873 Ok(rules_file) => (Some(rules_file), None),
874 Err(err) => (
875 None,
876 Some(ThreadError::Message {
877 header: "Error loading rules file".into(),
878 message: format!("{err}").into(),
879 }),
880 ),
881 };
882 let worktree_info = WorktreeInfoForSystemPrompt {
883 root_name,
884 abs_path,
885 rules_file,
886 };
887 (worktree_info, rules_file_error)
888 })
889 }
890
891 pub fn send_to_model(
892 &mut self,
893 model: Arc<dyn LanguageModel>,
894 request_kind: RequestKind,
895 cx: &mut Context<Self>,
896 ) {
897 let mut request = self.to_completion_request(request_kind, cx);
898 if model.supports_tools() {
899 request.tools = {
900 let mut tools = Vec::new();
901 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
902 LanguageModelRequestTool {
903 name: tool.name(),
904 description: tool.description(),
905 input_schema: tool.input_schema(model.tool_input_format()),
906 }
907 }));
908
909 tools
910 };
911 }
912
913 self.stream_completion(request, model, cx);
914 }
915
916 pub fn used_tools_since_last_user_message(&self) -> bool {
917 for message in self.messages.iter().rev() {
918 if self.tool_use.message_has_tool_results(message.id) {
919 return true;
920 } else if message.role == Role::User {
921 return false;
922 }
923 }
924
925 false
926 }
927
928 pub fn to_completion_request(
929 &self,
930 request_kind: RequestKind,
931 cx: &App,
932 ) -> LanguageModelRequest {
933 let mut request = LanguageModelRequest {
934 messages: vec![],
935 tools: Vec::new(),
936 stop: Vec::new(),
937 temperature: None,
938 };
939
940 if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
941 if let Some(system_prompt) = self
942 .prompt_builder
943 .generate_assistant_system_prompt(system_prompt_context)
944 .context("failed to generate assistant system prompt")
945 .log_err()
946 {
947 request.messages.push(LanguageModelRequestMessage {
948 role: Role::System,
949 content: vec![MessageContent::Text(system_prompt)],
950 cache: true,
951 });
952 }
953 } else {
954 log::error!("system_prompt_context not set.")
955 }
956
957 for message in &self.messages {
958 let mut request_message = LanguageModelRequestMessage {
959 role: message.role,
960 content: Vec::new(),
961 cache: false,
962 };
963
964 match request_kind {
965 RequestKind::Chat => {
966 self.tool_use
967 .attach_tool_results(message.id, &mut request_message);
968 }
969 RequestKind::Summarize => {
970 // We don't care about tool use during summarization.
971 if self.tool_use.message_has_tool_results(message.id) {
972 continue;
973 }
974 }
975 }
976
977 if !message.segments.is_empty() {
978 request_message
979 .content
980 .push(MessageContent::Text(message.to_string()));
981 }
982
983 match request_kind {
984 RequestKind::Chat => {
985 self.tool_use
986 .attach_tool_uses(message.id, &mut request_message);
987 }
988 RequestKind::Summarize => {
989 // We don't care about tool use during summarization.
990 }
991 };
992
993 request.messages.push(request_message);
994 }
995
996 // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
997 if let Some(last) = request.messages.last_mut() {
998 last.cache = true;
999 }
1000
1001 self.attached_tracked_files_state(&mut request.messages, cx);
1002
1003 request
1004 }
1005
1006 fn attached_tracked_files_state(
1007 &self,
1008 messages: &mut Vec<LanguageModelRequestMessage>,
1009 cx: &App,
1010 ) {
1011 const STALE_FILES_HEADER: &str = "These files changed since last read:";
1012
1013 let mut stale_message = String::new();
1014
1015 let action_log = self.action_log.read(cx);
1016
1017 for stale_file in action_log.stale_buffers(cx) {
1018 let Some(file) = stale_file.read(cx).file() else {
1019 continue;
1020 };
1021
1022 if stale_message.is_empty() {
1023 write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1024 }
1025
1026 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1027 }
1028
1029 let mut content = Vec::with_capacity(2);
1030
1031 if !stale_message.is_empty() {
1032 content.push(stale_message.into());
1033 }
1034
1035 if action_log.has_edited_files_since_project_diagnostics_check() {
1036 content.push(
1037 "\n\nWhen you're done making changes, make sure to check project diagnostics \
1038 and fix all errors AND warnings you introduced! \
1039 DO NOT mention you're going to do this until you're done."
1040 .into(),
1041 );
1042 }
1043
1044 if !content.is_empty() {
1045 let context_message = LanguageModelRequestMessage {
1046 role: Role::User,
1047 content,
1048 cache: false,
1049 };
1050
1051 messages.push(context_message);
1052 }
1053 }
1054
1055 pub fn stream_completion(
1056 &mut self,
1057 request: LanguageModelRequest,
1058 model: Arc<dyn LanguageModel>,
1059 cx: &mut Context<Self>,
1060 ) {
1061 let pending_completion_id = post_inc(&mut self.completion_count);
1062
1063 let task = cx.spawn(async move |thread, cx| {
1064 let stream = model.stream_completion(request, &cx);
1065 let initial_token_usage =
1066 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1067 let stream_completion = async {
1068 let mut events = stream.await?;
1069 let mut stop_reason = StopReason::EndTurn;
1070 let mut current_token_usage = TokenUsage::default();
1071
1072 while let Some(event) = events.next().await {
1073 let event = event?;
1074
1075 thread.update(cx, |thread, cx| {
1076 match event {
1077 LanguageModelCompletionEvent::StartMessage { .. } => {
1078 thread.insert_message(
1079 Role::Assistant,
1080 vec![MessageSegment::Text(String::new())],
1081 cx,
1082 );
1083 }
1084 LanguageModelCompletionEvent::Stop(reason) => {
1085 stop_reason = reason;
1086 }
1087 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1088 thread.cumulative_token_usage =
1089 thread.cumulative_token_usage.clone() + token_usage.clone()
1090 - current_token_usage.clone();
1091 current_token_usage = token_usage;
1092 }
1093 LanguageModelCompletionEvent::Text(chunk) => {
1094 if let Some(last_message) = thread.messages.last_mut() {
1095 if last_message.role == Role::Assistant {
1096 last_message.push_text(&chunk);
1097 cx.emit(ThreadEvent::StreamedAssistantText(
1098 last_message.id,
1099 chunk,
1100 ));
1101 } else {
1102 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1103 // of a new Assistant response.
1104 //
1105 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1106 // will result in duplicating the text of the chunk in the rendered Markdown.
1107 thread.insert_message(
1108 Role::Assistant,
1109 vec![MessageSegment::Text(chunk.to_string())],
1110 cx,
1111 );
1112 };
1113 }
1114 }
1115 LanguageModelCompletionEvent::Thinking(chunk) => {
1116 if let Some(last_message) = thread.messages.last_mut() {
1117 if last_message.role == Role::Assistant {
1118 last_message.push_thinking(&chunk);
1119 cx.emit(ThreadEvent::StreamedAssistantThinking(
1120 last_message.id,
1121 chunk,
1122 ));
1123 } else {
1124 // If we won't have an Assistant message yet, assume this chunk marks the beginning
1125 // of a new Assistant response.
1126 //
1127 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1128 // will result in duplicating the text of the chunk in the rendered Markdown.
1129 thread.insert_message(
1130 Role::Assistant,
1131 vec![MessageSegment::Thinking(chunk.to_string())],
1132 cx,
1133 );
1134 };
1135 }
1136 }
1137 LanguageModelCompletionEvent::ToolUse(tool_use) => {
1138 let last_assistant_message_id = thread
1139 .messages
1140 .iter_mut()
1141 .rfind(|message| message.role == Role::Assistant)
1142 .map(|message| message.id)
1143 .unwrap_or_else(|| {
1144 thread.insert_message(Role::Assistant, vec![], cx)
1145 });
1146
1147 thread.tool_use.request_tool_use(
1148 last_assistant_message_id,
1149 tool_use,
1150 cx,
1151 );
1152 }
1153 }
1154
1155 thread.touch_updated_at();
1156 cx.emit(ThreadEvent::StreamedCompletion);
1157 cx.notify();
1158 })?;
1159
1160 smol::future::yield_now().await;
1161 }
1162
1163 thread.update(cx, |thread, cx| {
1164 thread
1165 .pending_completions
1166 .retain(|completion| completion.id != pending_completion_id);
1167
1168 if thread.summary.is_none() && thread.messages.len() >= 2 {
1169 thread.summarize(cx);
1170 }
1171 })?;
1172
1173 anyhow::Ok(stop_reason)
1174 };
1175
1176 let result = stream_completion.await;
1177
1178 thread
1179 .update(cx, |thread, cx| {
1180 thread.finalize_pending_checkpoint(cx);
1181 match result.as_ref() {
1182 Ok(stop_reason) => match stop_reason {
1183 StopReason::ToolUse => {
1184 let tool_uses = thread.use_pending_tools(cx);
1185 cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1186 }
1187 StopReason::EndTurn => {}
1188 StopReason::MaxTokens => {}
1189 },
1190 Err(error) => {
1191 if error.is::<PaymentRequiredError>() {
1192 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1193 } else if error.is::<MaxMonthlySpendReachedError>() {
1194 cx.emit(ThreadEvent::ShowError(
1195 ThreadError::MaxMonthlySpendReached,
1196 ));
1197 } else {
1198 let error_message = error
1199 .chain()
1200 .map(|err| err.to_string())
1201 .collect::<Vec<_>>()
1202 .join("\n");
1203 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1204 header: "Error interacting with language model".into(),
1205 message: SharedString::from(error_message.clone()),
1206 }));
1207 }
1208
1209 thread.cancel_last_completion(cx);
1210 }
1211 }
1212 cx.emit(ThreadEvent::DoneStreaming);
1213
1214 if let Ok(initial_usage) = initial_token_usage {
1215 let usage = thread.cumulative_token_usage.clone() - initial_usage;
1216
1217 telemetry::event!(
1218 "Assistant Thread Completion",
1219 thread_id = thread.id().to_string(),
1220 model = model.telemetry_id(),
1221 model_provider = model.provider_id().to_string(),
1222 input_tokens = usage.input_tokens,
1223 output_tokens = usage.output_tokens,
1224 cache_creation_input_tokens = usage.cache_creation_input_tokens,
1225 cache_read_input_tokens = usage.cache_read_input_tokens,
1226 );
1227 }
1228 })
1229 .ok();
1230 });
1231
1232 self.pending_completions.push(PendingCompletion {
1233 id: pending_completion_id,
1234 _task: task,
1235 });
1236 }
1237
1238 pub fn summarize(&mut self, cx: &mut Context<Self>) {
1239 let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1240 return;
1241 };
1242
1243 if !model.provider.is_authenticated(cx) {
1244 return;
1245 }
1246
1247 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1248 request.messages.push(LanguageModelRequestMessage {
1249 role: Role::User,
1250 content: vec![
1251 "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1252 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1253 If the conversation is about a specific subject, include it in the title. \
1254 Be descriptive. DO NOT speak in the first person."
1255 .into(),
1256 ],
1257 cache: false,
1258 });
1259
1260 self.pending_summary = cx.spawn(async move |this, cx| {
1261 async move {
1262 let stream = model.model.stream_completion_text(request, &cx);
1263 let mut messages = stream.await?;
1264
1265 let mut new_summary = String::new();
1266 while let Some(message) = messages.stream.next().await {
1267 let text = message?;
1268 let mut lines = text.lines();
1269 new_summary.extend(lines.next());
1270
1271 // Stop if the LLM generated multiple lines.
1272 if lines.next().is_some() {
1273 break;
1274 }
1275 }
1276
1277 this.update(cx, |this, cx| {
1278 if !new_summary.is_empty() {
1279 this.summary = Some(new_summary.into());
1280 }
1281
1282 cx.emit(ThreadEvent::SummaryGenerated);
1283 })?;
1284
1285 anyhow::Ok(())
1286 }
1287 .log_err()
1288 .await
1289 });
1290 }
1291
1292 pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1293 let last_message_id = self.messages.last().map(|message| message.id)?;
1294
1295 match &self.detailed_summary_state {
1296 DetailedSummaryState::Generating { message_id, .. }
1297 | DetailedSummaryState::Generated { message_id, .. }
1298 if *message_id == last_message_id =>
1299 {
1300 // Already up-to-date
1301 return None;
1302 }
1303 _ => {}
1304 }
1305
1306 let ConfiguredModel { model, provider } =
1307 LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1308
1309 if !provider.is_authenticated(cx) {
1310 return None;
1311 }
1312
1313 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1314
1315 request.messages.push(LanguageModelRequestMessage {
1316 role: Role::User,
1317 content: vec![
1318 "Generate a detailed summary of this conversation. Include:\n\
1319 1. A brief overview of what was discussed\n\
1320 2. Key facts or information discovered\n\
1321 3. Outcomes or conclusions reached\n\
1322 4. Any action items or next steps if any\n\
1323 Format it in Markdown with headings and bullet points."
1324 .into(),
1325 ],
1326 cache: false,
1327 });
1328
1329 let task = cx.spawn(async move |thread, cx| {
1330 let stream = model.stream_completion_text(request, &cx);
1331 let Some(mut messages) = stream.await.log_err() else {
1332 thread
1333 .update(cx, |this, _cx| {
1334 this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1335 })
1336 .log_err();
1337
1338 return;
1339 };
1340
1341 let mut new_detailed_summary = String::new();
1342
1343 while let Some(chunk) = messages.stream.next().await {
1344 if let Some(chunk) = chunk.log_err() {
1345 new_detailed_summary.push_str(&chunk);
1346 }
1347 }
1348
1349 thread
1350 .update(cx, |this, _cx| {
1351 this.detailed_summary_state = DetailedSummaryState::Generated {
1352 text: new_detailed_summary.into(),
1353 message_id: last_message_id,
1354 };
1355 })
1356 .log_err();
1357 });
1358
1359 self.detailed_summary_state = DetailedSummaryState::Generating {
1360 message_id: last_message_id,
1361 };
1362
1363 Some(task)
1364 }
1365
1366 pub fn is_generating_detailed_summary(&self) -> bool {
1367 matches!(
1368 self.detailed_summary_state,
1369 DetailedSummaryState::Generating { .. }
1370 )
1371 }
1372
1373 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1374 let request = self.to_completion_request(RequestKind::Chat, cx);
1375 let messages = Arc::new(request.messages);
1376 let pending_tool_uses = self
1377 .tool_use
1378 .pending_tool_uses()
1379 .into_iter()
1380 .filter(|tool_use| tool_use.status.is_idle())
1381 .cloned()
1382 .collect::<Vec<_>>();
1383
1384 for tool_use in pending_tool_uses.iter() {
1385 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1386 if tool.needs_confirmation(&tool_use.input, cx)
1387 && !AssistantSettings::get_global(cx).always_allow_tool_actions
1388 {
1389 self.tool_use.confirm_tool_use(
1390 tool_use.id.clone(),
1391 tool_use.ui_text.clone(),
1392 tool_use.input.clone(),
1393 messages.clone(),
1394 tool,
1395 );
1396 cx.emit(ThreadEvent::ToolConfirmationNeeded);
1397 } else {
1398 self.run_tool(
1399 tool_use.id.clone(),
1400 tool_use.ui_text.clone(),
1401 tool_use.input.clone(),
1402 &messages,
1403 tool,
1404 cx,
1405 );
1406 }
1407 }
1408 }
1409
1410 pending_tool_uses
1411 }
1412
1413 pub fn run_tool(
1414 &mut self,
1415 tool_use_id: LanguageModelToolUseId,
1416 ui_text: impl Into<SharedString>,
1417 input: serde_json::Value,
1418 messages: &[LanguageModelRequestMessage],
1419 tool: Arc<dyn Tool>,
1420 cx: &mut Context<Thread>,
1421 ) {
1422 let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1423 self.tool_use
1424 .run_pending_tool(tool_use_id, ui_text.into(), task);
1425 }
1426
1427 fn spawn_tool_use(
1428 &mut self,
1429 tool_use_id: LanguageModelToolUseId,
1430 messages: &[LanguageModelRequestMessage],
1431 input: serde_json::Value,
1432 tool: Arc<dyn Tool>,
1433 cx: &mut Context<Thread>,
1434 ) -> Task<()> {
1435 let tool_name: Arc<str> = tool.name().into();
1436
1437 let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1438 Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1439 } else {
1440 tool.run(
1441 input,
1442 messages,
1443 self.project.clone(),
1444 self.action_log.clone(),
1445 cx,
1446 )
1447 };
1448
1449 cx.spawn({
1450 async move |thread: WeakEntity<Thread>, cx| {
1451 let output = run_tool.await;
1452
1453 thread
1454 .update(cx, |thread, cx| {
1455 let pending_tool_use = thread.tool_use.insert_tool_output(
1456 tool_use_id.clone(),
1457 tool_name,
1458 output,
1459 cx,
1460 );
1461 thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1462 })
1463 .ok();
1464 }
1465 })
1466 }
1467
1468 fn tool_finished(
1469 &mut self,
1470 tool_use_id: LanguageModelToolUseId,
1471 pending_tool_use: Option<PendingToolUse>,
1472 canceled: bool,
1473 cx: &mut Context<Self>,
1474 ) {
1475 if self.all_tools_finished() {
1476 let model_registry = LanguageModelRegistry::read_global(cx);
1477 if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1478 self.attach_tool_results(cx);
1479 if !canceled {
1480 self.send_to_model(model, RequestKind::Chat, cx);
1481 }
1482 }
1483 }
1484
1485 cx.emit(ThreadEvent::ToolFinished {
1486 tool_use_id,
1487 pending_tool_use,
1488 });
1489 }
1490
1491 pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1492 // Insert a user message to contain the tool results.
1493 self.insert_user_message(
1494 // TODO: Sending up a user message without any content results in the model sending back
1495 // responses that also don't have any content. We currently don't handle this case well,
1496 // so for now we provide some text to keep the model on track.
1497 "Here are the tool results.",
1498 Vec::new(),
1499 None,
1500 cx,
1501 );
1502 }
1503
1504 /// Cancels the last pending completion, if there are any pending.
1505 ///
1506 /// Returns whether a completion was canceled.
1507 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1508 let canceled = if self.pending_completions.pop().is_some() {
1509 true
1510 } else {
1511 let mut canceled = false;
1512 for pending_tool_use in self.tool_use.cancel_pending() {
1513 canceled = true;
1514 self.tool_finished(
1515 pending_tool_use.id.clone(),
1516 Some(pending_tool_use),
1517 true,
1518 cx,
1519 );
1520 }
1521 canceled
1522 };
1523 self.finalize_pending_checkpoint(cx);
1524 canceled
1525 }
1526
1527 pub fn feedback(&self) -> Option<ThreadFeedback> {
1528 self.feedback
1529 }
1530
1531 pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1532 self.message_feedback.get(&message_id).copied()
1533 }
1534
1535 pub fn report_message_feedback(
1536 &mut self,
1537 message_id: MessageId,
1538 feedback: ThreadFeedback,
1539 cx: &mut Context<Self>,
1540 ) -> Task<Result<()>> {
1541 if self.message_feedback.get(&message_id) == Some(&feedback) {
1542 return Task::ready(Ok(()));
1543 }
1544
1545 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1546 let serialized_thread = self.serialize(cx);
1547 let thread_id = self.id().clone();
1548 let client = self.project.read(cx).client();
1549
1550 let enabled_tool_names: Vec<String> = self
1551 .tools()
1552 .enabled_tools(cx)
1553 .iter()
1554 .map(|tool| tool.name().to_string())
1555 .collect();
1556
1557 self.message_feedback.insert(message_id, feedback);
1558
1559 cx.notify();
1560
1561 let message_content = self
1562 .message(message_id)
1563 .map(|msg| msg.to_string())
1564 .unwrap_or_default();
1565
1566 cx.background_spawn(async move {
1567 let final_project_snapshot = final_project_snapshot.await;
1568 let serialized_thread = serialized_thread.await?;
1569 let thread_data =
1570 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1571
1572 let rating = match feedback {
1573 ThreadFeedback::Positive => "positive",
1574 ThreadFeedback::Negative => "negative",
1575 };
1576 telemetry::event!(
1577 "Assistant Thread Rated",
1578 rating,
1579 thread_id,
1580 enabled_tool_names,
1581 message_id = message_id.0,
1582 message_content,
1583 thread_data,
1584 final_project_snapshot
1585 );
1586 client.telemetry().flush_events();
1587
1588 Ok(())
1589 })
1590 }
1591
1592 pub fn report_feedback(
1593 &mut self,
1594 feedback: ThreadFeedback,
1595 cx: &mut Context<Self>,
1596 ) -> Task<Result<()>> {
1597 let last_assistant_message_id = self
1598 .messages
1599 .iter()
1600 .rev()
1601 .find(|msg| msg.role == Role::Assistant)
1602 .map(|msg| msg.id);
1603
1604 if let Some(message_id) = last_assistant_message_id {
1605 self.report_message_feedback(message_id, feedback, cx)
1606 } else {
1607 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1608 let serialized_thread = self.serialize(cx);
1609 let thread_id = self.id().clone();
1610 let client = self.project.read(cx).client();
1611 self.feedback = Some(feedback);
1612 cx.notify();
1613
1614 cx.background_spawn(async move {
1615 let final_project_snapshot = final_project_snapshot.await;
1616 let serialized_thread = serialized_thread.await?;
1617 let thread_data = serde_json::to_value(serialized_thread)
1618 .unwrap_or_else(|_| serde_json::Value::Null);
1619
1620 let rating = match feedback {
1621 ThreadFeedback::Positive => "positive",
1622 ThreadFeedback::Negative => "negative",
1623 };
1624 telemetry::event!(
1625 "Assistant Thread Rated",
1626 rating,
1627 thread_id,
1628 thread_data,
1629 final_project_snapshot
1630 );
1631 client.telemetry().flush_events();
1632
1633 Ok(())
1634 })
1635 }
1636 }
1637
1638 /// Create a snapshot of the current project state including git information and unsaved buffers.
1639 fn project_snapshot(
1640 project: Entity<Project>,
1641 cx: &mut Context<Self>,
1642 ) -> Task<Arc<ProjectSnapshot>> {
1643 let git_store = project.read(cx).git_store().clone();
1644 let worktree_snapshots: Vec<_> = project
1645 .read(cx)
1646 .visible_worktrees(cx)
1647 .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1648 .collect();
1649
1650 cx.spawn(async move |_, cx| {
1651 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1652
1653 let mut unsaved_buffers = Vec::new();
1654 cx.update(|app_cx| {
1655 let buffer_store = project.read(app_cx).buffer_store();
1656 for buffer_handle in buffer_store.read(app_cx).buffers() {
1657 let buffer = buffer_handle.read(app_cx);
1658 if buffer.is_dirty() {
1659 if let Some(file) = buffer.file() {
1660 let path = file.path().to_string_lossy().to_string();
1661 unsaved_buffers.push(path);
1662 }
1663 }
1664 }
1665 })
1666 .ok();
1667
1668 Arc::new(ProjectSnapshot {
1669 worktree_snapshots,
1670 unsaved_buffer_paths: unsaved_buffers,
1671 timestamp: Utc::now(),
1672 })
1673 })
1674 }
1675
1676 fn worktree_snapshot(
1677 worktree: Entity<project::Worktree>,
1678 git_store: Entity<GitStore>,
1679 cx: &App,
1680 ) -> Task<WorktreeSnapshot> {
1681 cx.spawn(async move |cx| {
1682 // Get worktree path and snapshot
1683 let worktree_info = cx.update(|app_cx| {
1684 let worktree = worktree.read(app_cx);
1685 let path = worktree.abs_path().to_string_lossy().to_string();
1686 let snapshot = worktree.snapshot();
1687 (path, snapshot)
1688 });
1689
1690 let Ok((worktree_path, _snapshot)) = worktree_info else {
1691 return WorktreeSnapshot {
1692 worktree_path: String::new(),
1693 git_state: None,
1694 };
1695 };
1696
1697 let git_state = git_store
1698 .update(cx, |git_store, cx| {
1699 git_store
1700 .repositories()
1701 .values()
1702 .find(|repo| {
1703 repo.read(cx)
1704 .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1705 .is_some()
1706 })
1707 .cloned()
1708 })
1709 .ok()
1710 .flatten()
1711 .map(|repo| {
1712 repo.update(cx, |repo, _| {
1713 let current_branch =
1714 repo.branch.as_ref().map(|branch| branch.name.to_string());
1715 repo.send_job(None, |state, _| async move {
1716 let RepositoryState::Local { backend, .. } = state else {
1717 return GitState {
1718 remote_url: None,
1719 head_sha: None,
1720 current_branch,
1721 diff: None,
1722 };
1723 };
1724
1725 let remote_url = backend.remote_url("origin");
1726 let head_sha = backend.head_sha();
1727 let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1728
1729 GitState {
1730 remote_url,
1731 head_sha,
1732 current_branch,
1733 diff,
1734 }
1735 })
1736 })
1737 });
1738
1739 let git_state = match git_state {
1740 Some(git_state) => match git_state.ok() {
1741 Some(git_state) => git_state.await.ok(),
1742 None => None,
1743 },
1744 None => None,
1745 };
1746
1747 WorktreeSnapshot {
1748 worktree_path,
1749 git_state,
1750 }
1751 })
1752 }
1753
1754 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1755 let mut markdown = Vec::new();
1756
1757 if let Some(summary) = self.summary() {
1758 writeln!(markdown, "# {summary}\n")?;
1759 };
1760
1761 for message in self.messages() {
1762 writeln!(
1763 markdown,
1764 "## {role}\n",
1765 role = match message.role {
1766 Role::User => "User",
1767 Role::Assistant => "Assistant",
1768 Role::System => "System",
1769 }
1770 )?;
1771
1772 if !message.context.is_empty() {
1773 writeln!(markdown, "{}", message.context)?;
1774 }
1775
1776 for segment in &message.segments {
1777 match segment {
1778 MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1779 MessageSegment::Thinking(text) => {
1780 writeln!(markdown, "<think>{}</think>\n", text)?
1781 }
1782 }
1783 }
1784
1785 for tool_use in self.tool_uses_for_message(message.id, cx) {
1786 writeln!(
1787 markdown,
1788 "**Use Tool: {} ({})**",
1789 tool_use.name, tool_use.id
1790 )?;
1791 writeln!(markdown, "```json")?;
1792 writeln!(
1793 markdown,
1794 "{}",
1795 serde_json::to_string_pretty(&tool_use.input)?
1796 )?;
1797 writeln!(markdown, "```")?;
1798 }
1799
1800 for tool_result in self.tool_results_for_message(message.id) {
1801 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1802 if tool_result.is_error {
1803 write!(markdown, " (Error)")?;
1804 }
1805
1806 writeln!(markdown, "**\n")?;
1807 writeln!(markdown, "{}", tool_result.content)?;
1808 }
1809 }
1810
1811 Ok(String::from_utf8_lossy(&markdown).to_string())
1812 }
1813
1814 pub fn keep_edits_in_range(
1815 &mut self,
1816 buffer: Entity<language::Buffer>,
1817 buffer_range: Range<language::Anchor>,
1818 cx: &mut Context<Self>,
1819 ) {
1820 self.action_log.update(cx, |action_log, cx| {
1821 action_log.keep_edits_in_range(buffer, buffer_range, cx)
1822 });
1823 }
1824
1825 pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1826 self.action_log
1827 .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1828 }
1829
1830 pub fn reject_edits_in_range(
1831 &mut self,
1832 buffer: Entity<language::Buffer>,
1833 buffer_range: Range<language::Anchor>,
1834 cx: &mut Context<Self>,
1835 ) -> Task<Result<()>> {
1836 self.action_log.update(cx, |action_log, cx| {
1837 action_log.reject_edits_in_range(buffer, buffer_range, cx)
1838 })
1839 }
1840
1841 pub fn action_log(&self) -> &Entity<ActionLog> {
1842 &self.action_log
1843 }
1844
1845 pub fn project(&self) -> &Entity<Project> {
1846 &self.project
1847 }
1848
1849 pub fn cumulative_token_usage(&self) -> TokenUsage {
1850 self.cumulative_token_usage.clone()
1851 }
1852
1853 pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1854 let model_registry = LanguageModelRegistry::read_global(cx);
1855 let Some(model) = model_registry.default_model() else {
1856 return TotalTokenUsage::default();
1857 };
1858
1859 let max = model.model.max_token_count();
1860
1861 #[cfg(debug_assertions)]
1862 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1863 .unwrap_or("0.8".to_string())
1864 .parse()
1865 .unwrap();
1866 #[cfg(not(debug_assertions))]
1867 let warning_threshold: f32 = 0.8;
1868
1869 let total = self.cumulative_token_usage.total_tokens() as usize;
1870
1871 let ratio = if total >= max {
1872 TokenUsageRatio::Exceeded
1873 } else if total as f32 / max as f32 >= warning_threshold {
1874 TokenUsageRatio::Warning
1875 } else {
1876 TokenUsageRatio::Normal
1877 };
1878
1879 TotalTokenUsage { total, max, ratio }
1880 }
1881
1882 pub fn deny_tool_use(
1883 &mut self,
1884 tool_use_id: LanguageModelToolUseId,
1885 tool_name: Arc<str>,
1886 cx: &mut Context<Self>,
1887 ) {
1888 let err = Err(anyhow::anyhow!(
1889 "Permission to run tool action denied by user"
1890 ));
1891
1892 self.tool_use
1893 .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1894 self.tool_finished(tool_use_id.clone(), None, true, cx);
1895 }
1896}
1897
1898#[derive(Debug, Clone)]
1899pub enum ThreadError {
1900 PaymentRequired,
1901 MaxMonthlySpendReached,
1902 Message {
1903 header: SharedString,
1904 message: SharedString,
1905 },
1906}
1907
1908#[derive(Debug, Clone)]
1909pub enum ThreadEvent {
1910 ShowError(ThreadError),
1911 StreamedCompletion,
1912 StreamedAssistantText(MessageId, String),
1913 StreamedAssistantThinking(MessageId, String),
1914 DoneStreaming,
1915 MessageAdded(MessageId),
1916 MessageEdited(MessageId),
1917 MessageDeleted(MessageId),
1918 SummaryGenerated,
1919 SummaryChanged,
1920 UsePendingTools {
1921 tool_uses: Vec<PendingToolUse>,
1922 },
1923 ToolFinished {
1924 #[allow(unused)]
1925 tool_use_id: LanguageModelToolUseId,
1926 /// The pending tool use that corresponds to this tool.
1927 pending_tool_use: Option<PendingToolUse>,
1928 },
1929 CheckpointChanged,
1930 ToolConfirmationNeeded,
1931}
1932
1933impl EventEmitter<ThreadEvent> for Thread {}
1934
1935struct PendingCompletion {
1936 id: usize,
1937 _task: Task<()>,
1938}
1939
1940#[cfg(test)]
1941mod tests {
1942 use super::*;
1943 use crate::{ThreadStore, context_store::ContextStore, thread_store};
1944 use assistant_settings::AssistantSettings;
1945 use context_server::ContextServerSettings;
1946 use editor::EditorSettings;
1947 use gpui::TestAppContext;
1948 use project::{FakeFs, Project};
1949 use prompt_store::PromptBuilder;
1950 use serde_json::json;
1951 use settings::{Settings, SettingsStore};
1952 use std::sync::Arc;
1953 use theme::ThemeSettings;
1954 use util::path;
1955 use workspace::Workspace;
1956
1957 #[gpui::test]
1958 async fn test_message_with_context(cx: &mut TestAppContext) {
1959 init_test_settings(cx);
1960
1961 let project = create_test_project(
1962 cx,
1963 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
1964 )
1965 .await;
1966
1967 let (_workspace, _thread_store, thread, context_store) =
1968 setup_test_environment(cx, project.clone()).await;
1969
1970 add_file_to_context(&project, &context_store, "test/code.rs", cx)
1971 .await
1972 .unwrap();
1973
1974 let context =
1975 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1976
1977 // Insert user message with context
1978 let message_id = thread.update(cx, |thread, cx| {
1979 thread.insert_user_message("Please explain this code", vec![context], None, cx)
1980 });
1981
1982 // Check content and context in message object
1983 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1984
1985 // Use different path format strings based on platform for the test
1986 #[cfg(windows)]
1987 let path_part = r"test\code.rs";
1988 #[cfg(not(windows))]
1989 let path_part = "test/code.rs";
1990
1991 let expected_context = format!(
1992 r#"
1993<context>
1994The following items were attached by the user. You don't need to use other tools to read them.
1995
1996<files>
1997```rs {path_part}
1998fn main() {{
1999 println!("Hello, world!");
2000}}
2001```
2002</files>
2003</context>
2004"#
2005 );
2006
2007 assert_eq!(message.role, Role::User);
2008 assert_eq!(message.segments.len(), 1);
2009 assert_eq!(
2010 message.segments[0],
2011 MessageSegment::Text("Please explain this code".to_string())
2012 );
2013 assert_eq!(message.context, expected_context);
2014
2015 // Check message in request
2016 let request = thread.read_with(cx, |thread, cx| {
2017 thread.to_completion_request(RequestKind::Chat, cx)
2018 });
2019
2020 assert_eq!(request.messages.len(), 1);
2021 let expected_full_message = format!("{}Please explain this code", expected_context);
2022 assert_eq!(request.messages[0].string_contents(), expected_full_message);
2023 }
2024
2025 #[gpui::test]
2026 async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2027 init_test_settings(cx);
2028
2029 let project = create_test_project(
2030 cx,
2031 json!({
2032 "file1.rs": "fn function1() {}\n",
2033 "file2.rs": "fn function2() {}\n",
2034 "file3.rs": "fn function3() {}\n",
2035 }),
2036 )
2037 .await;
2038
2039 let (_, _thread_store, thread, context_store) =
2040 setup_test_environment(cx, project.clone()).await;
2041
2042 // Open files individually
2043 add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2044 .await
2045 .unwrap();
2046 add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2047 .await
2048 .unwrap();
2049 add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2050 .await
2051 .unwrap();
2052
2053 // Get the context objects
2054 let contexts = context_store.update(cx, |store, _| store.context().clone());
2055 assert_eq!(contexts.len(), 3);
2056
2057 // First message with context 1
2058 let message1_id = thread.update(cx, |thread, cx| {
2059 thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2060 });
2061
2062 // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2063 let message2_id = thread.update(cx, |thread, cx| {
2064 thread.insert_user_message(
2065 "Message 2",
2066 vec![contexts[0].clone(), contexts[1].clone()],
2067 None,
2068 cx,
2069 )
2070 });
2071
2072 // Third message with all three contexts (contexts 1 and 2 should be skipped)
2073 let message3_id = thread.update(cx, |thread, cx| {
2074 thread.insert_user_message(
2075 "Message 3",
2076 vec![
2077 contexts[0].clone(),
2078 contexts[1].clone(),
2079 contexts[2].clone(),
2080 ],
2081 None,
2082 cx,
2083 )
2084 });
2085
2086 // Check what contexts are included in each message
2087 let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2088 (
2089 thread.message(message1_id).unwrap().clone(),
2090 thread.message(message2_id).unwrap().clone(),
2091 thread.message(message3_id).unwrap().clone(),
2092 )
2093 });
2094
2095 // First message should include context 1
2096 assert!(message1.context.contains("file1.rs"));
2097
2098 // Second message should include only context 2 (not 1)
2099 assert!(!message2.context.contains("file1.rs"));
2100 assert!(message2.context.contains("file2.rs"));
2101
2102 // Third message should include only context 3 (not 1 or 2)
2103 assert!(!message3.context.contains("file1.rs"));
2104 assert!(!message3.context.contains("file2.rs"));
2105 assert!(message3.context.contains("file3.rs"));
2106
2107 // Check entire request to make sure all contexts are properly included
2108 let request = thread.read_with(cx, |thread, cx| {
2109 thread.to_completion_request(RequestKind::Chat, cx)
2110 });
2111
2112 // The request should contain all 3 messages
2113 assert_eq!(request.messages.len(), 3);
2114
2115 // Check that the contexts are properly formatted in each message
2116 assert!(request.messages[0].string_contents().contains("file1.rs"));
2117 assert!(!request.messages[0].string_contents().contains("file2.rs"));
2118 assert!(!request.messages[0].string_contents().contains("file3.rs"));
2119
2120 assert!(!request.messages[1].string_contents().contains("file1.rs"));
2121 assert!(request.messages[1].string_contents().contains("file2.rs"));
2122 assert!(!request.messages[1].string_contents().contains("file3.rs"));
2123
2124 assert!(!request.messages[2].string_contents().contains("file1.rs"));
2125 assert!(!request.messages[2].string_contents().contains("file2.rs"));
2126 assert!(request.messages[2].string_contents().contains("file3.rs"));
2127 }
2128
2129 #[gpui::test]
2130 async fn test_message_without_files(cx: &mut TestAppContext) {
2131 init_test_settings(cx);
2132
2133 let project = create_test_project(
2134 cx,
2135 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2136 )
2137 .await;
2138
2139 let (_, _thread_store, thread, _context_store) =
2140 setup_test_environment(cx, project.clone()).await;
2141
2142 // Insert user message without any context (empty context vector)
2143 let message_id = thread.update(cx, |thread, cx| {
2144 thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2145 });
2146
2147 // Check content and context in message object
2148 let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2149
2150 // Context should be empty when no files are included
2151 assert_eq!(message.role, Role::User);
2152 assert_eq!(message.segments.len(), 1);
2153 assert_eq!(
2154 message.segments[0],
2155 MessageSegment::Text("What is the best way to learn Rust?".to_string())
2156 );
2157 assert_eq!(message.context, "");
2158
2159 // Check message in request
2160 let request = thread.read_with(cx, |thread, cx| {
2161 thread.to_completion_request(RequestKind::Chat, cx)
2162 });
2163
2164 assert_eq!(request.messages.len(), 1);
2165 assert_eq!(
2166 request.messages[0].string_contents(),
2167 "What is the best way to learn Rust?"
2168 );
2169
2170 // Add second message, also without context
2171 let message2_id = thread.update(cx, |thread, cx| {
2172 thread.insert_user_message("Are there any good books?", vec![], None, cx)
2173 });
2174
2175 let message2 =
2176 thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2177 assert_eq!(message2.context, "");
2178
2179 // Check that both messages appear in the request
2180 let request = thread.read_with(cx, |thread, cx| {
2181 thread.to_completion_request(RequestKind::Chat, cx)
2182 });
2183
2184 assert_eq!(request.messages.len(), 2);
2185 assert_eq!(
2186 request.messages[0].string_contents(),
2187 "What is the best way to learn Rust?"
2188 );
2189 assert_eq!(
2190 request.messages[1].string_contents(),
2191 "Are there any good books?"
2192 );
2193 }
2194
2195 #[gpui::test]
2196 async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2197 init_test_settings(cx);
2198
2199 let project = create_test_project(
2200 cx,
2201 json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
2202 )
2203 .await;
2204
2205 let (_workspace, _thread_store, thread, context_store) =
2206 setup_test_environment(cx, project.clone()).await;
2207
2208 // Open buffer and add it to context
2209 let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2210 .await
2211 .unwrap();
2212
2213 let context =
2214 context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2215
2216 // Insert user message with the buffer as context
2217 thread.update(cx, |thread, cx| {
2218 thread.insert_user_message("Explain this code", vec![context], None, cx)
2219 });
2220
2221 // Create a request and check that it doesn't have a stale buffer warning yet
2222 let initial_request = thread.read_with(cx, |thread, cx| {
2223 thread.to_completion_request(RequestKind::Chat, cx)
2224 });
2225
2226 // Make sure we don't have a stale file warning yet
2227 let has_stale_warning = initial_request.messages.iter().any(|msg| {
2228 msg.string_contents()
2229 .contains("These files changed since last read:")
2230 });
2231 assert!(
2232 !has_stale_warning,
2233 "Should not have stale buffer warning before buffer is modified"
2234 );
2235
2236 // Modify the buffer
2237 buffer.update(cx, |buffer, cx| {
2238 // Find a position at the end of line 1
2239 buffer.edit(
2240 [(1..1, "\n println!(\"Added a new line\");\n")],
2241 None,
2242 cx,
2243 );
2244 });
2245
2246 // Insert another user message without context
2247 thread.update(cx, |thread, cx| {
2248 thread.insert_user_message("What does the code do now?", vec![], None, cx)
2249 });
2250
2251 // Create a new request and check for the stale buffer warning
2252 let new_request = thread.read_with(cx, |thread, cx| {
2253 thread.to_completion_request(RequestKind::Chat, cx)
2254 });
2255
2256 // We should have a stale file warning as the last message
2257 let last_message = new_request
2258 .messages
2259 .last()
2260 .expect("Request should have messages");
2261
2262 // The last message should be the stale buffer notification
2263 assert_eq!(last_message.role, Role::User);
2264
2265 // Check the exact content of the message
2266 let expected_content = "These files changed since last read:\n- code.rs\n";
2267 assert_eq!(
2268 last_message.string_contents(),
2269 expected_content,
2270 "Last message should be exactly the stale buffer notification"
2271 );
2272 }
2273
2274 fn init_test_settings(cx: &mut TestAppContext) {
2275 cx.update(|cx| {
2276 let settings_store = SettingsStore::test(cx);
2277 cx.set_global(settings_store);
2278 language::init(cx);
2279 Project::init_settings(cx);
2280 AssistantSettings::register(cx);
2281 thread_store::init(cx);
2282 workspace::init_settings(cx);
2283 ThemeSettings::register(cx);
2284 ContextServerSettings::register(cx);
2285 EditorSettings::register(cx);
2286 });
2287 }
2288
2289 // Helper to create a test project with test files
2290 async fn create_test_project(
2291 cx: &mut TestAppContext,
2292 files: serde_json::Value,
2293 ) -> Entity<Project> {
2294 let fs = FakeFs::new(cx.executor());
2295 fs.insert_tree(path!("/test"), files).await;
2296 Project::test(fs, [path!("/test").as_ref()], cx).await
2297 }
2298
2299 async fn setup_test_environment(
2300 cx: &mut TestAppContext,
2301 project: Entity<Project>,
2302 ) -> (
2303 Entity<Workspace>,
2304 Entity<ThreadStore>,
2305 Entity<Thread>,
2306 Entity<ContextStore>,
2307 ) {
2308 let (workspace, cx) =
2309 cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2310
2311 let thread_store = cx.update(|_, cx| {
2312 ThreadStore::new(
2313 project.clone(),
2314 Arc::default(),
2315 Arc::new(PromptBuilder::new(None).unwrap()),
2316 cx,
2317 )
2318 .unwrap()
2319 });
2320
2321 let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2322 let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2323
2324 (workspace, thread_store, thread, context_store)
2325 }
2326
2327 async fn add_file_to_context(
2328 project: &Entity<Project>,
2329 context_store: &Entity<ContextStore>,
2330 path: &str,
2331 cx: &mut TestAppContext,
2332 ) -> Result<Entity<language::Buffer>> {
2333 let buffer_path = project
2334 .read_with(cx, |project, cx| project.find_project_path(path, cx))
2335 .unwrap();
2336
2337 let buffer = project
2338 .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2339 .await
2340 .unwrap();
2341
2342 context_store
2343 .update(cx, |store, cx| {
2344 store.add_file_from_buffer(buffer.clone(), cx)
2345 })
2346 .await?;
2347
2348 Ok(buffer)
2349 }
2350}