thread.rs

  1use crate::templates::Templates;
  2use anyhow::{anyhow, Result};
  3use cloud_llm_client::{CompletionIntent, CompletionMode};
  4use futures::{channel::mpsc, future};
  5use gpui::{App, Context, SharedString, Task};
  6use language_model::{
  7    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  8    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  9    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
 10    LanguageModelToolUse, MessageContent, Role, StopReason,
 11};
 12use log;
 13use schemars::{JsonSchema, Schema};
 14use serde::Deserialize;
 15use smol::stream::StreamExt;
 16use std::{collections::BTreeMap, sync::Arc};
 17use util::ResultExt;
 18
 19#[derive(Debug)]
 20pub struct AgentMessage {
 21    pub role: Role,
 22    pub content: Vec<MessageContent>,
 23}
 24
 25pub type AgentResponseEvent = LanguageModelCompletionEvent;
 26
 27pub trait Prompt {
 28    fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
 29}
 30
 31pub struct Thread {
 32    messages: Vec<AgentMessage>,
 33    completion_mode: CompletionMode,
 34    /// Holds the task that handles agent interaction until the end of the turn.
 35    /// Survives across multiple requests as the model performs tool calls and
 36    /// we run tools, report their results.
 37    running_turn: Option<Task<()>>,
 38    system_prompts: Vec<Arc<dyn Prompt>>,
 39    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 40    templates: Arc<Templates>,
 41    pub selected_model: Arc<dyn LanguageModel>,
 42    // action_log: Entity<ActionLog>,
 43}
 44
 45impl Thread {
 46    pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
 47        Self {
 48            messages: Vec::new(),
 49            completion_mode: CompletionMode::Normal,
 50            system_prompts: Vec::new(),
 51            running_turn: None,
 52            tools: BTreeMap::default(),
 53            templates,
 54            selected_model: default_model,
 55        }
 56    }
 57
 58    pub fn set_mode(&mut self, mode: CompletionMode) {
 59        self.completion_mode = mode;
 60    }
 61
 62    pub fn messages(&self) -> &[AgentMessage] {
 63        &self.messages
 64    }
 65
 66    pub fn add_tool(&mut self, tool: impl AgentTool) {
 67        self.tools.insert(tool.name(), tool.erase());
 68    }
 69
 70    pub fn remove_tool(&mut self, name: &str) -> bool {
 71        self.tools.remove(name).is_some()
 72    }
 73
 74    /// Sending a message results in the model streaming a response, which could include tool calls.
 75    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 76    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 77    pub fn send(
 78        &mut self,
 79        model: Arc<dyn LanguageModel>,
 80        content: impl Into<MessageContent>,
 81        cx: &mut Context<Self>,
 82    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 83        let content = content.into();
 84        log::info!("Thread::send called with model: {:?}", model.name());
 85        log::debug!("Thread::send content: {:?}", content);
 86
 87        cx.notify();
 88        let (events_tx, events_rx) =
 89            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 90
 91        let system_message = self.build_system_message(cx);
 92        log::debug!(
 93            "System messages count: {}",
 94            if system_message.is_some() { 1 } else { 0 }
 95        );
 96        self.messages.extend(system_message);
 97
 98        self.messages.push(AgentMessage {
 99            role: Role::User,
100            content: vec![content],
101        });
102        log::info!("Total messages in thread: {}", self.messages.len());
103        self.running_turn = Some(cx.spawn(async move |thread, cx| {
104            log::info!("Starting agent turn execution");
105            let turn_result = async {
106                // Perform one request, then keep looping if the model makes tool calls.
107                let mut completion_intent = CompletionIntent::UserPrompt;
108                loop {
109                    log::debug!(
110                        "Building completion request with intent: {:?}",
111                        completion_intent
112                    );
113                    let request = thread.update(cx, |thread, cx| {
114                        thread.build_completion_request(completion_intent, cx)
115                    })?;
116
117                    // println!(
118                    //     "request: {}",
119                    //     serde_json::to_string_pretty(&request).unwrap()
120                    // );
121
122                    // Stream events, appending to messages and collecting up tool uses.
123                    log::info!("Calling model.stream_completion");
124                    let mut events = model.stream_completion(request, cx).await?;
125                    log::debug!("Stream completion started successfully");
126                    let mut tool_uses = Vec::new();
127                    while let Some(event) = events.next().await {
128                        match event {
129                            Ok(event) => {
130                                log::trace!("Received completion event: {:?}", event);
131                                thread
132                                    .update(cx, |thread, cx| {
133                                        tool_uses.extend(thread.handle_streamed_completion_event(
134                                            event,
135                                            events_tx.clone(),
136                                            cx,
137                                        ));
138                                    })
139                                    .ok();
140                            }
141                            Err(error) => {
142                                log::error!("Error in completion stream: {:?}", error);
143                                events_tx.unbounded_send(Err(error)).ok();
144                                break;
145                            }
146                        }
147                    }
148
149                    // If there are no tool uses, the turn is done.
150                    if tool_uses.is_empty() {
151                        log::info!("No tool uses found, completing turn");
152                        break;
153                    }
154                    log::info!("Found {} tool uses to execute", tool_uses.len());
155
156                    // If there are tool uses, wait for their results to be
157                    // computed, then send them together in a single message on
158                    // the next loop iteration.
159                    let tool_results = future::join_all(tool_uses).await;
160                    log::debug!("Tool execution completed, {} results", tool_results.len());
161                    thread
162                        .update(cx, |thread, _cx| {
163                            thread.messages.push(AgentMessage {
164                                role: Role::User,
165                                content: tool_results
166                                    .into_iter()
167                                    .map(MessageContent::ToolResult)
168                                    .collect(),
169                            });
170                        })
171                        .ok();
172                    completion_intent = CompletionIntent::ToolResults;
173                }
174
175                Ok(())
176            }
177            .await;
178
179            if let Err(error) = turn_result {
180                log::error!("Turn execution failed: {:?}", error);
181                events_tx.unbounded_send(Err(error)).ok();
182            } else {
183                log::info!("Turn execution completed successfully");
184            }
185        }));
186        events_rx
187    }
188
189    pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
190        log::debug!("Building system message");
191        let mut system_message = AgentMessage {
192            role: Role::System,
193            content: Vec::new(),
194        };
195
196        for prompt in &self.system_prompts {
197            if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
198                system_message
199                    .content
200                    .push(MessageContent::Text(rendered_prompt));
201            }
202        }
203
204        let result = (!system_message.content.is_empty()).then_some(system_message);
205        log::debug!("System message built: {}", result.is_some());
206        result
207    }
208
209    /// A helper method that's called on every streamed completion event.
210    /// Returns an optional tool result task, which the main agentic loop in
211    /// send will send back to the model when it resolves.
212    fn handle_streamed_completion_event(
213        &mut self,
214        event: LanguageModelCompletionEvent,
215        events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
216        cx: &mut Context<Self>,
217    ) -> Option<Task<LanguageModelToolResult>> {
218        log::trace!("Handling streamed completion event: {:?}", event);
219        use LanguageModelCompletionEvent::*;
220        events_tx.unbounded_send(Ok(event.clone())).ok();
221
222        match event {
223            Text(new_text) => self.handle_text_event(new_text, cx),
224            Thinking {
225                text: _text,
226                signature: _signature,
227            } => {
228                todo!()
229            }
230            ToolUse(tool_use) => {
231                return self.handle_tool_use_event(tool_use, cx);
232            }
233            StartMessage { .. } => {
234                self.messages.push(AgentMessage {
235                    role: Role::Assistant,
236                    content: Vec::new(),
237                });
238            }
239            UsageUpdate(_) => {}
240            Stop(stop_reason) => self.handle_stop_event(stop_reason),
241            StatusUpdate(_completion_request_status) => {}
242            RedactedThinking { data: _data } => todo!(),
243            ToolUseJsonParseError {
244                id: _id,
245                tool_name: _tool_name,
246                raw_input: _raw_input,
247                json_parse_error: _json_parse_error,
248            } => todo!(),
249        }
250
251        None
252    }
253
254    fn handle_stop_event(&mut self, stop_reason: StopReason) {
255        match stop_reason {
256            StopReason::EndTurn | StopReason::ToolUse => {}
257            StopReason::MaxTokens => todo!(),
258            StopReason::Refusal => todo!(),
259        }
260    }
261
262    fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
263        let last_message = self.last_assistant_message();
264        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
265            text.push_str(&new_text);
266        } else {
267            last_message.content.push(MessageContent::Text(new_text));
268        }
269
270        cx.notify();
271    }
272
273    fn handle_tool_use_event(
274        &mut self,
275        tool_use: LanguageModelToolUse,
276        cx: &mut Context<Self>,
277    ) -> Option<Task<LanguageModelToolResult>> {
278        cx.notify();
279
280        let last_message = self.last_assistant_message();
281
282        // Ensure the last message ends in the current tool use
283        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
284            if let MessageContent::ToolUse(last_tool_use) = content {
285                if last_tool_use.id == tool_use.id {
286                    *last_tool_use = tool_use.clone();
287                    false
288                } else {
289                    true
290                }
291            } else {
292                true
293            }
294        });
295        if push_new_tool_use {
296            last_message
297                .content
298                .push(MessageContent::ToolUse(tool_use.clone()));
299        }
300
301        if !tool_use.is_input_complete {
302            return None;
303        }
304
305        if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
306            let pending_tool_result = tool.clone().run(tool_use.input, cx);
307
308            Some(cx.foreground_executor().spawn(async move {
309                match pending_tool_result.await {
310                    Ok(tool_output) => LanguageModelToolResult {
311                        tool_use_id: tool_use.id,
312                        tool_name: tool_use.name,
313                        is_error: false,
314                        content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
315                        output: None,
316                    },
317                    Err(error) => LanguageModelToolResult {
318                        tool_use_id: tool_use.id,
319                        tool_name: tool_use.name,
320                        is_error: true,
321                        content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
322                        output: None,
323                    },
324                }
325            }))
326        } else {
327            Some(Task::ready(LanguageModelToolResult {
328                content: LanguageModelToolResultContent::Text(Arc::from(format!(
329                    "No tool named {} exists",
330                    tool_use.name
331                ))),
332                tool_use_id: tool_use.id,
333                tool_name: tool_use.name,
334                is_error: true,
335                output: None,
336            }))
337        }
338    }
339
340    /// Guarantees the last message is from the assistant and returns a mutable reference.
341    fn last_assistant_message(&mut self) -> &mut AgentMessage {
342        if self
343            .messages
344            .last()
345            .map_or(true, |m| m.role != Role::Assistant)
346        {
347            self.messages.push(AgentMessage {
348                role: Role::Assistant,
349                content: Vec::new(),
350            });
351        }
352        self.messages.last_mut().unwrap()
353    }
354
355    fn build_completion_request(
356        &self,
357        completion_intent: CompletionIntent,
358        cx: &mut App,
359    ) -> LanguageModelRequest {
360        log::debug!("Building completion request");
361        log::debug!("Completion intent: {:?}", completion_intent);
362        log::debug!("Completion mode: {:?}", self.completion_mode);
363
364        let messages = self.build_request_messages();
365        log::info!("Request will include {} messages", messages.len());
366
367        let tools: Vec<LanguageModelRequestTool> = self
368            .tools
369            .values()
370            .filter_map(|tool| {
371                let tool_name = tool.name().to_string();
372                log::trace!("Including tool: {}", tool_name);
373                Some(LanguageModelRequestTool {
374                    name: tool_name,
375                    description: tool.description(cx).to_string(),
376                    input_schema: tool
377                        .input_schema(LanguageModelToolSchemaFormat::JsonSchema)
378                        .log_err()?,
379                })
380            })
381            .collect();
382
383        log::info!("Request includes {} tools", tools.len());
384
385        let request = LanguageModelRequest {
386            thread_id: None,
387            prompt_id: None,
388            intent: Some(completion_intent),
389            mode: Some(self.completion_mode),
390            messages,
391            tools,
392            tool_choice: None,
393            stop: Vec::new(),
394            temperature: None,
395            thinking_allowed: false,
396        };
397
398        log::debug!("Completion request built successfully");
399        request
400    }
401
402    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
403        log::trace!(
404            "Building request messages from {} thread messages",
405            self.messages.len()
406        );
407        let messages = self
408            .messages
409            .iter()
410            .map(|message| {
411                log::trace!(
412                    "  - {} message with {} content items",
413                    match message.role {
414                        Role::System => "System",
415                        Role::User => "User",
416                        Role::Assistant => "Assistant",
417                    },
418                    message.content.len()
419                );
420                LanguageModelRequestMessage {
421                    role: message.role,
422                    content: message.content.clone(),
423                    cache: false,
424                }
425            })
426            .collect();
427        messages
428    }
429}
430
431pub trait AgentTool
432where
433    Self: 'static + Sized,
434{
435    type Input: for<'de> Deserialize<'de> + JsonSchema;
436
437    fn name(&self) -> SharedString;
438    fn description(&self, _cx: &mut App) -> SharedString {
439        let schema = schemars::schema_for!(Self::Input);
440        SharedString::new(
441            schema
442                .get("description")
443                .and_then(|description| description.as_str())
444                .unwrap_or_default(),
445        )
446    }
447
448    /// Returns the JSON schema that describes the tool's input.
449    fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
450        schemars::schema_for!(Self::Input)
451    }
452
453    /// Runs the tool with the provided input.
454    fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
455
456    fn erase(self) -> Arc<dyn AnyAgentTool> {
457        Arc::new(Erased(Arc::new(self)))
458    }
459}
460
461pub struct Erased<T>(T);
462
463pub trait AnyAgentTool {
464    fn name(&self) -> SharedString;
465    fn description(&self, cx: &mut App) -> SharedString;
466    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
467    fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
468}
469
470impl<T> AnyAgentTool for Erased<Arc<T>>
471where
472    T: AgentTool,
473{
474    fn name(&self) -> SharedString {
475        self.0.name()
476    }
477
478    fn description(&self, cx: &mut App) -> SharedString {
479        self.0.description(cx)
480    }
481
482    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
483        Ok(serde_json::to_value(self.0.input_schema(format))?)
484    }
485
486    fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
487        let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
488        match parsed_input {
489            Ok(input) => self.0.clone().run(input, cx),
490            Err(error) => Task::ready(Err(anyhow!(error))),
491        }
492    }
493}