1use std::sync::Arc;
2
3use futures::StreamExt as _;
4use gpui::{AppContext, EventEmitter, ModelContext, Task};
5use language_model::{
6 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
7 MessageContent, Role, StopReason,
8};
9use util::{post_inc, ResultExt as _};
10
11#[derive(Debug, Clone, Copy)]
12pub enum RequestKind {
13 Chat,
14}
15
16/// A message in a [`Thread`].
17pub struct Message {
18 pub role: Role,
19 pub text: String,
20}
21
22struct PendingCompletion {
23 id: usize,
24 _task: Task<()>,
25}
26
27/// A thread of conversation with the LLM.
28pub struct Thread {
29 messages: Vec<Message>,
30 completion_count: usize,
31 pending_completions: Vec<PendingCompletion>,
32}
33
34impl Thread {
35 pub fn new(_cx: &mut ModelContext<Self>) -> Self {
36 Self {
37 messages: Vec::new(),
38 completion_count: 0,
39 pending_completions: Vec::new(),
40 }
41 }
42
43 pub fn messages(&self) -> impl Iterator<Item = &Message> {
44 self.messages.iter()
45 }
46
47 pub fn insert_user_message(&mut self, text: impl Into<String>) {
48 self.messages.push(Message {
49 role: Role::User,
50 text: text.into(),
51 });
52 }
53
54 pub fn to_completion_request(
55 &self,
56 _request_kind: RequestKind,
57 _cx: &AppContext,
58 ) -> LanguageModelRequest {
59 let mut request = LanguageModelRequest {
60 messages: vec![],
61 tools: Vec::new(),
62 stop: Vec::new(),
63 temperature: None,
64 };
65
66 for message in &self.messages {
67 let mut request_message = LanguageModelRequestMessage {
68 role: message.role,
69 content: Vec::new(),
70 cache: false,
71 };
72
73 request_message
74 .content
75 .push(MessageContent::Text(message.text.clone()));
76
77 request.messages.push(request_message);
78 }
79
80 request
81 }
82
83 pub fn stream_completion(
84 &mut self,
85 request: LanguageModelRequest,
86 model: Arc<dyn LanguageModel>,
87 cx: &mut ModelContext<Self>,
88 ) {
89 let pending_completion_id = post_inc(&mut self.completion_count);
90
91 let task = cx.spawn(|thread, mut cx| async move {
92 let stream = model.stream_completion(request, &cx);
93 let stream_completion = async {
94 let mut events = stream.await?;
95 let mut stop_reason = StopReason::EndTurn;
96
97 while let Some(event) = events.next().await {
98 let event = event?;
99
100 thread.update(&mut cx, |thread, cx| {
101 match event {
102 LanguageModelCompletionEvent::StartMessage { .. } => {
103 thread.messages.push(Message {
104 role: Role::Assistant,
105 text: String::new(),
106 });
107 }
108 LanguageModelCompletionEvent::Stop(reason) => {
109 stop_reason = reason;
110 }
111 LanguageModelCompletionEvent::Text(chunk) => {
112 if let Some(last_message) = thread.messages.last_mut() {
113 if last_message.role == Role::Assistant {
114 last_message.text.push_str(&chunk);
115 }
116 }
117 }
118 LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
119 }
120
121 cx.emit(ThreadEvent::StreamedCompletion);
122 cx.notify();
123 })?;
124
125 smol::future::yield_now().await;
126 }
127
128 thread.update(&mut cx, |thread, _cx| {
129 thread
130 .pending_completions
131 .retain(|completion| completion.id != pending_completion_id);
132 })?;
133
134 anyhow::Ok(stop_reason)
135 };
136
137 let result = stream_completion.await;
138 let _ = result.log_err();
139 });
140
141 self.pending_completions.push(PendingCompletion {
142 id: pending_completion_id,
143 _task: task,
144 });
145 }
146}
147
148#[derive(Debug, Clone)]
149pub enum ThreadEvent {
150 StreamedCompletion,
151}
152
153impl EventEmitter<ThreadEvent> for Thread {}