thread.rs

 1use std::sync::Arc;
 2
 3use futures::StreamExt as _;
 4use gpui::{EventEmitter, ModelContext, Task};
 5use language_model::{
 6    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, Role, StopReason,
 7};
 8use util::ResultExt as _;
 9
10/// A message in a [`Thread`].
11pub struct Message {
12    pub role: Role,
13    pub text: String,
14}
15
16/// A thread of conversation with the LLM.
17pub struct Thread {
18    pub messages: Vec<Message>,
19    pub pending_completion_tasks: Vec<Task<()>>,
20}
21
22impl Thread {
23    pub fn new(_cx: &mut ModelContext<Self>) -> Self {
24        Self {
25            messages: Vec::new(),
26            pending_completion_tasks: Vec::new(),
27        }
28    }
29
30    pub fn stream_completion(
31        &mut self,
32        request: LanguageModelRequest,
33        model: Arc<dyn LanguageModel>,
34        cx: &mut ModelContext<Self>,
35    ) {
36        let task = cx.spawn(|this, mut cx| async move {
37            let stream = model.stream_completion(request, &cx);
38            let stream_completion = async {
39                let mut events = stream.await?;
40                let mut stop_reason = StopReason::EndTurn;
41
42                while let Some(event) = events.next().await {
43                    let event = event?;
44
45                    this.update(&mut cx, |thread, cx| {
46                        match event {
47                            LanguageModelCompletionEvent::StartMessage { .. } => {
48                                thread.messages.push(Message {
49                                    role: Role::Assistant,
50                                    text: String::new(),
51                                });
52                            }
53                            LanguageModelCompletionEvent::Stop(reason) => {
54                                stop_reason = reason;
55                            }
56                            LanguageModelCompletionEvent::Text(chunk) => {
57                                if let Some(last_message) = thread.messages.last_mut() {
58                                    if last_message.role == Role::Assistant {
59                                        last_message.text.push_str(&chunk);
60                                    }
61                                }
62                            }
63                            LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
64                        }
65
66                        cx.emit(ThreadEvent::StreamedCompletion);
67                        cx.notify();
68                    })?;
69
70                    smol::future::yield_now().await;
71                }
72
73                anyhow::Ok(stop_reason)
74            };
75
76            let result = stream_completion.await;
77            let _ = result.log_err();
78        });
79
80        self.pending_completion_tasks.push(task);
81    }
82}
83
84#[derive(Debug, Clone)]
85pub enum ThreadEvent {
86    StreamedCompletion,
87}
88
89impl EventEmitter<ThreadEvent> for Thread {}