thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use chrono::{DateTime, Utc};
  6use collections::{BTreeMap, HashMap, HashSet};
  7use futures::StreamExt as _;
  8use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
 11    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
 12    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
 13    Role, StopReason,
 14};
 15use project::Project;
 16use serde::{Deserialize, Serialize};
 17use util::{post_inc, TryFutureExt as _};
 18use uuid::Uuid;
 19
 20use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 21use crate::thread_store::SavedThread;
 22use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
 23
 24#[derive(Debug, Clone, Copy)]
 25pub enum RequestKind {
 26    Chat,
 27    /// Used when summarizing a thread.
 28    Summarize,
 29}
 30
 31#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 32pub struct ThreadId(Arc<str>);
 33
 34impl ThreadId {
 35    pub fn new() -> Self {
 36        Self(Uuid::new_v4().to_string().into())
 37    }
 38}
 39
 40impl std::fmt::Display for ThreadId {
 41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 42        write!(f, "{}", self.0)
 43    }
 44}
 45
 46#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 47pub struct MessageId(pub(crate) usize);
 48
 49impl MessageId {
 50    fn post_inc(&mut self) -> Self {
 51        Self(post_inc(&mut self.0))
 52    }
 53}
 54
 55/// A message in a [`Thread`].
 56#[derive(Debug, Clone)]
 57pub struct Message {
 58    pub id: MessageId,
 59    pub role: Role,
 60    pub text: String,
 61}
 62
 63/// A thread of conversation with the LLM.
 64pub struct Thread {
 65    id: ThreadId,
 66    updated_at: DateTime<Utc>,
 67    summary: Option<SharedString>,
 68    pending_summary: Task<Option<()>>,
 69    messages: Vec<Message>,
 70    next_message_id: MessageId,
 71    context: BTreeMap<ContextId, ContextSnapshot>,
 72    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 73    completion_count: usize,
 74    pending_completions: Vec<PendingCompletion>,
 75    project: Entity<Project>,
 76    tools: Arc<ToolWorkingSet>,
 77    tool_use: ToolUseState,
 78}
 79
 80impl Thread {
 81    pub fn new(
 82        project: Entity<Project>,
 83        tools: Arc<ToolWorkingSet>,
 84        _cx: &mut Context<Self>,
 85    ) -> Self {
 86        Self {
 87            id: ThreadId::new(),
 88            updated_at: Utc::now(),
 89            summary: None,
 90            pending_summary: Task::ready(None),
 91            messages: Vec::new(),
 92            next_message_id: MessageId(0),
 93            context: BTreeMap::default(),
 94            context_by_message: HashMap::default(),
 95            completion_count: 0,
 96            pending_completions: Vec::new(),
 97            project,
 98            tools,
 99            tool_use: ToolUseState::new(),
100        }
101    }
102
103    pub fn from_saved(
104        id: ThreadId,
105        saved: SavedThread,
106        project: Entity<Project>,
107        tools: Arc<ToolWorkingSet>,
108        _cx: &mut Context<Self>,
109    ) -> Self {
110        let next_message_id = MessageId(
111            saved
112                .messages
113                .last()
114                .map(|message| message.id.0 + 1)
115                .unwrap_or(0),
116        );
117        let tool_use = ToolUseState::from_saved_messages(&saved.messages);
118
119        Self {
120            id,
121            updated_at: saved.updated_at,
122            summary: Some(saved.summary),
123            pending_summary: Task::ready(None),
124            messages: saved
125                .messages
126                .into_iter()
127                .map(|message| Message {
128                    id: message.id,
129                    role: message.role,
130                    text: message.text,
131                })
132                .collect(),
133            next_message_id,
134            context: BTreeMap::default(),
135            context_by_message: HashMap::default(),
136            completion_count: 0,
137            pending_completions: Vec::new(),
138            project,
139            tools,
140            tool_use,
141        }
142    }
143
144    pub fn id(&self) -> &ThreadId {
145        &self.id
146    }
147
148    pub fn is_empty(&self) -> bool {
149        self.messages.is_empty()
150    }
151
152    pub fn updated_at(&self) -> DateTime<Utc> {
153        self.updated_at
154    }
155
156    pub fn touch_updated_at(&mut self) {
157        self.updated_at = Utc::now();
158    }
159
160    pub fn summary(&self) -> Option<SharedString> {
161        self.summary.clone()
162    }
163
164    pub fn summary_or_default(&self) -> SharedString {
165        const DEFAULT: SharedString = SharedString::new_static("New Thread");
166        self.summary.clone().unwrap_or(DEFAULT)
167    }
168
169    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
170        self.summary = Some(summary.into());
171        cx.emit(ThreadEvent::SummaryChanged);
172    }
173
174    pub fn message(&self, id: MessageId) -> Option<&Message> {
175        self.messages.iter().find(|message| message.id == id)
176    }
177
178    pub fn messages(&self) -> impl Iterator<Item = &Message> {
179        self.messages.iter()
180    }
181
182    pub fn is_streaming(&self) -> bool {
183        !self.pending_completions.is_empty()
184    }
185
186    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
187        &self.tools
188    }
189
190    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
191        let context = self.context_by_message.get(&id)?;
192        Some(
193            context
194                .into_iter()
195                .filter_map(|context_id| self.context.get(&context_id))
196                .cloned()
197                .collect::<Vec<_>>(),
198        )
199    }
200
201    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
202        self.tool_use.pending_tool_uses()
203    }
204
205    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
206        self.tool_use.tool_uses_for_message(id)
207    }
208
209    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
210        self.tool_use.tool_results_for_message(id)
211    }
212
213    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
214        self.tool_use.message_has_tool_results(message_id)
215    }
216
217    pub fn insert_user_message(
218        &mut self,
219        text: impl Into<String>,
220        context: Vec<ContextSnapshot>,
221        cx: &mut Context<Self>,
222    ) {
223        let message_id = self.insert_message(Role::User, text, cx);
224        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
225        self.context
226            .extend(context.into_iter().map(|context| (context.id, context)));
227        self.context_by_message.insert(message_id, context_ids);
228    }
229
230    pub fn insert_message(
231        &mut self,
232        role: Role,
233        text: impl Into<String>,
234        cx: &mut Context<Self>,
235    ) -> MessageId {
236        let id = self.next_message_id.post_inc();
237        self.messages.push(Message {
238            id,
239            role,
240            text: text.into(),
241        });
242        self.touch_updated_at();
243        cx.emit(ThreadEvent::MessageAdded(id));
244        id
245    }
246
247    pub fn edit_message(
248        &mut self,
249        id: MessageId,
250        new_role: Role,
251        new_text: String,
252        cx: &mut Context<Self>,
253    ) -> bool {
254        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
255            return false;
256        };
257        message.role = new_role;
258        message.text = new_text;
259        self.touch_updated_at();
260        cx.emit(ThreadEvent::MessageEdited(id));
261        true
262    }
263
264    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
265        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
266            return false;
267        };
268        self.messages.remove(index);
269        self.context_by_message.remove(&id);
270        self.touch_updated_at();
271        cx.emit(ThreadEvent::MessageDeleted(id));
272        true
273    }
274
275    /// Returns the representation of this [`Thread`] in a textual form.
276    ///
277    /// This is the representation we use when attaching a thread as context to another thread.
278    pub fn text(&self) -> String {
279        let mut text = String::new();
280
281        for message in &self.messages {
282            text.push_str(match message.role {
283                language_model::Role::User => "User:",
284                language_model::Role::Assistant => "Assistant:",
285                language_model::Role::System => "System:",
286            });
287            text.push('\n');
288
289            text.push_str(&message.text);
290            text.push('\n');
291        }
292
293        text
294    }
295
296    pub fn send_to_model(
297        &mut self,
298        model: Arc<dyn LanguageModel>,
299        request_kind: RequestKind,
300        use_tools: bool,
301        cx: &mut Context<Self>,
302    ) {
303        let mut request = self.to_completion_request(request_kind, cx);
304
305        if use_tools {
306            request.tools = self
307                .tools()
308                .tools(cx)
309                .into_iter()
310                .map(|tool| LanguageModelRequestTool {
311                    name: tool.name(),
312                    description: tool.description(),
313                    input_schema: tool.input_schema(),
314                })
315                .collect();
316        }
317
318        self.stream_completion(request, model, cx);
319    }
320
321    pub fn to_completion_request(
322        &self,
323        request_kind: RequestKind,
324        _cx: &App,
325    ) -> LanguageModelRequest {
326        let mut request = LanguageModelRequest {
327            messages: vec![],
328            tools: Vec::new(),
329            stop: Vec::new(),
330            temperature: None,
331        };
332
333        let mut referenced_context_ids = HashSet::default();
334
335        for message in &self.messages {
336            if let Some(context_ids) = self.context_by_message.get(&message.id) {
337                referenced_context_ids.extend(context_ids);
338            }
339
340            let mut request_message = LanguageModelRequestMessage {
341                role: message.role,
342                content: Vec::new(),
343                cache: false,
344            };
345            match request_kind {
346                RequestKind::Chat => {
347                    self.tool_use
348                        .attach_tool_results(message.id, &mut request_message);
349                }
350                RequestKind::Summarize => {
351                    // We don't care about tool use during summarization.
352                }
353            }
354
355            if !message.text.is_empty() {
356                request_message
357                    .content
358                    .push(MessageContent::Text(message.text.clone()));
359            }
360
361            match request_kind {
362                RequestKind::Chat => {
363                    self.tool_use
364                        .attach_tool_uses(message.id, &mut request_message);
365                }
366                RequestKind::Summarize => {
367                    // We don't care about tool use during summarization.
368                }
369            }
370
371            request.messages.push(request_message);
372        }
373
374        if !referenced_context_ids.is_empty() {
375            let mut context_message = LanguageModelRequestMessage {
376                role: Role::User,
377                content: Vec::new(),
378                cache: false,
379            };
380
381            let referenced_context = referenced_context_ids
382                .into_iter()
383                .filter_map(|context_id| self.context.get(context_id))
384                .cloned();
385            attach_context_to_message(&mut context_message, referenced_context);
386
387            request.messages.push(context_message);
388        }
389
390        request
391    }
392
393    pub fn stream_completion(
394        &mut self,
395        request: LanguageModelRequest,
396        model: Arc<dyn LanguageModel>,
397        cx: &mut Context<Self>,
398    ) {
399        let pending_completion_id = post_inc(&mut self.completion_count);
400
401        let task = cx.spawn(|thread, mut cx| async move {
402            let stream = model.stream_completion(request, &cx);
403            let stream_completion = async {
404                let mut events = stream.await?;
405                let mut stop_reason = StopReason::EndTurn;
406
407                while let Some(event) = events.next().await {
408                    let event = event?;
409
410                    thread.update(&mut cx, |thread, cx| {
411                        match event {
412                            LanguageModelCompletionEvent::StartMessage { .. } => {
413                                thread.insert_message(Role::Assistant, String::new(), cx);
414                            }
415                            LanguageModelCompletionEvent::Stop(reason) => {
416                                stop_reason = reason;
417                            }
418                            LanguageModelCompletionEvent::Text(chunk) => {
419                                if let Some(last_message) = thread.messages.last_mut() {
420                                    if last_message.role == Role::Assistant {
421                                        last_message.text.push_str(&chunk);
422                                        cx.emit(ThreadEvent::StreamedAssistantText(
423                                            last_message.id,
424                                            chunk,
425                                        ));
426                                    } else {
427                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
428                                        // of a new Assistant response.
429                                        //
430                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
431                                        // will result in duplicating the text of the chunk in the rendered Markdown.
432                                        thread.insert_message(Role::Assistant, chunk, cx);
433                                    }
434                                }
435                            }
436                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
437                                if let Some(last_assistant_message) = thread
438                                    .messages
439                                    .iter()
440                                    .rfind(|message| message.role == Role::Assistant)
441                                {
442                                    thread
443                                        .tool_use
444                                        .request_tool_use(last_assistant_message.id, tool_use);
445                                }
446                            }
447                        }
448
449                        thread.touch_updated_at();
450                        cx.emit(ThreadEvent::StreamedCompletion);
451                        cx.notify();
452                    })?;
453
454                    smol::future::yield_now().await;
455                }
456
457                thread.update(&mut cx, |thread, cx| {
458                    thread
459                        .pending_completions
460                        .retain(|completion| completion.id != pending_completion_id);
461
462                    if thread.summary.is_none() && thread.messages.len() >= 2 {
463                        thread.summarize(cx);
464                    }
465                })?;
466
467                anyhow::Ok(stop_reason)
468            };
469
470            let result = stream_completion.await;
471
472            thread
473                .update(&mut cx, |thread, cx| match result.as_ref() {
474                    Ok(stop_reason) => match stop_reason {
475                        StopReason::ToolUse => {
476                            cx.emit(ThreadEvent::UsePendingTools);
477                        }
478                        StopReason::EndTurn => {}
479                        StopReason::MaxTokens => {}
480                    },
481                    Err(error) => {
482                        if error.is::<PaymentRequiredError>() {
483                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
484                        } else if error.is::<MaxMonthlySpendReachedError>() {
485                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
486                        } else {
487                            let error_message = error
488                                .chain()
489                                .map(|err| err.to_string())
490                                .collect::<Vec<_>>()
491                                .join("\n");
492                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
493                                SharedString::from(error_message.clone()),
494                            )));
495                        }
496
497                        thread.cancel_last_completion();
498                    }
499                })
500                .ok();
501        });
502
503        self.pending_completions.push(PendingCompletion {
504            id: pending_completion_id,
505            _task: task,
506        });
507    }
508
509    pub fn summarize(&mut self, cx: &mut Context<Self>) {
510        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
511            return;
512        };
513        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
514            return;
515        };
516
517        if !provider.is_authenticated(cx) {
518            return;
519        }
520
521        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
522        request.messages.push(LanguageModelRequestMessage {
523            role: Role::User,
524            content: vec![
525                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
526                    .into(),
527            ],
528            cache: false,
529        });
530
531        self.pending_summary = cx.spawn(|this, mut cx| {
532            async move {
533                let stream = model.stream_completion_text(request, &cx);
534                let mut messages = stream.await?;
535
536                let mut new_summary = String::new();
537                while let Some(message) = messages.stream.next().await {
538                    let text = message?;
539                    let mut lines = text.lines();
540                    new_summary.extend(lines.next());
541
542                    // Stop if the LLM generated multiple lines.
543                    if lines.next().is_some() {
544                        break;
545                    }
546                }
547
548                this.update(&mut cx, |this, cx| {
549                    if !new_summary.is_empty() {
550                        this.summary = Some(new_summary.into());
551                    }
552
553                    cx.emit(ThreadEvent::SummaryChanged);
554                })?;
555
556                anyhow::Ok(())
557            }
558            .log_err()
559        });
560    }
561
562    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
563        let pending_tool_uses = self
564            .pending_tool_uses()
565            .into_iter()
566            .filter(|tool_use| tool_use.status.is_idle())
567            .cloned()
568            .collect::<Vec<_>>();
569
570        for tool_use in pending_tool_uses {
571            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
572                let task = tool.run(tool_use.input, self.project.clone(), cx);
573
574                self.insert_tool_output(tool_use.id.clone(), task, cx);
575            }
576        }
577    }
578
579    pub fn insert_tool_output(
580        &mut self,
581        tool_use_id: LanguageModelToolUseId,
582        output: Task<Result<String>>,
583        cx: &mut Context<Self>,
584    ) {
585        let insert_output_task = cx.spawn(|thread, mut cx| {
586            let tool_use_id = tool_use_id.clone();
587            async move {
588                let output = output.await;
589                thread
590                    .update(&mut cx, |thread, cx| {
591                        thread
592                            .tool_use
593                            .insert_tool_output(tool_use_id.clone(), output);
594
595                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
596                    })
597                    .ok();
598            }
599        });
600
601        self.tool_use
602            .run_pending_tool(tool_use_id, insert_output_task);
603    }
604
605    pub fn send_tool_results_to_model(
606        &mut self,
607        model: Arc<dyn LanguageModel>,
608        cx: &mut Context<Self>,
609    ) {
610        // Insert a user message to contain the tool results.
611        self.insert_user_message(
612            // TODO: Sending up a user message without any content results in the model sending back
613            // responses that also don't have any content. We currently don't handle this case well,
614            // so for now we provide some text to keep the model on track.
615            "Here are the tool results.",
616            Vec::new(),
617            cx,
618        );
619        self.send_to_model(model, RequestKind::Chat, true, cx);
620    }
621
622    /// Cancels the last pending completion, if there are any pending.
623    ///
624    /// Returns whether a completion was canceled.
625    pub fn cancel_last_completion(&mut self) -> bool {
626        if let Some(_last_completion) = self.pending_completions.pop() {
627            true
628        } else {
629            false
630        }
631    }
632}
633
634#[derive(Debug, Clone)]
635pub enum ThreadError {
636    PaymentRequired,
637    MaxMonthlySpendReached,
638    Message(SharedString),
639}
640
641#[derive(Debug, Clone)]
642pub enum ThreadEvent {
643    ShowError(ThreadError),
644    StreamedCompletion,
645    StreamedAssistantText(MessageId, String),
646    MessageAdded(MessageId),
647    MessageEdited(MessageId),
648    MessageDeleted(MessageId),
649    SummaryChanged,
650    UsePendingTools,
651    ToolFinished {
652        #[allow(unused)]
653        tool_use_id: LanguageModelToolUseId,
654    },
655}
656
657impl EventEmitter<ThreadEvent> for Thread {}
658
659struct PendingCompletion {
660    id: usize,
661    _task: Task<()>,
662}