1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_scripting::{
5 Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT,
6};
7use assistant_tool::ToolWorkingSet;
8use chrono::{DateTime, Utc};
9use collections::{BTreeMap, HashMap, HashSet};
10use futures::StreamExt as _;
11use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
12use language_model::{
13 LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
14 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
15 LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
16 Role, StopReason,
17};
18use project::Project;
19use serde::{Deserialize, Serialize};
20use util::{post_inc, TryFutureExt as _};
21use uuid::Uuid;
22
23use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
24use crate::thread_store::SavedThread;
25use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
26
27#[derive(Debug, Clone, Copy)]
28pub enum RequestKind {
29 Chat,
30 /// Used when summarizing a thread.
31 Summarize,
32}
33
34#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
35pub struct ThreadId(Arc<str>);
36
37impl ThreadId {
38 pub fn new() -> Self {
39 Self(Uuid::new_v4().to_string().into())
40 }
41}
42
43impl std::fmt::Display for ThreadId {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "{}", self.0)
46 }
47}
48
49#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
50pub struct MessageId(pub(crate) usize);
51
52impl MessageId {
53 fn post_inc(&mut self) -> Self {
54 Self(post_inc(&mut self.0))
55 }
56}
57
58/// A message in a [`Thread`].
59#[derive(Debug, Clone)]
60pub struct Message {
61 pub id: MessageId,
62 pub role: Role,
63 pub text: String,
64}
65
66/// A thread of conversation with the LLM.
67pub struct Thread {
68 id: ThreadId,
69 updated_at: DateTime<Utc>,
70 summary: Option<SharedString>,
71 pending_summary: Task<Option<()>>,
72 messages: Vec<Message>,
73 next_message_id: MessageId,
74 context: BTreeMap<ContextId, ContextSnapshot>,
75 context_by_message: HashMap<MessageId, Vec<ContextId>>,
76 completion_count: usize,
77 pending_completions: Vec<PendingCompletion>,
78 project: Entity<Project>,
79 tools: Arc<ToolWorkingSet>,
80 tool_use: ToolUseState,
81 scripts_by_assistant_message: HashMap<MessageId, ScriptId>,
82 script_output_messages: HashSet<MessageId>,
83 script_session: Entity<ScriptSession>,
84 _script_session_subscription: Subscription,
85}
86
87impl Thread {
88 pub fn new(
89 project: Entity<Project>,
90 tools: Arc<ToolWorkingSet>,
91 cx: &mut Context<Self>,
92 ) -> Self {
93 let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
94 let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
95
96 Self {
97 id: ThreadId::new(),
98 updated_at: Utc::now(),
99 summary: None,
100 pending_summary: Task::ready(None),
101 messages: Vec::new(),
102 next_message_id: MessageId(0),
103 context: BTreeMap::default(),
104 context_by_message: HashMap::default(),
105 completion_count: 0,
106 pending_completions: Vec::new(),
107 project,
108 tools,
109 tool_use: ToolUseState::new(),
110 scripts_by_assistant_message: HashMap::default(),
111 script_output_messages: HashSet::default(),
112 script_session,
113 _script_session_subscription: script_session_subscription,
114 }
115 }
116
117 pub fn from_saved(
118 id: ThreadId,
119 saved: SavedThread,
120 project: Entity<Project>,
121 tools: Arc<ToolWorkingSet>,
122 cx: &mut Context<Self>,
123 ) -> Self {
124 let next_message_id = MessageId(
125 saved
126 .messages
127 .last()
128 .map(|message| message.id.0 + 1)
129 .unwrap_or(0),
130 );
131 let tool_use = ToolUseState::from_saved_messages(&saved.messages);
132 let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
133 let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
134
135 Self {
136 id,
137 updated_at: saved.updated_at,
138 summary: Some(saved.summary),
139 pending_summary: Task::ready(None),
140 messages: saved
141 .messages
142 .into_iter()
143 .map(|message| Message {
144 id: message.id,
145 role: message.role,
146 text: message.text,
147 })
148 .collect(),
149 next_message_id,
150 context: BTreeMap::default(),
151 context_by_message: HashMap::default(),
152 completion_count: 0,
153 pending_completions: Vec::new(),
154 project,
155 tools,
156 tool_use,
157 scripts_by_assistant_message: HashMap::default(),
158 script_output_messages: HashSet::default(),
159 script_session,
160 _script_session_subscription: script_session_subscription,
161 }
162 }
163
164 pub fn id(&self) -> &ThreadId {
165 &self.id
166 }
167
168 pub fn is_empty(&self) -> bool {
169 self.messages.is_empty()
170 }
171
172 pub fn updated_at(&self) -> DateTime<Utc> {
173 self.updated_at
174 }
175
176 pub fn touch_updated_at(&mut self) {
177 self.updated_at = Utc::now();
178 }
179
180 pub fn summary(&self) -> Option<SharedString> {
181 self.summary.clone()
182 }
183
184 pub fn summary_or_default(&self) -> SharedString {
185 const DEFAULT: SharedString = SharedString::new_static("New Thread");
186 self.summary.clone().unwrap_or(DEFAULT)
187 }
188
189 pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
190 self.summary = Some(summary.into());
191 cx.emit(ThreadEvent::SummaryChanged);
192 }
193
194 pub fn message(&self, id: MessageId) -> Option<&Message> {
195 self.messages.iter().find(|message| message.id == id)
196 }
197
198 pub fn messages(&self) -> impl Iterator<Item = &Message> {
199 self.messages.iter()
200 }
201
202 pub fn is_streaming(&self) -> bool {
203 !self.pending_completions.is_empty()
204 }
205
206 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
207 &self.tools
208 }
209
210 pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
211 let context = self.context_by_message.get(&id)?;
212 Some(
213 context
214 .into_iter()
215 .filter_map(|context_id| self.context.get(&context_id))
216 .cloned()
217 .collect::<Vec<_>>(),
218 )
219 }
220
221 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
222 self.tool_use.pending_tool_uses()
223 }
224
225 /// Returns whether all of the tool uses have finished running.
226 pub fn all_tools_finished(&self) -> bool {
227 // If the only pending tool uses left are the ones with errors, then that means that we've finished running all
228 // of the pending tools.
229 self.pending_tool_uses()
230 .into_iter()
231 .all(|tool_use| tool_use.status.is_error())
232 }
233
234 pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
235 self.tool_use.tool_uses_for_message(id)
236 }
237
238 pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
239 self.tool_use.tool_results_for_message(id)
240 }
241
242 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
243 self.tool_use.message_has_tool_results(message_id)
244 }
245
246 pub fn message_has_script_output(&self, message_id: MessageId) -> bool {
247 self.script_output_messages.contains(&message_id)
248 }
249
250 pub fn insert_user_message(
251 &mut self,
252 text: impl Into<String>,
253 context: Vec<ContextSnapshot>,
254 cx: &mut Context<Self>,
255 ) -> MessageId {
256 let message_id = self.insert_message(Role::User, text, cx);
257 let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
258 self.context
259 .extend(context.into_iter().map(|context| (context.id, context)));
260 self.context_by_message.insert(message_id, context_ids);
261 message_id
262 }
263
264 pub fn insert_message(
265 &mut self,
266 role: Role,
267 text: impl Into<String>,
268 cx: &mut Context<Self>,
269 ) -> MessageId {
270 let id = self.next_message_id.post_inc();
271 self.messages.push(Message {
272 id,
273 role,
274 text: text.into(),
275 });
276 self.touch_updated_at();
277 cx.emit(ThreadEvent::MessageAdded(id));
278 id
279 }
280
281 pub fn edit_message(
282 &mut self,
283 id: MessageId,
284 new_role: Role,
285 new_text: String,
286 cx: &mut Context<Self>,
287 ) -> bool {
288 let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
289 return false;
290 };
291 message.role = new_role;
292 message.text = new_text;
293 self.touch_updated_at();
294 cx.emit(ThreadEvent::MessageEdited(id));
295 true
296 }
297
298 pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
299 let Some(index) = self.messages.iter().position(|message| message.id == id) else {
300 return false;
301 };
302 self.messages.remove(index);
303 self.context_by_message.remove(&id);
304 self.touch_updated_at();
305 cx.emit(ThreadEvent::MessageDeleted(id));
306 true
307 }
308
309 /// Returns the representation of this [`Thread`] in a textual form.
310 ///
311 /// This is the representation we use when attaching a thread as context to another thread.
312 pub fn text(&self) -> String {
313 let mut text = String::new();
314
315 for message in &self.messages {
316 text.push_str(match message.role {
317 language_model::Role::User => "User:",
318 language_model::Role::Assistant => "Assistant:",
319 language_model::Role::System => "System:",
320 });
321 text.push('\n');
322
323 text.push_str(&message.text);
324 text.push('\n');
325 }
326
327 text
328 }
329
330 pub fn script_for_message<'a>(
331 &'a self,
332 message_id: MessageId,
333 cx: &'a App,
334 ) -> Option<&'a Script> {
335 self.scripts_by_assistant_message
336 .get(&message_id)
337 .map(|script_id| self.script_session.read(cx).get(*script_id))
338 }
339
340 fn handle_script_event(
341 &mut self,
342 _script_session: Entity<ScriptSession>,
343 event: &ScriptEvent,
344 cx: &mut Context<Self>,
345 ) {
346 match event {
347 ScriptEvent::Spawned(_) => {}
348 ScriptEvent::Exited(script_id) => {
349 if let Some(output_message) = self
350 .script_session
351 .read(cx)
352 .get(*script_id)
353 .output_message_for_llm()
354 {
355 let message_id = self.insert_user_message(output_message, vec![], cx);
356 self.script_output_messages.insert(message_id);
357 cx.emit(ThreadEvent::ScriptFinished)
358 }
359 }
360 }
361 }
362
363 pub fn send_to_model(
364 &mut self,
365 model: Arc<dyn LanguageModel>,
366 request_kind: RequestKind,
367 use_tools: bool,
368 cx: &mut Context<Self>,
369 ) {
370 let mut request = self.to_completion_request(request_kind, cx);
371
372 if use_tools {
373 request.tools = self
374 .tools()
375 .tools(cx)
376 .into_iter()
377 .map(|tool| LanguageModelRequestTool {
378 name: tool.name(),
379 description: tool.description(),
380 input_schema: tool.input_schema(),
381 })
382 .collect();
383 }
384
385 self.stream_completion(request, model, cx);
386 }
387
388 pub fn to_completion_request(
389 &self,
390 request_kind: RequestKind,
391 cx: &App,
392 ) -> LanguageModelRequest {
393 let mut request = LanguageModelRequest {
394 messages: vec![],
395 tools: Vec::new(),
396 stop: Vec::new(),
397 temperature: None,
398 };
399
400 request.messages.push(LanguageModelRequestMessage {
401 role: Role::System,
402 content: vec![SCRIPTING_PROMPT.to_string().into()],
403 cache: true,
404 });
405
406 let mut referenced_context_ids = HashSet::default();
407
408 for message in &self.messages {
409 if let Some(context_ids) = self.context_by_message.get(&message.id) {
410 referenced_context_ids.extend(context_ids);
411 }
412
413 let mut request_message = LanguageModelRequestMessage {
414 role: message.role,
415 content: Vec::new(),
416 cache: false,
417 };
418
419 match request_kind {
420 RequestKind::Chat => {
421 self.tool_use
422 .attach_tool_results(message.id, &mut request_message);
423 }
424 RequestKind::Summarize => {
425 // We don't care about tool use during summarization.
426 }
427 }
428
429 if !message.text.is_empty() {
430 request_message
431 .content
432 .push(MessageContent::Text(message.text.clone()));
433 }
434
435 match request_kind {
436 RequestKind::Chat => {
437 self.tool_use
438 .attach_tool_uses(message.id, &mut request_message);
439
440 if matches!(message.role, Role::Assistant) {
441 if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id)
442 {
443 let script = self.script_session.read(cx).get(*script_id);
444
445 request_message.content.push(script.source_tag().into());
446 }
447 }
448 }
449 RequestKind::Summarize => {
450 // We don't care about tool use during summarization.
451 }
452 };
453
454 request.messages.push(request_message);
455 }
456
457 if !referenced_context_ids.is_empty() {
458 let mut context_message = LanguageModelRequestMessage {
459 role: Role::User,
460 content: Vec::new(),
461 cache: false,
462 };
463
464 let referenced_context = referenced_context_ids
465 .into_iter()
466 .filter_map(|context_id| self.context.get(context_id))
467 .cloned();
468 attach_context_to_message(&mut context_message, referenced_context);
469
470 request.messages.push(context_message);
471 }
472
473 request
474 }
475
476 pub fn stream_completion(
477 &mut self,
478 request: LanguageModelRequest,
479 model: Arc<dyn LanguageModel>,
480 cx: &mut Context<Self>,
481 ) {
482 let pending_completion_id = post_inc(&mut self.completion_count);
483
484 let task = cx.spawn(|thread, mut cx| async move {
485 let stream = model.stream_completion(request, &cx);
486 let stream_completion = async {
487 let mut events = stream.await?;
488 let mut stop_reason = StopReason::EndTurn;
489 let mut script_tag_parser = ScriptTagParser::new();
490 let mut script_id = None;
491
492 while let Some(event) = events.next().await {
493 let event = event?;
494
495 thread.update(&mut cx, |thread, cx| {
496 match event {
497 LanguageModelCompletionEvent::StartMessage { .. } => {
498 thread.insert_message(Role::Assistant, String::new(), cx);
499 }
500 LanguageModelCompletionEvent::Stop(reason) => {
501 stop_reason = reason;
502 }
503 LanguageModelCompletionEvent::Text(chunk) => {
504 if let Some(last_message) = thread.messages.last_mut() {
505 let chunk = script_tag_parser.parse_chunk(&chunk);
506
507 let message_id = if last_message.role == Role::Assistant {
508 last_message.text.push_str(&chunk.content);
509 cx.emit(ThreadEvent::StreamedAssistantText(
510 last_message.id,
511 chunk.content,
512 ));
513 last_message.id
514 } else {
515 // If we won't have an Assistant message yet, assume this chunk marks the beginning
516 // of a new Assistant response.
517 //
518 // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
519 // will result in duplicating the text of the chunk in the rendered Markdown.
520 thread.insert_message(Role::Assistant, chunk.content, cx)
521 };
522
523 if script_id.is_none() && script_tag_parser.found_script() {
524 let id = thread
525 .script_session
526 .update(cx, |session, _cx| session.new_script());
527 thread.scripts_by_assistant_message.insert(message_id, id);
528
529 script_id = Some(id);
530 }
531
532 if let (Some(script_source), Some(script_id)) =
533 (chunk.script_source, script_id)
534 {
535 // TODO: move buffer to script and run as it streams
536 thread
537 .script_session
538 .update(cx, |this, cx| {
539 this.run_script(script_id, script_source, cx)
540 })
541 .detach_and_log_err(cx);
542 }
543 }
544 }
545 LanguageModelCompletionEvent::ToolUse(tool_use) => {
546 if let Some(last_assistant_message) = thread
547 .messages
548 .iter()
549 .rfind(|message| message.role == Role::Assistant)
550 {
551 thread
552 .tool_use
553 .request_tool_use(last_assistant_message.id, tool_use);
554 }
555 }
556 }
557
558 thread.touch_updated_at();
559 cx.emit(ThreadEvent::StreamedCompletion);
560 cx.notify();
561 })?;
562
563 smol::future::yield_now().await;
564 }
565
566 thread.update(&mut cx, |thread, cx| {
567 thread
568 .pending_completions
569 .retain(|completion| completion.id != pending_completion_id);
570
571 if thread.summary.is_none() && thread.messages.len() >= 2 {
572 thread.summarize(cx);
573 }
574 })?;
575
576 anyhow::Ok(stop_reason)
577 };
578
579 let result = stream_completion.await;
580
581 thread
582 .update(&mut cx, |thread, cx| match result.as_ref() {
583 Ok(stop_reason) => match stop_reason {
584 StopReason::ToolUse => {
585 cx.emit(ThreadEvent::UsePendingTools);
586 }
587 StopReason::EndTurn => {}
588 StopReason::MaxTokens => {}
589 },
590 Err(error) => {
591 if error.is::<PaymentRequiredError>() {
592 cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
593 } else if error.is::<MaxMonthlySpendReachedError>() {
594 cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
595 } else {
596 let error_message = error
597 .chain()
598 .map(|err| err.to_string())
599 .collect::<Vec<_>>()
600 .join("\n");
601 cx.emit(ThreadEvent::ShowError(ThreadError::Message(
602 SharedString::from(error_message.clone()),
603 )));
604 }
605
606 thread.cancel_last_completion();
607 }
608 })
609 .ok();
610 });
611
612 self.pending_completions.push(PendingCompletion {
613 id: pending_completion_id,
614 _task: task,
615 });
616 }
617
618 pub fn summarize(&mut self, cx: &mut Context<Self>) {
619 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
620 return;
621 };
622 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
623 return;
624 };
625
626 if !provider.is_authenticated(cx) {
627 return;
628 }
629
630 let mut request = self.to_completion_request(RequestKind::Summarize, cx);
631 request.messages.push(LanguageModelRequestMessage {
632 role: Role::User,
633 content: vec![
634 "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
635 .into(),
636 ],
637 cache: false,
638 });
639
640 self.pending_summary = cx.spawn(|this, mut cx| {
641 async move {
642 let stream = model.stream_completion_text(request, &cx);
643 let mut messages = stream.await?;
644
645 let mut new_summary = String::new();
646 while let Some(message) = messages.stream.next().await {
647 let text = message?;
648 let mut lines = text.lines();
649 new_summary.extend(lines.next());
650
651 // Stop if the LLM generated multiple lines.
652 if lines.next().is_some() {
653 break;
654 }
655 }
656
657 this.update(&mut cx, |this, cx| {
658 if !new_summary.is_empty() {
659 this.summary = Some(new_summary.into());
660 }
661
662 cx.emit(ThreadEvent::SummaryChanged);
663 })?;
664
665 anyhow::Ok(())
666 }
667 .log_err()
668 });
669 }
670
671 pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
672 let pending_tool_uses = self
673 .pending_tool_uses()
674 .into_iter()
675 .filter(|tool_use| tool_use.status.is_idle())
676 .cloned()
677 .collect::<Vec<_>>();
678
679 for tool_use in pending_tool_uses {
680 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
681 let task = tool.run(tool_use.input, self.project.clone(), cx);
682
683 self.insert_tool_output(tool_use.id.clone(), task, cx);
684 }
685 }
686 }
687
688 pub fn insert_tool_output(
689 &mut self,
690 tool_use_id: LanguageModelToolUseId,
691 output: Task<Result<String>>,
692 cx: &mut Context<Self>,
693 ) {
694 let insert_output_task = cx.spawn(|thread, mut cx| {
695 let tool_use_id = tool_use_id.clone();
696 async move {
697 let output = output.await;
698 thread
699 .update(&mut cx, |thread, cx| {
700 thread
701 .tool_use
702 .insert_tool_output(tool_use_id.clone(), output);
703
704 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
705 })
706 .ok();
707 }
708 });
709
710 self.tool_use
711 .run_pending_tool(tool_use_id, insert_output_task);
712 }
713
714 pub fn send_tool_results_to_model(
715 &mut self,
716 model: Arc<dyn LanguageModel>,
717 cx: &mut Context<Self>,
718 ) {
719 // Insert a user message to contain the tool results.
720 self.insert_user_message(
721 // TODO: Sending up a user message without any content results in the model sending back
722 // responses that also don't have any content. We currently don't handle this case well,
723 // so for now we provide some text to keep the model on track.
724 "Here are the tool results.",
725 Vec::new(),
726 cx,
727 );
728 self.send_to_model(model, RequestKind::Chat, true, cx);
729 }
730
731 /// Cancels the last pending completion, if there are any pending.
732 ///
733 /// Returns whether a completion was canceled.
734 pub fn cancel_last_completion(&mut self) -> bool {
735 if let Some(_last_completion) = self.pending_completions.pop() {
736 true
737 } else {
738 false
739 }
740 }
741}
742
743#[derive(Debug, Clone)]
744pub enum ThreadError {
745 PaymentRequired,
746 MaxMonthlySpendReached,
747 Message(SharedString),
748}
749
750#[derive(Debug, Clone)]
751pub enum ThreadEvent {
752 ShowError(ThreadError),
753 StreamedCompletion,
754 StreamedAssistantText(MessageId, String),
755 MessageAdded(MessageId),
756 MessageEdited(MessageId),
757 MessageDeleted(MessageId),
758 SummaryChanged,
759 UsePendingTools,
760 ToolFinished {
761 #[allow(unused)]
762 tool_use_id: LanguageModelToolUseId,
763 },
764 ScriptFinished,
765}
766
767impl EventEmitter<ThreadEvent> for Thread {}
768
769struct PendingCompletion {
770 id: usize,
771 _task: Task<()>,
772}