1use std::sync::Arc;
2
3use futures::StreamExt as _;
4use gpui::{EventEmitter, ModelContext, Task};
5use language_model::{
6 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, Role, StopReason,
7};
8use util::ResultExt as _;
9
10/// A message in a [`Thread`].
11pub struct Message {
12 pub role: Role,
13 pub text: String,
14}
15
16/// A thread of conversation with the LLM.
17pub struct Thread {
18 pub messages: Vec<Message>,
19 pub pending_completion_tasks: Vec<Task<()>>,
20}
21
22impl Thread {
23 pub fn new(_cx: &mut ModelContext<Self>) -> Self {
24 Self {
25 messages: Vec::new(),
26 pending_completion_tasks: Vec::new(),
27 }
28 }
29
30 pub fn stream_completion(
31 &mut self,
32 request: LanguageModelRequest,
33 model: Arc<dyn LanguageModel>,
34 cx: &mut ModelContext<Self>,
35 ) {
36 let task = cx.spawn(|this, mut cx| async move {
37 let stream = model.stream_completion(request, &cx);
38 let stream_completion = async {
39 let mut events = stream.await?;
40 let mut stop_reason = StopReason::EndTurn;
41
42 while let Some(event) = events.next().await {
43 let event = event?;
44
45 this.update(&mut cx, |thread, cx| {
46 match event {
47 LanguageModelCompletionEvent::StartMessage { .. } => {
48 thread.messages.push(Message {
49 role: Role::Assistant,
50 text: String::new(),
51 });
52 }
53 LanguageModelCompletionEvent::Stop(reason) => {
54 stop_reason = reason;
55 }
56 LanguageModelCompletionEvent::Text(chunk) => {
57 if let Some(last_message) = thread.messages.last_mut() {
58 if last_message.role == Role::Assistant {
59 last_message.text.push_str(&chunk);
60 }
61 }
62 }
63 LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
64 }
65
66 cx.emit(ThreadEvent::StreamedCompletion);
67 cx.notify();
68 })?;
69
70 smol::future::yield_now().await;
71 }
72
73 anyhow::Ok(stop_reason)
74 };
75
76 let result = stream_completion.await;
77 let _ = result.log_err();
78 });
79
80 self.pending_completion_tasks.push(task);
81 }
82}
83
84#[derive(Debug, Clone)]
85pub enum ThreadEvent {
86 StreamedCompletion,
87}
88
89impl EventEmitter<ThreadEvent> for Thread {}