thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, EventEmitter, ModelContext, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
 11    LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12    StopReason,
 13};
 14use serde::{Deserialize, Serialize};
 15use util::post_inc;
 16
 17#[derive(Debug, Clone, Copy)]
 18pub enum RequestKind {
 19    Chat,
 20}
 21
 22#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 23pub struct MessageId(usize);
 24
 25impl MessageId {
 26    fn post_inc(&mut self) -> Self {
 27        Self(post_inc(&mut self.0))
 28    }
 29}
 30
 31/// A message in a [`Thread`].
 32#[derive(Debug, Clone)]
 33pub struct Message {
 34    pub id: MessageId,
 35    pub role: Role,
 36    pub text: String,
 37}
 38
 39/// A thread of conversation with the LLM.
 40pub struct Thread {
 41    messages: Vec<Message>,
 42    next_message_id: MessageId,
 43    completion_count: usize,
 44    pending_completions: Vec<PendingCompletion>,
 45    tools: Arc<ToolWorkingSet>,
 46    tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 47    tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
 48    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 49}
 50
 51impl Thread {
 52    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
 53        Self {
 54            messages: Vec::new(),
 55            next_message_id: MessageId(0),
 56            completion_count: 0,
 57            pending_completions: Vec::new(),
 58            tools,
 59            tool_uses_by_message: HashMap::default(),
 60            tool_results_by_message: HashMap::default(),
 61            pending_tool_uses_by_id: HashMap::default(),
 62        }
 63    }
 64
 65    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 66        self.messages.iter()
 67    }
 68
 69    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 70        &self.tools
 71    }
 72
 73    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
 74        self.pending_tool_uses_by_id.values().collect()
 75    }
 76
 77    pub fn insert_user_message(&mut self, text: impl Into<String>) {
 78        self.messages.push(Message {
 79            id: self.next_message_id.post_inc(),
 80            role: Role::User,
 81            text: text.into(),
 82        });
 83    }
 84
 85    pub fn to_completion_request(
 86        &self,
 87        _request_kind: RequestKind,
 88        _cx: &AppContext,
 89    ) -> LanguageModelRequest {
 90        let mut request = LanguageModelRequest {
 91            messages: vec![],
 92            tools: Vec::new(),
 93            stop: Vec::new(),
 94            temperature: None,
 95        };
 96
 97        for message in &self.messages {
 98            let mut request_message = LanguageModelRequestMessage {
 99                role: message.role,
100                content: Vec::new(),
101                cache: false,
102            };
103
104            if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
105                for tool_result in tool_results {
106                    request_message
107                        .content
108                        .push(MessageContent::ToolResult(tool_result.clone()));
109                }
110            }
111
112            if !message.text.is_empty() {
113                request_message
114                    .content
115                    .push(MessageContent::Text(message.text.clone()));
116            }
117
118            if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
119                for tool_use in tool_uses {
120                    request_message
121                        .content
122                        .push(MessageContent::ToolUse(tool_use.clone()));
123                }
124            }
125
126            request.messages.push(request_message);
127        }
128
129        request
130    }
131
132    pub fn stream_completion(
133        &mut self,
134        request: LanguageModelRequest,
135        model: Arc<dyn LanguageModel>,
136        cx: &mut ModelContext<Self>,
137    ) {
138        let pending_completion_id = post_inc(&mut self.completion_count);
139
140        let task = cx.spawn(|thread, mut cx| async move {
141            let stream = model.stream_completion(request, &cx);
142            let stream_completion = async {
143                let mut events = stream.await?;
144                let mut stop_reason = StopReason::EndTurn;
145
146                while let Some(event) = events.next().await {
147                    let event = event?;
148
149                    thread.update(&mut cx, |thread, cx| {
150                        match event {
151                            LanguageModelCompletionEvent::StartMessage { .. } => {
152                                thread.messages.push(Message {
153                                    id: thread.next_message_id.post_inc(),
154                                    role: Role::Assistant,
155                                    text: String::new(),
156                                });
157                            }
158                            LanguageModelCompletionEvent::Stop(reason) => {
159                                stop_reason = reason;
160                            }
161                            LanguageModelCompletionEvent::Text(chunk) => {
162                                if let Some(last_message) = thread.messages.last_mut() {
163                                    if last_message.role == Role::Assistant {
164                                        last_message.text.push_str(&chunk);
165                                    }
166                                }
167                            }
168                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
169                                if let Some(last_assistant_message) = thread
170                                    .messages
171                                    .iter()
172                                    .rfind(|message| message.role == Role::Assistant)
173                                {
174                                    thread
175                                        .tool_uses_by_message
176                                        .entry(last_assistant_message.id)
177                                        .or_default()
178                                        .push(tool_use.clone());
179
180                                    thread.pending_tool_uses_by_id.insert(
181                                        tool_use.id.clone(),
182                                        PendingToolUse {
183                                            assistant_message_id: last_assistant_message.id,
184                                            id: tool_use.id,
185                                            name: tool_use.name,
186                                            input: tool_use.input,
187                                            status: PendingToolUseStatus::Idle,
188                                        },
189                                    );
190                                }
191                            }
192                        }
193
194                        cx.emit(ThreadEvent::StreamedCompletion);
195                        cx.notify();
196                    })?;
197
198                    smol::future::yield_now().await;
199                }
200
201                thread.update(&mut cx, |thread, _cx| {
202                    thread
203                        .pending_completions
204                        .retain(|completion| completion.id != pending_completion_id);
205                })?;
206
207                anyhow::Ok(stop_reason)
208            };
209
210            let result = stream_completion.await;
211
212            thread
213                .update(&mut cx, |_thread, cx| {
214                    let error_message = if let Some(error) = result.as_ref().err() {
215                        let error_message = error
216                            .chain()
217                            .map(|err| err.to_string())
218                            .collect::<Vec<_>>()
219                            .join("\n");
220                        Some(error_message)
221                    } else {
222                        None
223                    };
224
225                    if let Some(error_message) = error_message {
226                        eprintln!("Completion failed: {error_message:?}");
227                    }
228
229                    if let Ok(stop_reason) = result {
230                        match stop_reason {
231                            StopReason::ToolUse => {
232                                cx.emit(ThreadEvent::UsePendingTools);
233                            }
234                            StopReason::EndTurn => {}
235                            StopReason::MaxTokens => {}
236                        }
237                    }
238                })
239                .ok();
240        });
241
242        self.pending_completions.push(PendingCompletion {
243            id: pending_completion_id,
244            _task: task,
245        });
246    }
247
248    pub fn insert_tool_output(
249        &mut self,
250        assistant_message_id: MessageId,
251        tool_use_id: LanguageModelToolUseId,
252        output: Task<Result<String>>,
253        cx: &mut ModelContext<Self>,
254    ) {
255        let insert_output_task = cx.spawn(|thread, mut cx| {
256            let tool_use_id = tool_use_id.clone();
257            async move {
258                let output = output.await;
259                thread
260                    .update(&mut cx, |thread, cx| {
261                        // The tool use was requested by an Assistant message,
262                        // so we want to attach the tool results to the next
263                        // user message.
264                        let next_user_message = MessageId(assistant_message_id.0 + 1);
265
266                        let tool_results = thread
267                            .tool_results_by_message
268                            .entry(next_user_message)
269                            .or_default();
270
271                        match output {
272                            Ok(output) => {
273                                tool_results.push(LanguageModelToolResult {
274                                    tool_use_id: tool_use_id.to_string(),
275                                    content: output,
276                                    is_error: false,
277                                });
278
279                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
280                            }
281                            Err(err) => {
282                                tool_results.push(LanguageModelToolResult {
283                                    tool_use_id: tool_use_id.to_string(),
284                                    content: err.to_string(),
285                                    is_error: true,
286                                });
287
288                                if let Some(tool_use) =
289                                    thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
290                                {
291                                    tool_use.status = PendingToolUseStatus::Error(err.to_string());
292                                }
293                            }
294                        }
295                    })
296                    .ok();
297            }
298        });
299
300        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
301            tool_use.status = PendingToolUseStatus::Running {
302                _task: insert_output_task.shared(),
303            };
304        }
305    }
306}
307
308#[derive(Debug, Clone)]
309pub enum ThreadEvent {
310    StreamedCompletion,
311    UsePendingTools,
312    ToolFinished {
313        #[allow(unused)]
314        tool_use_id: LanguageModelToolUseId,
315    },
316}
317
318impl EventEmitter<ThreadEvent> for Thread {}
319
320struct PendingCompletion {
321    id: usize,
322    _task: Task<()>,
323}
324
325#[derive(Debug, Clone)]
326pub struct PendingToolUse {
327    pub id: LanguageModelToolUseId,
328    /// The ID of the Assistant message in which the tool use was requested.
329    pub assistant_message_id: MessageId,
330    pub name: String,
331    pub input: serde_json::Value,
332    pub status: PendingToolUseStatus,
333}
334
335#[derive(Debug, Clone)]
336pub enum PendingToolUseStatus {
337    Idle,
338    Running { _task: Shared<Task<()>> },
339    Error(#[allow(unused)] String),
340}
341
342impl PendingToolUseStatus {
343    pub fn is_idle(&self) -> bool {
344        matches!(self, PendingToolUseStatus::Idle)
345    }
346}