thread.rs

  1use crate::templates::Templates;
  2use anyhow::{anyhow, Result};
  3use cloud_llm_client::{CompletionIntent, CompletionMode};
  4use futures::{channel::mpsc, future};
  5use gpui::{App, Context, SharedString, Task};
  6use language_model::{
  7    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  8    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  9    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
 10    LanguageModelToolUse, MessageContent, Role, StopReason,
 11};
 12use schemars::{JsonSchema, Schema};
 13use serde::Deserialize;
 14use smol::stream::StreamExt;
 15use std::{collections::BTreeMap, sync::Arc};
 16use util::ResultExt;
 17
 18#[derive(Debug)]
 19pub struct AgentMessage {
 20    pub role: Role,
 21    pub content: Vec<MessageContent>,
 22}
 23
 24pub type AgentResponseEvent = LanguageModelCompletionEvent;
 25
 26pub trait Prompt {
 27    fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
 28}
 29
 30pub struct Thread {
 31    messages: Vec<AgentMessage>,
 32    completion_mode: CompletionMode,
 33    /// Holds the task that handles agent interaction until the end of the turn.
 34    /// Survives across multiple requests as the model performs tool calls and
 35    /// we run tools, report their results.
 36    running_turn: Option<Task<()>>,
 37    system_prompts: Vec<Arc<dyn Prompt>>,
 38    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 39    templates: Arc<Templates>,
 40    // project: Entity<Project>,
 41    // action_log: Entity<ActionLog>,
 42}
 43
 44impl Thread {
 45    pub fn new(templates: Arc<Templates>) -> Self {
 46        Self {
 47            messages: Vec::new(),
 48            completion_mode: CompletionMode::Normal,
 49            system_prompts: Vec::new(),
 50            running_turn: None,
 51            tools: BTreeMap::default(),
 52            templates,
 53        }
 54    }
 55
 56    pub fn set_mode(&mut self, mode: CompletionMode) {
 57        self.completion_mode = mode;
 58    }
 59
 60    pub fn messages(&self) -> &[AgentMessage] {
 61        &self.messages
 62    }
 63
 64    pub fn add_tool(&mut self, tool: impl AgentTool) {
 65        self.tools.insert(tool.name(), tool.erase());
 66    }
 67
 68    pub fn remove_tool(&mut self, name: &str) -> bool {
 69        self.tools.remove(name).is_some()
 70    }
 71
 72    /// Sending a message results in the model streaming a response, which could include tool calls.
 73    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 74    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 75    pub fn send(
 76        &mut self,
 77        model: Arc<dyn LanguageModel>,
 78        content: impl Into<MessageContent>,
 79        cx: &mut Context<Self>,
 80    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 81        cx.notify();
 82        let (events_tx, events_rx) =
 83            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 84
 85        let system_message = self.build_system_message(cx);
 86        self.messages.extend(system_message);
 87
 88        self.messages.push(AgentMessage {
 89            role: Role::User,
 90            content: vec![content.into()],
 91        });
 92        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 93            let turn_result = async {
 94                // Perform one request, then keep looping if the model makes tool calls.
 95                let mut completion_intent = CompletionIntent::UserPrompt;
 96                loop {
 97                    let request = thread.update(cx, |thread, cx| {
 98                        thread.build_completion_request(completion_intent, cx)
 99                    })?;
100
101                    // println!(
102                    //     "request: {}",
103                    //     serde_json::to_string_pretty(&request).unwrap()
104                    // );
105
106                    // Stream events, appending to messages and collecting up tool uses.
107                    let mut events = model.stream_completion(request, cx).await?;
108                    let mut tool_uses = Vec::new();
109                    while let Some(event) = events.next().await {
110                        match event {
111                            Ok(event) => {
112                                thread
113                                    .update(cx, |thread, cx| {
114                                        tool_uses.extend(thread.handle_streamed_completion_event(
115                                            event,
116                                            events_tx.clone(),
117                                            cx,
118                                        ));
119                                    })
120                                    .ok();
121                            }
122                            Err(error) => {
123                                events_tx.unbounded_send(Err(error)).ok();
124                                break;
125                            }
126                        }
127                    }
128
129                    // If there are no tool uses, the turn is done.
130                    if tool_uses.is_empty() {
131                        break;
132                    }
133
134                    // If there are tool uses, wait for their results to be
135                    // computed, then send them together in a single message on
136                    // the next loop iteration.
137                    let tool_results = future::join_all(tool_uses).await;
138                    thread
139                        .update(cx, |thread, _cx| {
140                            thread.messages.push(AgentMessage {
141                                role: Role::User,
142                                content: tool_results
143                                    .into_iter()
144                                    .map(MessageContent::ToolResult)
145                                    .collect(),
146                            });
147                        })
148                        .ok();
149                    completion_intent = CompletionIntent::ToolResults;
150                }
151
152                Ok(())
153            }
154            .await;
155
156            if let Err(error) = turn_result {
157                events_tx.unbounded_send(Err(error)).ok();
158            }
159        }));
160        events_rx
161    }
162
163    pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
164        let mut system_message = AgentMessage {
165            role: Role::System,
166            content: Vec::new(),
167        };
168
169        for prompt in &self.system_prompts {
170            if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
171                system_message
172                    .content
173                    .push(MessageContent::Text(rendered_prompt));
174            }
175        }
176
177        (!system_message.content.is_empty()).then_some(system_message)
178    }
179
180    /// A helper method that's called on every streamed completion event.
181    /// Returns an optional tool result task, which the main agentic loop in
182    /// send will send back to the model when it resolves.
183    fn handle_streamed_completion_event(
184        &mut self,
185        event: LanguageModelCompletionEvent,
186        events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
187        cx: &mut Context<Self>,
188    ) -> Option<Task<LanguageModelToolResult>> {
189        use LanguageModelCompletionEvent::*;
190        events_tx.unbounded_send(Ok(event.clone())).ok();
191
192        match event {
193            Text(new_text) => self.handle_text_event(new_text, cx),
194            Thinking {
195                text: _text,
196                signature: _signature,
197            } => {
198                todo!()
199            }
200            ToolUse(tool_use) => {
201                return self.handle_tool_use_event(tool_use, cx);
202            }
203            StartMessage { .. } => {
204                self.messages.push(AgentMessage {
205                    role: Role::Assistant,
206                    content: Vec::new(),
207                });
208            }
209            UsageUpdate(_) => {}
210            Stop(stop_reason) => self.handle_stop_event(stop_reason),
211            StatusUpdate(_completion_request_status) => {}
212            RedactedThinking { data: _data } => todo!(),
213            ToolUseJsonParseError {
214                id: _id,
215                tool_name: _tool_name,
216                raw_input: _raw_input,
217                json_parse_error: _json_parse_error,
218            } => todo!(),
219        }
220
221        None
222    }
223
224    fn handle_stop_event(&mut self, stop_reason: StopReason) {
225        match stop_reason {
226            StopReason::EndTurn | StopReason::ToolUse => {}
227            StopReason::MaxTokens => todo!(),
228            StopReason::Refusal => todo!(),
229        }
230    }
231
232    fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
233        let last_message = self.last_assistant_message();
234        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
235            text.push_str(&new_text);
236        } else {
237            last_message.content.push(MessageContent::Text(new_text));
238        }
239
240        cx.notify();
241    }
242
243    fn handle_tool_use_event(
244        &mut self,
245        tool_use: LanguageModelToolUse,
246        cx: &mut Context<Self>,
247    ) -> Option<Task<LanguageModelToolResult>> {
248        cx.notify();
249
250        let last_message = self.last_assistant_message();
251
252        // Ensure the last message ends in the current tool use
253        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
254            if let MessageContent::ToolUse(last_tool_use) = content {
255                if last_tool_use.id == tool_use.id {
256                    *last_tool_use = tool_use.clone();
257                    false
258                } else {
259                    true
260                }
261            } else {
262                true
263            }
264        });
265        if push_new_tool_use {
266            last_message
267                .content
268                .push(MessageContent::ToolUse(tool_use.clone()));
269        }
270
271        if !tool_use.is_input_complete {
272            return None;
273        }
274
275        if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
276            let pending_tool_result = tool.clone().run(tool_use.input, cx);
277
278            Some(cx.foreground_executor().spawn(async move {
279                match pending_tool_result.await {
280                    Ok(tool_output) => LanguageModelToolResult {
281                        tool_use_id: tool_use.id,
282                        tool_name: tool_use.name,
283                        is_error: false,
284                        content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
285                        output: None,
286                    },
287                    Err(error) => LanguageModelToolResult {
288                        tool_use_id: tool_use.id,
289                        tool_name: tool_use.name,
290                        is_error: true,
291                        content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
292                        output: None,
293                    },
294                }
295            }))
296        } else {
297            Some(Task::ready(LanguageModelToolResult {
298                content: LanguageModelToolResultContent::Text(Arc::from(format!(
299                    "No tool named {} exists",
300                    tool_use.name
301                ))),
302                tool_use_id: tool_use.id,
303                tool_name: tool_use.name,
304                is_error: true,
305                output: None,
306            }))
307        }
308    }
309
310    /// Guarantees the last message is from the assistant and returns a mutable reference.
311    fn last_assistant_message(&mut self) -> &mut AgentMessage {
312        if self
313            .messages
314            .last()
315            .map_or(true, |m| m.role != Role::Assistant)
316        {
317            self.messages.push(AgentMessage {
318                role: Role::Assistant,
319                content: Vec::new(),
320            });
321        }
322        self.messages.last_mut().unwrap()
323    }
324
325    fn build_completion_request(
326        &self,
327        completion_intent: CompletionIntent,
328        cx: &mut App,
329    ) -> LanguageModelRequest {
330        LanguageModelRequest {
331            thread_id: None,
332            prompt_id: None,
333            intent: Some(completion_intent),
334            mode: Some(self.completion_mode),
335            messages: self.build_request_messages(),
336            tools: self
337                .tools
338                .values()
339                .filter_map(|tool| {
340                    Some(LanguageModelRequestTool {
341                        name: tool.name().to_string(),
342                        description: tool.description(cx).to_string(),
343                        input_schema: tool
344                            .input_schema(LanguageModelToolSchemaFormat::JsonSchema)
345                            .log_err()?,
346                    })
347                })
348                .collect(),
349            tool_choice: None,
350            stop: Vec::new(),
351            temperature: None,
352            thinking_allowed: false,
353        }
354    }
355
356    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
357        self.messages
358            .iter()
359            .map(|message| LanguageModelRequestMessage {
360                role: message.role,
361                content: message.content.clone(),
362                cache: false,
363            })
364            .collect()
365    }
366}
367
368pub trait AgentTool
369where
370    Self: 'static + Sized,
371{
372    type Input: for<'de> Deserialize<'de> + JsonSchema;
373
374    fn name(&self) -> SharedString;
375    fn description(&self, _cx: &mut App) -> SharedString {
376        let schema = schemars::schema_for!(Self::Input);
377        SharedString::new(
378            schema
379                .get("description")
380                .and_then(|description| description.as_str())
381                .unwrap_or_default(),
382        )
383    }
384
385    /// Returns the JSON schema that describes the tool's input.
386    fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
387        schemars::schema_for!(Self::Input)
388    }
389
390    /// Runs the tool with the provided input.
391    fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
392
393    fn erase(self) -> Arc<dyn AnyAgentTool> {
394        Arc::new(Erased(Arc::new(self)))
395    }
396}
397
398pub struct Erased<T>(T);
399
400pub trait AnyAgentTool {
401    fn name(&self) -> SharedString;
402    fn description(&self, cx: &mut App) -> SharedString;
403    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
404    fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
405}
406
407impl<T> AnyAgentTool for Erased<Arc<T>>
408where
409    T: AgentTool,
410{
411    fn name(&self) -> SharedString {
412        self.0.name()
413    }
414
415    fn description(&self, cx: &mut App) -> SharedString {
416        self.0.description(cx)
417    }
418
419    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
420        Ok(serde_json::to_value(self.0.input_schema(format))?)
421    }
422
423    fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
424        let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
425        match parsed_input {
426            Ok(input) => self.0.clone().run(input, cx),
427            Err(error) => Task::ready(Err(anyhow!(error))),
428        }
429    }
430}