request.rs

  1use std::io::{Cursor, Write};
  2use std::sync::Arc;
  3
  4use crate::role::Role;
  5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
  6use anyhow::Result;
  7use base64::write::EncoderWriter;
  8use gpui::{
  9    App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
 10    point, px, size,
 11};
 12use image::codecs::png::PngEncoder;
 13use serde::{Deserialize, Serialize};
 14use util::ResultExt;
 15use zed_llm_client::CompletionMode;
 16
 17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 18pub struct LanguageModelImage {
 19    /// A base64-encoded PNG image.
 20    pub source: SharedString,
 21    size: Size<DevicePixels>,
 22}
 23
 24impl std::fmt::Debug for LanguageModelImage {
 25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 26        f.debug_struct("LanguageModelImage")
 27            .field("source", &format!("<{} bytes>", self.source.len()))
 28            .field("size", &self.size)
 29            .finish()
 30    }
 31}
 32
 33/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 34const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
 35
 36impl LanguageModelImage {
 37    pub fn empty() -> Self {
 38        Self {
 39            source: "".into(),
 40            size: size(DevicePixels(0), DevicePixels(0)),
 41        }
 42    }
 43
 44    pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
 45        cx.background_spawn(async move {
 46            let image_bytes = Cursor::new(data.bytes());
 47            let dynamic_image = match data.format() {
 48                ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
 49                    .and_then(image::DynamicImage::from_decoder),
 50                ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
 51                    .and_then(image::DynamicImage::from_decoder),
 52                ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
 53                    .and_then(image::DynamicImage::from_decoder),
 54                ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
 55                    .and_then(image::DynamicImage::from_decoder),
 56                _ => return None,
 57            }
 58            .log_err()?;
 59
 60            let width = dynamic_image.width();
 61            let height = dynamic_image.height();
 62            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
 63
 64            let base64_image = {
 65                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 66                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 67                {
 68                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
 69                        gpui::Bounds {
 70                            origin: point(px(0.0), px(0.0)),
 71                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
 72                        },
 73                        image_size,
 74                    );
 75                    let resized_image = dynamic_image.resize(
 76                        new_bounds.size.width.0 as u32,
 77                        new_bounds.size.height.0 as u32,
 78                        image::imageops::FilterType::Triangle,
 79                    );
 80
 81                    encode_as_base64(data, resized_image)
 82                } else {
 83                    encode_as_base64(data, dynamic_image)
 84                }
 85            }
 86            .log_err()?;
 87
 88            // SAFETY: The base64 encoder should not produce non-UTF8.
 89            let source = unsafe { String::from_utf8_unchecked(base64_image) };
 90
 91            Some(LanguageModelImage {
 92                size: image_size,
 93                source: source.into(),
 94            })
 95        })
 96    }
 97
 98    pub fn estimate_tokens(&self) -> usize {
 99        let width = self.size.width.0.unsigned_abs() as usize;
100        let height = self.size.height.0.unsigned_abs() as usize;
101
102        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
103        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
104        // so this method is more of a rough guess.
105        (width * height) / 750
106    }
107}
108
109fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
110    let mut base64_image = Vec::new();
111    {
112        let mut base64_encoder = EncoderWriter::new(
113            Cursor::new(&mut base64_image),
114            &base64::engine::general_purpose::STANDARD,
115        );
116        if data.format() == ImageFormat::Png {
117            base64_encoder.write_all(data.bytes())?;
118        } else {
119            let mut png = Vec::new();
120            image.write_with_encoder(PngEncoder::new(&mut png))?;
121
122            base64_encoder.write_all(png.as_slice())?;
123        }
124    }
125    Ok(base64_image)
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
129pub struct LanguageModelToolResult {
130    pub tool_use_id: LanguageModelToolUseId,
131    pub tool_name: Arc<str>,
132    pub is_error: bool,
133    pub content: Arc<str>,
134    pub output: Option<serde_json::Value>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
138pub enum MessageContent {
139    Text(String),
140    Thinking {
141        text: String,
142        signature: Option<String>,
143    },
144    RedactedThinking(Vec<u8>),
145    Image(LanguageModelImage),
146    ToolUse(LanguageModelToolUse),
147    ToolResult(LanguageModelToolResult),
148}
149
150impl From<String> for MessageContent {
151    fn from(value: String) -> Self {
152        MessageContent::Text(value)
153    }
154}
155
156impl From<&str> for MessageContent {
157    fn from(value: &str) -> Self {
158        MessageContent::Text(value.to_string())
159    }
160}
161
162#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
163pub struct LanguageModelRequestMessage {
164    pub role: Role,
165    pub content: Vec<MessageContent>,
166    pub cache: bool,
167}
168
169impl LanguageModelRequestMessage {
170    pub fn string_contents(&self) -> String {
171        let mut buffer = String::new();
172        for string in self.content.iter().filter_map(|content| match content {
173            MessageContent::Text(text) => Some(text.as_str()),
174            MessageContent::Thinking { text, .. } => Some(text.as_str()),
175            MessageContent::RedactedThinking(_) => None,
176            MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()),
177            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
178        }) {
179            buffer.push_str(string);
180        }
181
182        buffer
183    }
184
185    pub fn contents_empty(&self) -> bool {
186        self.content.iter().all(|content| match content {
187            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
188            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
189            MessageContent::ToolResult(tool_result) => {
190                tool_result.content.chars().all(|c| c.is_whitespace())
191            }
192            MessageContent::RedactedThinking(_)
193            | MessageContent::ToolUse(_)
194            | MessageContent::Image(_) => false,
195        })
196    }
197}
198
199#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
200pub struct LanguageModelRequestTool {
201    pub name: String,
202    pub description: String,
203    pub input_schema: serde_json::Value,
204}
205
206#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
207pub enum LanguageModelToolChoice {
208    Auto,
209    Any,
210    None,
211}
212
213#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
214pub struct LanguageModelRequest {
215    pub thread_id: Option<String>,
216    pub prompt_id: Option<String>,
217    pub mode: Option<CompletionMode>,
218    pub messages: Vec<LanguageModelRequestMessage>,
219    pub tools: Vec<LanguageModelRequestTool>,
220    pub tool_choice: Option<LanguageModelToolChoice>,
221    pub stop: Vec<String>,
222    pub temperature: Option<f32>,
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226pub struct LanguageModelResponseMessage {
227    pub role: Option<Role>,
228    pub content: Option<String>,
229}