1use std::fmt::Write as _;
2use std::io::Write;
3use std::sync::Arc;
4
5use anyhow::{Context as _, Result};
6use assistant_tool::{ActionLog, ToolWorkingSet};
7use chrono::{DateTime, Utc};
8use collections::{BTreeMap, HashMap, HashSet};
9use fs::Fs;
10use futures::future::Shared;
11use futures::{FutureExt, StreamExt as _};
12use git;
13use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
14use language_model::{
15 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
16 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
17 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
18 Role, StopReason, TokenUsage,
19};
20use project::git::GitStoreCheckpoint;
21use project::{Project, Worktree};
22use prompt_store::{
23 AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
24};
25use scripting_tool::{ScriptingSession, ScriptingTool};
26use serde::{Deserialize, Serialize};
27use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
28use uuid::Uuid;
29
30use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
31use crate::thread_store::{
32 SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
33};
34use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
35
36#[derive(Debug, Clone, Copy)]
37pub enum RequestKind {
38 Chat,
39 /// Used when summarizing a thread.
40 Summarize,
41}
42
43#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
44pub struct ThreadId(Arc<str>);
45
46impl ThreadId {
47 pub fn new() -> Self {
48 Self(Uuid::new_v4().to_string().into())
49 }
50}
51
52impl std::fmt::Display for ThreadId {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "{}", self.0)
55 }
56}
57
58#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
59pub struct MessageId(pub(crate) usize);
60
61impl MessageId {
62 fn post_inc(&mut self) -> Self {
63 Self(post_inc(&mut self.0))
64 }
65}
66
67/// A message in a [`Thread`].
68#[derive(Debug, Clone)]
69pub struct Message {
70 pub id: MessageId,
71 pub role: Role,
72 pub text: String,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ProjectSnapshot {
77 pub worktree_snapshots: Vec<WorktreeSnapshot>,
78 pub unsaved_buffer_paths: Vec<String>,
79 pub timestamp: DateTime<Utc>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct WorktreeSnapshot {
84 pub worktree_path: String,
85 pub git_state: Option<GitState>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct GitState {
90 pub remote_url: Option<String>,
91 pub head_sha: Option<String>,
92 pub current_branch: Option<String>,
93 pub diff: Option<String>,
94}
95
96#[derive(Clone)]
97pub struct ThreadCheckpoint {
98 message_id: MessageId,
99 git_checkpoint: GitStoreCheckpoint,
100}
101
102pub enum LastRestoreCheckpoint {
103 Pending {
104 message_id: MessageId,
105 },
106 Error {
107 message_id: MessageId,
108 error: String,
109 },
110}
111
112impl LastRestoreCheckpoint {
113 pub fn message_id(&self) -> MessageId {
114 match self {
115 LastRestoreCheckpoint::Pending { message_id } => *message_id,
116 LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
117 }
118 }
119}
120
121/// A thread of conversation with the LLM.
122pub struct Thread {
123 id: ThreadId,
124 updated_at: DateTime<Utc>,
125 summary: Option<SharedString>,
126 pending_summary: Task<Option<()>>,
127 messages: Vec<Message>,
128 next_message_id: MessageId,
129 context: BTreeMap<ContextId, ContextSnapshot>,
130 context_by_message: HashMap<MessageId, Vec<ContextId>>,
131 system_prompt_context: Option<AssistantSystemPromptContext>,
132 checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
133 completion_count: usize,
134 pending_completions: Vec<PendingCompletion>,
135 project: Entity<Project>,
136 prompt_builder: Arc<PromptBuilder>,
137 tools: Arc<ToolWorkingSet>,
138 tool_use: ToolUseState,
139 action_log: Entity<ActionLog>,
140 last_restore_checkpoint: Option<LastRestoreCheckpoint>,
141 scripting_session: Entity<ScriptingSession>,
142 scripting_tool_use: ToolUseState,
143 initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
144 cumulative_token_usage: TokenUsage,
145}
146
147impl Thread {
148 pub fn new(
149 project: Entity<Project>,
150 tools: Arc<ToolWorkingSet>,
151 prompt_builder: Arc<PromptBuilder>,
152 cx: &mut Context<Self>,
153 ) -> Self {
154 Self {
155 id: ThreadId::new(),
156 updated_at: Utc::now(),
157 summary: None,
158 pending_summary: Task::ready(None),
159 messages: Vec::new(),
160 next_message_id: MessageId(0),
161 context: BTreeMap::default(),
162 context_by_message: HashMap::default(),
163 system_prompt_context: None,
164 checkpoints_by_message: HashMap::default(),
165 completion_count: 0,
166 pending_completions: Vec::new(),
167 project: project.clone(),
168 prompt_builder,
169 tools: tools.clone(),
170 last_restore_checkpoint: None,
171 tool_use: ToolUseState::new(tools.clone()),
172 scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
173 scripting_tool_use: ToolUseState::new(tools),
174 action_log: cx.new(|_| ActionLog::new()),
175 initial_project_snapshot: {
176 let project_snapshot = Self::project_snapshot(project, cx);
177 cx.foreground_executor()
178 .spawn(async move { Some(project_snapshot.await) })
179 .shared()
180 },
181 cumulative_token_usage: TokenUsage::default(),
182 }
183 }
184
185 pub fn deserialize(
186 id: ThreadId,
187 serialized: SerializedThread,
188 project: Entity<Project>,
189 tools: Arc<ToolWorkingSet>,
190 prompt_builder: Arc<PromptBuilder>,
191 cx: &mut Context<Self>,
192 ) -> Self {
193 let next_message_id = MessageId(
194 serialized
195 .messages
196 .last()
197 .map(|message| message.id.0 + 1)
198 .unwrap_or(0),
199 );
200 let tool_use =
201 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
202 name != ScriptingTool::NAME
203 });
204 let scripting_tool_use =
205 ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| {
206 name == ScriptingTool::NAME
207 });
208 let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
209
210 Self {
211 id,
212 updated_at: serialized.updated_at,
213 summary: Some(serialized.summary),
214 pending_summary: Task::ready(None),
215 messages: serialized
216 .messages
217 .into_iter()
218 .map(|message| Message {
219 id: message.id,
220 role: message.role,
221 text: message.text,
222 })
223 .collect(),
224 next_message_id,
225 context: BTreeMap::default(),
226 context_by_message: HashMap::default(),
227 system_prompt_context: None,
228 checkpoints_by_message: HashMap::default(),
229 completion_count: 0,
230 pending_completions: Vec::new(),
231 last_restore_checkpoint: None,
232 project,
233 prompt_builder,
234 tools,
235 tool_use,
236 action_log: cx.new(|_| ActionLog::new()),
237 scripting_session,
238 scripting_tool_use,
239 initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
240 // TODO: persist token usage?
241 cumulative_token_usage: TokenUsage::default(),
242 }
243 }
244
245 pub fn id(&self) -> &ThreadId {
246 &self.id
247 }
248
249 pub fn is_empty(&self) -> bool {
250 self.messages.is_empty()
251 }
252
253 pub fn updated_at(&self) -> DateTime<Utc> {
254 self.updated_at
255 }
256
257 pub fn touch_updated_at(&mut self) {
258 self.updated_at = Utc::now();
259 }
260
261 pub fn summary(&self) -> Option<SharedString> {
262 self.summary.clone()
263 }
264
265 pub fn summary_or_default(&self) -> SharedString {
266 const DEFAULT: SharedString = SharedString::new_static("New Thread");
267 self.summary.clone().unwrap_or(DEFAULT)
268 }
269
270 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
271 self.summary = Some(summary.into());
272 cx.emit(ThreadEvent::SummaryChanged);
273 }
274
275 pub fn message(&self, id: MessageId) -> Option<&Message> {
276 self.messages.iter().find(|message| message.id == id)
277 }
278
279 pub fn messages(&self) -> impl Iterator<Item = &Message> {
280 self.messages.iter()
281 }
282
283 pub fn is_generating(&self) -> bool {
284 !self.pending_completions.is_empty() || !self.all_tools_finished()
285 }
286
287 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
288 &self.tools
289 }
290
291 pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
292 let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
293 Some(ThreadCheckpoint {
294 message_id: id,
295 git_checkpoint: checkpoint,
296 })
297 }
298
299 pub fn restore_checkpoint(
300 &mut self,
301 checkpoint: ThreadCheckpoint,
302 cx: &mut Context<Self>,
303 ) -> Task<Result<()>> {
304 self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
305 message_id: checkpoint.message_id,
306 });
307 cx.emit(ThreadEvent::CheckpointChanged);
308
309 let project = self.project.read(cx);
310 let restore = project
311 .git_store()
312 .read(cx)
313 .restore_checkpoint(checkpoint.git_checkpoint, cx);
314 cx.spawn(async move |this, cx| {
315 let result = restore.await;
316 this.update(cx, |this, cx| {
317 if let Err(err) = result.as_ref() {
318 this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
319 message_id: checkpoint.message_id,
320 error: err.to_string(),
321 });
322 } else {
323 this.last_restore_checkpoint = None;
324 this.truncate(checkpoint.message_id, cx);
325 }
326 cx.emit(ThreadEvent::CheckpointChanged);
327 })?;
328 result
329 })
330 }
331
332 pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
333 self.last_restore_checkpoint.as_ref()
334 }
335
336 pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
337 let Some(message_ix) = self
338 .messages
339 .iter()
340 .rposition(|message| message.id == message_id)
341 else {
342 return;
343 };
344 for deleted_message in self.messages.drain(message_ix..) {
345 self.context_by_message.remove(&deleted_message.id);
346 self.checkpoints_by_message.remove(&deleted_message.id);
347 }
348 cx.notify();
349 }
350
351 pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
352 let context = self.context_by_message.get(&id)?;
353 Some(
354 context
355 .into_iter()
356 .filter_map(|context_id| self.context.get(&context_id))
357 .cloned()
358 .collect::<Vec<_>>(),
359 )
360 }
361
362 /// Returns whether all of the tool uses have finished running.
363 pub fn all_tools_finished(&self) -> bool {
364 let mut all_pending_tool_uses = self
365 .tool_use
366 .pending_tool_uses()
367 .into_iter()
368 .chain(self.scripting_tool_use.pending_tool_uses());
369
370 // If the only pending tool uses left are the ones with errors, then
371 // that means that we've finished running all of the pending tools.
372 all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
373 }
374
375 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
376 self.tool_use.tool_uses_for_message(id, cx)
377 }
378
379 pub fn scripting_tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
380 self.scripting_tool_use.tool_uses_for_message(id, cx)
381 }
382
383 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
384 self.tool_use.tool_results_for_message(id)
385 }
386
387 pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
388 self.tool_use.tool_result(id)
389 }
390
391 pub fn scripting_tool_results_for_message(
392 &self,
393 id: MessageId,
394 ) -> Vec<&LanguageModelToolResult> {
395 self.scripting_tool_use.tool_results_for_message(id)
396 }
397
398 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
399 self.tool_use.message_has_tool_results(message_id)
400 }
401
402 pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
403 self.scripting_tool_use.message_has_tool_results(message_id)
404 }
405
406 pub fn insert_user_message(
407 &mut self,
408 text: impl Into<String>,
409 context: Vec<ContextSnapshot>,
410 checkpoint: Option<GitStoreCheckpoint>,
411 cx: &mut Context<Self>,
412 ) -> MessageId {
413 let message_id = self.insert_message(Role::User, text, cx);
414 let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
415 self.context
416 .extend(context.into_iter().map(|context| (context.id, context)));
417 self.context_by_message.insert(message_id, context_ids);
418 if let Some(checkpoint) = checkpoint {
419 self.checkpoints_by_message.insert(message_id, checkpoint);
420 }
421 message_id
422 }
423
424 pub fn insert_message(
425 &mut self,
426 role: Role,
427 text: impl Into<String>,
428 cx: &mut Context<Self>,
429 ) -> MessageId {
430 let id = self.next_message_id.post_inc();
431 self.messages.push(Message {
432 id,
433 role,
434 text: text.into(),
435 });
436 self.touch_updated_at();
437 cx.emit(ThreadEvent::MessageAdded(id));
438 id
439 }
440
441 pub fn edit_message(
442 &mut self,
443 id: MessageId,
444 new_role: Role,
445 new_text: String,
446 cx: &mut Context<Self>,
447 ) -> bool {
448 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
449 return false;
450 };
451 message.role = new_role;
452 message.text = new_text;
453 self.touch_updated_at();
454 cx.emit(ThreadEvent::MessageEdited(id));
455 true
456 }
457
458 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
459 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
460 return false;
461 };
462 self.messages.remove(index);
463 self.context_by_message.remove(&id);
464 self.touch_updated_at();
465 cx.emit(ThreadEvent::MessageDeleted(id));
466 true
467 }
468
469 /// Returns the representation of this [`Thread`] in a textual form.
470 ///
471 /// This is the representation we use when attaching a thread as context to another thread.
472 pub fn text(&self) -> String {
473 let mut text = String::new();
474
475 for message in &self.messages {
476 text.push_str(match message.role {
477 language_model::Role::User => "User:",
478 language_model::Role::Assistant => "Assistant:",
479 language_model::Role::System => "System:",
480 });
481 text.push('\n');
482
483 text.push_str(&message.text);
484 text.push('\n');
485 }
486
487 text
488 }
489
490 /// Serializes this thread into a format for storage or telemetry.
491 pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
492 let initial_project_snapshot = self.initial_project_snapshot.clone();
493 cx.spawn(async move |this, cx| {
494 let initial_project_snapshot = initial_project_snapshot.await;
495 this.read_with(cx, |this, cx| SerializedThread {
496 summary: this.summary_or_default(),
497 updated_at: this.updated_at(),
498 messages: this
499 .messages()
500 .map(|message| SerializedMessage {
501 id: message.id,
502 role: message.role,
503 text: message.text.clone(),
504 tool_uses: this
505 .tool_uses_for_message(message.id, cx)
506 .into_iter()
507 .chain(this.scripting_tool_uses_for_message(message.id, cx))
508 .map(|tool_use| SerializedToolUse {
509 id: tool_use.id,
510 name: tool_use.name,
511 input: tool_use.input,
512 })
513 .collect(),
514 tool_results: this
515 .tool_results_for_message(message.id)
516 .into_iter()
517 .chain(this.scripting_tool_results_for_message(message.id))
518 .map(|tool_result| SerializedToolResult {
519 tool_use_id: tool_result.tool_use_id.clone(),
520 is_error: tool_result.is_error,
521 content: tool_result.content.clone(),
522 })
523 .collect(),
524 })
525 .collect(),
526 initial_project_snapshot,
527 })
528 })
529 }
530
531 pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
532 self.system_prompt_context = Some(context);
533 }
534
535 pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
536 &self.system_prompt_context
537 }
538
539 pub fn load_system_prompt_context(
540 &self,
541 cx: &App,
542 ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
543 let project = self.project.read(cx);
544 let tasks = project
545 .visible_worktrees(cx)
546 .map(|worktree| {
547 Self::load_worktree_info_for_system_prompt(
548 project.fs().clone(),
549 worktree.read(cx),
550 cx,
551 )
552 })
553 .collect::<Vec<_>>();
554
555 cx.spawn(async |_cx| {
556 let results = futures::future::join_all(tasks).await;
557 let mut first_err = None;
558 let worktrees = results
559 .into_iter()
560 .map(|(worktree, err)| {
561 if first_err.is_none() && err.is_some() {
562 first_err = err;
563 }
564 worktree
565 })
566 .collect::<Vec<_>>();
567 (AssistantSystemPromptContext::new(worktrees), first_err)
568 })
569 }
570
571 fn load_worktree_info_for_system_prompt(
572 fs: Arc<dyn Fs>,
573 worktree: &Worktree,
574 cx: &App,
575 ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
576 let root_name = worktree.root_name().into();
577 let abs_path = worktree.abs_path();
578
579 // Note that Cline supports `.clinerules` being a directory, but that is not currently
580 // supported. This doesn't seem to occur often in GitHub repositories.
581 const RULES_FILE_NAMES: [&'static str; 5] = [
582 ".rules",
583 ".cursorrules",
584 ".windsurfrules",
585 ".clinerules",
586 "CLAUDE.md",
587 ];
588 let selected_rules_file = RULES_FILE_NAMES
589 .into_iter()
590 .filter_map(|name| {
591 worktree
592 .entry_for_path(name)
593 .filter(|entry| entry.is_file())
594 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
595 })
596 .next();
597
598 if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
599 cx.spawn(async move |_| {
600 let rules_file_result = maybe!(async move {
601 let abs_rules_path = abs_rules_path?;
602 let text = fs.load(&abs_rules_path).await.with_context(|| {
603 format!("Failed to load assistant rules file {:?}", abs_rules_path)
604 })?;
605 anyhow::Ok(RulesFile {
606 rel_path: rel_rules_path,
607 abs_path: abs_rules_path.into(),
608 text: text.trim().to_string(),
609 })
610 })
611 .await;
612 let (rules_file, rules_file_error) = match rules_file_result {
613 Ok(rules_file) => (Some(rules_file), None),
614 Err(err) => (
615 None,
616 Some(ThreadError::Message {
617 header: "Error loading rules file".into(),
618 message: format!("{err}").into(),
619 }),
620 ),
621 };
622 let worktree_info = WorktreeInfoForSystemPrompt {
623 root_name,
624 abs_path,
625 rules_file,
626 };
627 (worktree_info, rules_file_error)
628 })
629 } else {
630 Task::ready((
631 WorktreeInfoForSystemPrompt {
632 root_name,
633 abs_path,
634 rules_file: None,
635 },
636 None,
637 ))
638 }
639 }
640
641 pub fn send_to_model(
642 &mut self,
643 model: Arc<dyn LanguageModel>,
644 request_kind: RequestKind,
645 cx: &mut Context<Self>,
646 ) {
647 let mut request = self.to_completion_request(request_kind, cx);
648 request.tools = {
649 let mut tools = Vec::new();
650
651 if self.tools.is_scripting_tool_enabled() {
652 tools.push(LanguageModelRequestTool {
653 name: ScriptingTool::NAME.into(),
654 description: ScriptingTool::DESCRIPTION.into(),
655 input_schema: ScriptingTool::input_schema(),
656 });
657 }
658
659 tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
660 LanguageModelRequestTool {
661 name: tool.name(),
662 description: tool.description(),
663 input_schema: tool.input_schema(),
664 }
665 }));
666
667 tools
668 };
669
670 self.stream_completion(request, model, cx);
671 }
672
673 pub fn to_completion_request(
674 &self,
675 request_kind: RequestKind,
676 cx: &App,
677 ) -> LanguageModelRequest {
678 let mut request = LanguageModelRequest {
679 messages: vec![],
680 tools: Vec::new(),
681 stop: Vec::new(),
682 temperature: None,
683 };
684
685 if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
686 if let Some(system_prompt) = self
687 .prompt_builder
688 .generate_assistant_system_prompt(system_prompt_context)
689 .context("failed to generate assistant system prompt")
690 .log_err()
691 {
692 request.messages.push(LanguageModelRequestMessage {
693 role: Role::System,
694 content: vec![MessageContent::Text(system_prompt)],
695 cache: true,
696 });
697 }
698 } else {
699 log::error!("system_prompt_context not set.")
700 }
701
702 let mut referenced_context_ids = HashSet::default();
703
704 for message in &self.messages {
705 if let Some(context_ids) = self.context_by_message.get(&message.id) {
706 referenced_context_ids.extend(context_ids);
707 }
708
709 let mut request_message = LanguageModelRequestMessage {
710 role: message.role,
711 content: Vec::new(),
712 cache: false,
713 };
714
715 match request_kind {
716 RequestKind::Chat => {
717 self.tool_use
718 .attach_tool_results(message.id, &mut request_message);
719 self.scripting_tool_use
720 .attach_tool_results(message.id, &mut request_message);
721 }
722 RequestKind::Summarize => {
723 // We don't care about tool use during summarization.
724 }
725 }
726
727 if !message.text.is_empty() {
728 request_message
729 .content
730 .push(MessageContent::Text(message.text.clone()));
731 }
732
733 match request_kind {
734 RequestKind::Chat => {
735 self.tool_use
736 .attach_tool_uses(message.id, &mut request_message);
737 self.scripting_tool_use
738 .attach_tool_uses(message.id, &mut request_message);
739 }
740 RequestKind::Summarize => {
741 // We don't care about tool use during summarization.
742 }
743 };
744
745 request.messages.push(request_message);
746 }
747
748 if !referenced_context_ids.is_empty() {
749 let mut context_message = LanguageModelRequestMessage {
750 role: Role::User,
751 content: Vec::new(),
752 cache: false,
753 };
754
755 let referenced_context = referenced_context_ids
756 .into_iter()
757 .filter_map(|context_id| self.context.get(context_id))
758 .cloned();
759 attach_context_to_message(&mut context_message, referenced_context);
760
761 request.messages.push(context_message);
762 }
763
764 self.attach_stale_files(&mut request.messages, cx);
765
766 request
767 }
768
769 fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
770 const STALE_FILES_HEADER: &str = "These files changed since last read:";
771
772 let mut stale_message = String::new();
773
774 for stale_file in self.action_log.read(cx).stale_buffers(cx) {
775 let Some(file) = stale_file.read(cx).file() else {
776 continue;
777 };
778
779 if stale_message.is_empty() {
780 write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
781 }
782
783 writeln!(&mut stale_message, "- {}", file.path().display()).ok();
784 }
785
786 if !stale_message.is_empty() {
787 let context_message = LanguageModelRequestMessage {
788 role: Role::User,
789 content: vec![stale_message.into()],
790 cache: false,
791 };
792
793 messages.push(context_message);
794 }
795 }
796
797 pub fn stream_completion(
798 &mut self,
799 request: LanguageModelRequest,
800 model: Arc<dyn LanguageModel>,
801 cx: &mut Context<Self>,
802 ) {
803 let pending_completion_id = post_inc(&mut self.completion_count);
804
805 let task = cx.spawn(async move |thread, cx| {
806 let stream = model.stream_completion(request, &cx);
807 let initial_token_usage =
808 thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
809 let stream_completion = async {
810 let mut events = stream.await?;
811 let mut stop_reason = StopReason::EndTurn;
812 let mut current_token_usage = TokenUsage::default();
813
814 while let Some(event) = events.next().await {
815 let event = event?;
816
817 thread.update(cx, |thread, cx| {
818 match event {
819 LanguageModelCompletionEvent::StartMessage { .. } => {
820 thread.insert_message(Role::Assistant, String::new(), cx);
821 }
822 LanguageModelCompletionEvent::Stop(reason) => {
823 stop_reason = reason;
824 }
825 LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
826 thread.cumulative_token_usage =
827 thread.cumulative_token_usage.clone() + token_usage.clone()
828 - current_token_usage.clone();
829 current_token_usage = token_usage;
830 }
831 LanguageModelCompletionEvent::Text(chunk) => {
832 if let Some(last_message) = thread.messages.last_mut() {
833 if last_message.role == Role::Assistant {
834 last_message.text.push_str(&chunk);
835 cx.emit(ThreadEvent::StreamedAssistantText(
836 last_message.id,
837 chunk,
838 ));
839 } else {
840 // If we won't have an Assistant message yet, assume this chunk marks the beginning
841 // of a new Assistant response.
842 //
843 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
844 // will result in duplicating the text of the chunk in the rendered Markdown.
845 thread.insert_message(Role::Assistant, chunk, cx);
846 };
847 }
848 }
849 LanguageModelCompletionEvent::ToolUse(tool_use) => {
850 if let Some(last_assistant_message) = thread
851 .messages
852 .iter()
853 .rfind(|message| message.role == Role::Assistant)
854 {
855 if tool_use.name.as_ref() == ScriptingTool::NAME {
856 thread.scripting_tool_use.request_tool_use(
857 last_assistant_message.id,
858 tool_use,
859 cx,
860 );
861 } else {
862 thread.tool_use.request_tool_use(
863 last_assistant_message.id,
864 tool_use,
865 cx,
866 );
867 }
868 }
869 }
870 }
871
872 thread.touch_updated_at();
873 cx.emit(ThreadEvent::StreamedCompletion);
874 cx.notify();
875 })?;
876
877 smol::future::yield_now().await;
878 }
879
880 thread.update(cx, |thread, cx| {
881 thread
882 .pending_completions
883 .retain(|completion| completion.id != pending_completion_id);
884
885 if thread.summary.is_none() && thread.messages.len() >= 2 {
886 thread.summarize(cx);
887 }
888 })?;
889
890 anyhow::Ok(stop_reason)
891 };
892
893 let result = stream_completion.await;
894
895 thread
896 .update(cx, |thread, cx| {
897 match result.as_ref() {
898 Ok(stop_reason) => match stop_reason {
899 StopReason::ToolUse => {
900 cx.emit(ThreadEvent::UsePendingTools);
901 }
902 StopReason::EndTurn => {}
903 StopReason::MaxTokens => {}
904 },
905 Err(error) => {
906 if error.is::<PaymentRequiredError>() {
907 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
908 } else if error.is::<MaxMonthlySpendReachedError>() {
909 cx.emit(ThreadEvent::ShowError(
910 ThreadError::MaxMonthlySpendReached,
911 ));
912 } else {
913 let error_message = error
914 .chain()
915 .map(|err| err.to_string())
916 .collect::<Vec<_>>()
917 .join("\n");
918 cx.emit(ThreadEvent::ShowError(ThreadError::Message {
919 header: "Error interacting with language model".into(),
920 message: SharedString::from(error_message.clone()),
921 }));
922 }
923
924 thread.cancel_last_completion(cx);
925 }
926 }
927 cx.emit(ThreadEvent::DoneStreaming);
928
929 if let Ok(initial_usage) = initial_token_usage {
930 let usage = thread.cumulative_token_usage.clone() - initial_usage;
931
932 telemetry::event!(
933 "Assistant Thread Completion",
934 thread_id = thread.id().to_string(),
935 model = model.telemetry_id(),
936 model_provider = model.provider_id().to_string(),
937 input_tokens = usage.input_tokens,
938 output_tokens = usage.output_tokens,
939 cache_creation_input_tokens = usage.cache_creation_input_tokens,
940 cache_read_input_tokens = usage.cache_read_input_tokens,
941 );
942 }
943 })
944 .ok();
945 });
946
947 self.pending_completions.push(PendingCompletion {
948 id: pending_completion_id,
949 _task: task,
950 });
951 }
952
953 pub fn summarize(&mut self, cx: &mut Context<Self>) {
954 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
955 return;
956 };
957 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
958 return;
959 };
960
961 if !provider.is_authenticated(cx) {
962 return;
963 }
964
965 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
966 request.messages.push(LanguageModelRequestMessage {
967 role: Role::User,
968 content: vec![
969 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
970 .into(),
971 ],
972 cache: false,
973 });
974
975 self.pending_summary = cx.spawn(async move |this, cx| {
976 async move {
977 let stream = model.stream_completion_text(request, &cx);
978 let mut messages = stream.await?;
979
980 let mut new_summary = String::new();
981 while let Some(message) = messages.stream.next().await {
982 let text = message?;
983 let mut lines = text.lines();
984 new_summary.extend(lines.next());
985
986 // Stop if the LLM generated multiple lines.
987 if lines.next().is_some() {
988 break;
989 }
990 }
991
992 this.update(cx, |this, cx| {
993 if !new_summary.is_empty() {
994 this.summary = Some(new_summary.into());
995 }
996
997 cx.emit(ThreadEvent::SummaryChanged);
998 })?;
999
1000 anyhow::Ok(())
1001 }
1002 .log_err()
1003 .await
1004 });
1005 }
1006
1007 pub fn use_pending_tools(
1008 &mut self,
1009 cx: &mut Context<Self>,
1010 ) -> impl IntoIterator<Item = PendingToolUse> {
1011 let request = self.to_completion_request(RequestKind::Chat, cx);
1012 let pending_tool_uses = self
1013 .tool_use
1014 .pending_tool_uses()
1015 .into_iter()
1016 .filter(|tool_use| tool_use.status.is_idle())
1017 .cloned()
1018 .collect::<Vec<_>>();
1019
1020 for tool_use in pending_tool_uses.iter() {
1021 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1022 let task = tool.run(
1023 tool_use.input.clone(),
1024 &request.messages,
1025 self.project.clone(),
1026 self.action_log.clone(),
1027 cx,
1028 );
1029
1030 self.insert_tool_output(
1031 tool_use.id.clone(),
1032 tool_use.ui_text.clone().into(),
1033 task,
1034 cx,
1035 );
1036 }
1037 }
1038
1039 let pending_scripting_tool_uses = self
1040 .scripting_tool_use
1041 .pending_tool_uses()
1042 .into_iter()
1043 .filter(|tool_use| tool_use.status.is_idle())
1044 .cloned()
1045 .collect::<Vec<_>>();
1046
1047 for scripting_tool_use in pending_scripting_tool_uses.iter() {
1048 let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
1049 Err(err) => Task::ready(Err(err.into())),
1050 Ok(input) => {
1051 let (script_id, script_task) =
1052 self.scripting_session.update(cx, move |session, cx| {
1053 session.run_script(input.lua_script, cx)
1054 });
1055
1056 let session = self.scripting_session.clone();
1057 cx.spawn(async move |_, cx| {
1058 script_task.await;
1059
1060 let message = session.read_with(cx, |session, _cx| {
1061 // Using a id to get the script output seems impractical.
1062 // Why not just include it in the Task result?
1063 // This is because we'll later report the script state as it runs,
1064 session
1065 .get(script_id)
1066 .output_message_for_llm()
1067 .expect("Script shouldn't still be running")
1068 })?;
1069
1070 Ok(message)
1071 })
1072 }
1073 };
1074
1075 let ui_text: SharedString = scripting_tool_use.name.clone().into();
1076
1077 self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
1078 }
1079
1080 pending_tool_uses
1081 .into_iter()
1082 .chain(pending_scripting_tool_uses)
1083 }
1084
1085 pub fn insert_tool_output(
1086 &mut self,
1087 tool_use_id: LanguageModelToolUseId,
1088 ui_text: SharedString,
1089 output: Task<Result<String>>,
1090 cx: &mut Context<Self>,
1091 ) {
1092 let insert_output_task = cx.spawn({
1093 let tool_use_id = tool_use_id.clone();
1094 async move |thread, cx| {
1095 let output = output.await;
1096 thread
1097 .update(cx, |thread, cx| {
1098 let pending_tool_use = thread
1099 .tool_use
1100 .insert_tool_output(tool_use_id.clone(), output);
1101
1102 cx.emit(ThreadEvent::ToolFinished {
1103 tool_use_id,
1104 pending_tool_use,
1105 canceled: false,
1106 });
1107 })
1108 .ok();
1109 }
1110 });
1111
1112 self.tool_use
1113 .run_pending_tool(tool_use_id, ui_text, insert_output_task);
1114 }
1115
1116 pub fn insert_scripting_tool_output(
1117 &mut self,
1118 tool_use_id: LanguageModelToolUseId,
1119 ui_text: SharedString,
1120 output: Task<Result<String>>,
1121 cx: &mut Context<Self>,
1122 ) {
1123 let insert_output_task = cx.spawn({
1124 let tool_use_id = tool_use_id.clone();
1125 async move |thread, cx| {
1126 let output = output.await;
1127 thread
1128 .update(cx, |thread, cx| {
1129 let pending_tool_use = thread
1130 .scripting_tool_use
1131 .insert_tool_output(tool_use_id.clone(), output);
1132
1133 cx.emit(ThreadEvent::ToolFinished {
1134 tool_use_id,
1135 pending_tool_use,
1136 canceled: false,
1137 });
1138 })
1139 .ok();
1140 }
1141 });
1142
1143 self.scripting_tool_use
1144 .run_pending_tool(tool_use_id, ui_text, insert_output_task);
1145 }
1146
1147 pub fn attach_tool_results(
1148 &mut self,
1149 updated_context: Vec<ContextSnapshot>,
1150 cx: &mut Context<Self>,
1151 ) {
1152 self.context.extend(
1153 updated_context
1154 .into_iter()
1155 .map(|context| (context.id, context)),
1156 );
1157
1158 // Insert a user message to contain the tool results.
1159 self.insert_user_message(
1160 // TODO: Sending up a user message without any content results in the model sending back
1161 // responses that also don't have any content. We currently don't handle this case well,
1162 // so for now we provide some text to keep the model on track.
1163 "Here are the tool results.",
1164 Vec::new(),
1165 None,
1166 cx,
1167 );
1168 }
1169
1170 /// Cancels the last pending completion, if there are any pending.
1171 ///
1172 /// Returns whether a completion was canceled.
1173 pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1174 if self.pending_completions.pop().is_some() {
1175 true
1176 } else {
1177 let mut canceled = false;
1178 for pending_tool_use in self.tool_use.cancel_pending() {
1179 canceled = true;
1180 cx.emit(ThreadEvent::ToolFinished {
1181 tool_use_id: pending_tool_use.id.clone(),
1182 pending_tool_use: Some(pending_tool_use),
1183 canceled: true,
1184 });
1185 }
1186 canceled
1187 }
1188 }
1189
1190 /// Reports feedback about the thread and stores it in our telemetry backend.
1191 pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
1192 let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1193 let serialized_thread = self.serialize(cx);
1194 let thread_id = self.id().clone();
1195 let client = self.project.read(cx).client();
1196
1197 cx.background_spawn(async move {
1198 let final_project_snapshot = final_project_snapshot.await;
1199 let serialized_thread = serialized_thread.await?;
1200 let thread_data =
1201 serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1202
1203 let rating = if is_positive { "positive" } else { "negative" };
1204 telemetry::event!(
1205 "Assistant Thread Rated",
1206 rating,
1207 thread_id,
1208 thread_data,
1209 final_project_snapshot
1210 );
1211 client.telemetry().flush_events();
1212
1213 Ok(())
1214 })
1215 }
1216
1217 /// Create a snapshot of the current project state including git information and unsaved buffers.
1218 fn project_snapshot(
1219 project: Entity<Project>,
1220 cx: &mut Context<Self>,
1221 ) -> Task<Arc<ProjectSnapshot>> {
1222 let worktree_snapshots: Vec<_> = project
1223 .read(cx)
1224 .visible_worktrees(cx)
1225 .map(|worktree| Self::worktree_snapshot(worktree, cx))
1226 .collect();
1227
1228 cx.spawn(async move |_, cx| {
1229 let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1230
1231 let mut unsaved_buffers = Vec::new();
1232 cx.update(|app_cx| {
1233 let buffer_store = project.read(app_cx).buffer_store();
1234 for buffer_handle in buffer_store.read(app_cx).buffers() {
1235 let buffer = buffer_handle.read(app_cx);
1236 if buffer.is_dirty() {
1237 if let Some(file) = buffer.file() {
1238 let path = file.path().to_string_lossy().to_string();
1239 unsaved_buffers.push(path);
1240 }
1241 }
1242 }
1243 })
1244 .ok();
1245
1246 Arc::new(ProjectSnapshot {
1247 worktree_snapshots,
1248 unsaved_buffer_paths: unsaved_buffers,
1249 timestamp: Utc::now(),
1250 })
1251 })
1252 }
1253
1254 fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
1255 cx.spawn(async move |cx| {
1256 // Get worktree path and snapshot
1257 let worktree_info = cx.update(|app_cx| {
1258 let worktree = worktree.read(app_cx);
1259 let path = worktree.abs_path().to_string_lossy().to_string();
1260 let snapshot = worktree.snapshot();
1261 (path, snapshot)
1262 });
1263
1264 let Ok((worktree_path, snapshot)) = worktree_info else {
1265 return WorktreeSnapshot {
1266 worktree_path: String::new(),
1267 git_state: None,
1268 };
1269 };
1270
1271 // Extract git information
1272 let git_state = match snapshot.repositories().first() {
1273 None => None,
1274 Some(repo_entry) => {
1275 // Get branch information
1276 let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
1277
1278 // Get repository info
1279 let repo_result = worktree.read_with(cx, |worktree, _cx| {
1280 if let project::Worktree::Local(local_worktree) = &worktree {
1281 local_worktree.get_local_repo(repo_entry).map(|local_repo| {
1282 let repo = local_repo.repo();
1283 (repo.remote_url("origin"), repo.head_sha(), repo.clone())
1284 })
1285 } else {
1286 None
1287 }
1288 });
1289
1290 match repo_result {
1291 Ok(Some((remote_url, head_sha, repository))) => {
1292 // Get diff asynchronously
1293 let diff = repository
1294 .diff(git::repository::DiffType::HeadToWorktree, cx.clone())
1295 .await
1296 .ok();
1297
1298 Some(GitState {
1299 remote_url,
1300 head_sha,
1301 current_branch,
1302 diff,
1303 })
1304 }
1305 Err(_) | Ok(None) => None,
1306 }
1307 }
1308 };
1309
1310 WorktreeSnapshot {
1311 worktree_path,
1312 git_state,
1313 }
1314 })
1315 }
1316
1317 pub fn to_markdown(&self, cx: &App) -> Result<String> {
1318 let mut markdown = Vec::new();
1319
1320 if let Some(summary) = self.summary() {
1321 writeln!(markdown, "# {summary}\n")?;
1322 };
1323
1324 for message in self.messages() {
1325 writeln!(
1326 markdown,
1327 "## {role}\n",
1328 role = match message.role {
1329 Role::User => "User",
1330 Role::Assistant => "Assistant",
1331 Role::System => "System",
1332 }
1333 )?;
1334 writeln!(markdown, "{}\n", message.text)?;
1335
1336 for tool_use in self.tool_uses_for_message(message.id, cx) {
1337 writeln!(
1338 markdown,
1339 "**Use Tool: {} ({})**",
1340 tool_use.name, tool_use.id
1341 )?;
1342 writeln!(markdown, "```json")?;
1343 writeln!(
1344 markdown,
1345 "{}",
1346 serde_json::to_string_pretty(&tool_use.input)?
1347 )?;
1348 writeln!(markdown, "```")?;
1349 }
1350
1351 for tool_result in self.tool_results_for_message(message.id) {
1352 write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1353 if tool_result.is_error {
1354 write!(markdown, " (Error)")?;
1355 }
1356
1357 writeln!(markdown, "**\n")?;
1358 writeln!(markdown, "{}", tool_result.content)?;
1359 }
1360 }
1361
1362 Ok(String::from_utf8_lossy(&markdown).to_string())
1363 }
1364
1365 pub fn action_log(&self) -> &Entity<ActionLog> {
1366 &self.action_log
1367 }
1368
1369 pub fn project(&self) -> &Entity<Project> {
1370 &self.project
1371 }
1372
1373 pub fn cumulative_token_usage(&self) -> TokenUsage {
1374 self.cumulative_token_usage.clone()
1375 }
1376}
1377
1378#[derive(Debug, Clone)]
1379pub enum ThreadError {
1380 PaymentRequired,
1381 MaxMonthlySpendReached,
1382 Message {
1383 header: SharedString,
1384 message: SharedString,
1385 },
1386}
1387
1388#[derive(Debug, Clone)]
1389pub enum ThreadEvent {
1390 ShowError(ThreadError),
1391 StreamedCompletion,
1392 StreamedAssistantText(MessageId, String),
1393 DoneStreaming,
1394 MessageAdded(MessageId),
1395 MessageEdited(MessageId),
1396 MessageDeleted(MessageId),
1397 SummaryChanged,
1398 UsePendingTools,
1399 ToolFinished {
1400 #[allow(unused)]
1401 tool_use_id: LanguageModelToolUseId,
1402 /// The pending tool use that corresponds to this tool.
1403 pending_tool_use: Option<PendingToolUse>,
1404 /// Whether the tool was canceled by the user.
1405 canceled: bool,
1406 },
1407 CheckpointChanged,
1408}
1409
1410impl EventEmitter<ThreadEvent> for Thread {}
1411
1412struct PendingCompletion {
1413 id: usize,
1414 _task: Task<()>,
1415}