1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use futures::future::Shared;
7use futures::{FutureExt as _, StreamExt as _};
8use gpui::{AppContext, EventEmitter, ModelContext, Task};
9use language_model::{
10 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
11 LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
12 StopReason,
13};
14use serde::{Deserialize, Serialize};
15use util::post_inc;
16
17#[derive(Debug, Clone, Copy)]
18pub enum RequestKind {
19 Chat,
20}
21
22#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
23pub struct MessageId(usize);
24
25impl MessageId {
26 fn post_inc(&mut self) -> Self {
27 Self(post_inc(&mut self.0))
28 }
29}
30
31/// A message in a [`Thread`].
32#[derive(Debug, Clone)]
33pub struct Message {
34 pub id: MessageId,
35 pub role: Role,
36 pub text: String,
37}
38
39/// A thread of conversation with the LLM.
40pub struct Thread {
41 messages: Vec<Message>,
42 next_message_id: MessageId,
43 completion_count: usize,
44 pending_completions: Vec<PendingCompletion>,
45 tools: Arc<ToolWorkingSet>,
46 tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
47 tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
48 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
49}
50
51impl Thread {
52 pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
53 Self {
54 messages: Vec::new(),
55 next_message_id: MessageId(0),
56 completion_count: 0,
57 pending_completions: Vec::new(),
58 tools,
59 tool_uses_by_message: HashMap::default(),
60 tool_results_by_message: HashMap::default(),
61 pending_tool_uses_by_id: HashMap::default(),
62 }
63 }
64
65 pub fn messages(&self) -> impl Iterator<Item = &Message> {
66 self.messages.iter()
67 }
68
69 pub fn tools(&self) -> &Arc<ToolWorkingSet> {
70 &self.tools
71 }
72
73 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
74 self.pending_tool_uses_by_id.values().collect()
75 }
76
77 pub fn insert_user_message(&mut self, text: impl Into<String>) {
78 self.messages.push(Message {
79 id: self.next_message_id.post_inc(),
80 role: Role::User,
81 text: text.into(),
82 });
83 }
84
85 pub fn to_completion_request(
86 &self,
87 _request_kind: RequestKind,
88 _cx: &AppContext,
89 ) -> LanguageModelRequest {
90 let mut request = LanguageModelRequest {
91 messages: vec![],
92 tools: Vec::new(),
93 stop: Vec::new(),
94 temperature: None,
95 };
96
97 for message in &self.messages {
98 let mut request_message = LanguageModelRequestMessage {
99 role: message.role,
100 content: Vec::new(),
101 cache: false,
102 };
103
104 if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
105 for tool_result in tool_results {
106 request_message
107 .content
108 .push(MessageContent::ToolResult(tool_result.clone()));
109 }
110 }
111
112 if !message.text.is_empty() {
113 request_message
114 .content
115 .push(MessageContent::Text(message.text.clone()));
116 }
117
118 if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
119 for tool_use in tool_uses {
120 request_message
121 .content
122 .push(MessageContent::ToolUse(tool_use.clone()));
123 }
124 }
125
126 request.messages.push(request_message);
127 }
128
129 request
130 }
131
132 pub fn stream_completion(
133 &mut self,
134 request: LanguageModelRequest,
135 model: Arc<dyn LanguageModel>,
136 cx: &mut ModelContext<Self>,
137 ) {
138 let pending_completion_id = post_inc(&mut self.completion_count);
139
140 let task = cx.spawn(|thread, mut cx| async move {
141 let stream = model.stream_completion(request, &cx);
142 let stream_completion = async {
143 let mut events = stream.await?;
144 let mut stop_reason = StopReason::EndTurn;
145
146 while let Some(event) = events.next().await {
147 let event = event?;
148
149 thread.update(&mut cx, |thread, cx| {
150 match event {
151 LanguageModelCompletionEvent::StartMessage { .. } => {
152 thread.messages.push(Message {
153 id: thread.next_message_id.post_inc(),
154 role: Role::Assistant,
155 text: String::new(),
156 });
157 }
158 LanguageModelCompletionEvent::Stop(reason) => {
159 stop_reason = reason;
160 }
161 LanguageModelCompletionEvent::Text(chunk) => {
162 if let Some(last_message) = thread.messages.last_mut() {
163 if last_message.role == Role::Assistant {
164 last_message.text.push_str(&chunk);
165 }
166 }
167 }
168 LanguageModelCompletionEvent::ToolUse(tool_use) => {
169 if let Some(last_assistant_message) = thread
170 .messages
171 .iter()
172 .rfind(|message| message.role == Role::Assistant)
173 {
174 thread
175 .tool_uses_by_message
176 .entry(last_assistant_message.id)
177 .or_default()
178 .push(tool_use.clone());
179
180 thread.pending_tool_uses_by_id.insert(
181 tool_use.id.clone(),
182 PendingToolUse {
183 assistant_message_id: last_assistant_message.id,
184 id: tool_use.id,
185 name: tool_use.name,
186 input: tool_use.input,
187 status: PendingToolUseStatus::Idle,
188 },
189 );
190 }
191 }
192 }
193
194 cx.emit(ThreadEvent::StreamedCompletion);
195 cx.notify();
196 })?;
197
198 smol::future::yield_now().await;
199 }
200
201 thread.update(&mut cx, |thread, _cx| {
202 thread
203 .pending_completions
204 .retain(|completion| completion.id != pending_completion_id);
205 })?;
206
207 anyhow::Ok(stop_reason)
208 };
209
210 let result = stream_completion.await;
211
212 thread
213 .update(&mut cx, |_thread, cx| {
214 let error_message = if let Some(error) = result.as_ref().err() {
215 let error_message = error
216 .chain()
217 .map(|err| err.to_string())
218 .collect::<Vec<_>>()
219 .join("\n");
220 Some(error_message)
221 } else {
222 None
223 };
224
225 if let Some(error_message) = error_message {
226 eprintln!("Completion failed: {error_message:?}");
227 }
228
229 if let Ok(stop_reason) = result {
230 match stop_reason {
231 StopReason::ToolUse => {
232 cx.emit(ThreadEvent::UsePendingTools);
233 }
234 StopReason::EndTurn => {}
235 StopReason::MaxTokens => {}
236 }
237 }
238 })
239 .ok();
240 });
241
242 self.pending_completions.push(PendingCompletion {
243 id: pending_completion_id,
244 _task: task,
245 });
246 }
247
248 pub fn insert_tool_output(
249 &mut self,
250 assistant_message_id: MessageId,
251 tool_use_id: LanguageModelToolUseId,
252 output: Task<Result<String>>,
253 cx: &mut ModelContext<Self>,
254 ) {
255 let insert_output_task = cx.spawn(|thread, mut cx| {
256 let tool_use_id = tool_use_id.clone();
257 async move {
258 let output = output.await;
259 thread
260 .update(&mut cx, |thread, cx| {
261 // The tool use was requested by an Assistant message,
262 // so we want to attach the tool results to the next
263 // user message.
264 let next_user_message = MessageId(assistant_message_id.0 + 1);
265
266 let tool_results = thread
267 .tool_results_by_message
268 .entry(next_user_message)
269 .or_default();
270
271 match output {
272 Ok(output) => {
273 tool_results.push(LanguageModelToolResult {
274 tool_use_id: tool_use_id.to_string(),
275 content: output,
276 is_error: false,
277 });
278
279 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
280 }
281 Err(err) => {
282 tool_results.push(LanguageModelToolResult {
283 tool_use_id: tool_use_id.to_string(),
284 content: err.to_string(),
285 is_error: true,
286 });
287
288 if let Some(tool_use) =
289 thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
290 {
291 tool_use.status = PendingToolUseStatus::Error(err.to_string());
292 }
293 }
294 }
295 })
296 .ok();
297 }
298 });
299
300 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
301 tool_use.status = PendingToolUseStatus::Running {
302 _task: insert_output_task.shared(),
303 };
304 }
305 }
306}
307
308#[derive(Debug, Clone)]
309pub enum ThreadEvent {
310 StreamedCompletion,
311 UsePendingTools,
312 ToolFinished {
313 #[allow(unused)]
314 tool_use_id: LanguageModelToolUseId,
315 },
316}
317
318impl EventEmitter<ThreadEvent> for Thread {}
319
320struct PendingCompletion {
321 id: usize,
322 _task: Task<()>,
323}
324
325#[derive(Debug, Clone)]
326pub struct PendingToolUse {
327 pub id: LanguageModelToolUseId,
328 /// The ID of the Assistant message in which the tool use was requested.
329 pub assistant_message_id: MessageId,
330 pub name: String,
331 pub input: serde_json::Value,
332 pub status: PendingToolUseStatus,
333}
334
335#[derive(Debug, Clone)]
336pub enum PendingToolUseStatus {
337 Idle,
338 Running { _task: Shared<Task<()>> },
339 Error(#[allow(unused)] String),
340}
341
342impl PendingToolUseStatus {
343 pub fn is_idle(&self) -> bool {
344 matches!(self, PendingToolUseStatus::Idle)
345 }
346}