request.rs

  1use std::io::{Cursor, Write};
  2use std::sync::Arc;
  3
  4use crate::role::Role;
  5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
  6use anyhow::Result;
  7use base64::write::EncoderWriter;
  8use gpui::{
  9    App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
 10    point, px, size,
 11};
 12use image::codecs::png::PngEncoder;
 13use serde::{Deserialize, Serialize};
 14use util::ResultExt;
 15use zed_llm_client::CompletionMode;
 16
 17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 18pub struct LanguageModelImage {
 19    /// A base64-encoded PNG image.
 20    pub source: SharedString,
 21    size: Size<DevicePixels>,
 22}
 23
 24impl LanguageModelImage {
 25    pub fn len(&self) -> usize {
 26        self.source.len()
 27    }
 28
 29    pub fn is_empty(&self) -> bool {
 30        self.source.is_empty()
 31    }
 32}
 33
 34impl std::fmt::Debug for LanguageModelImage {
 35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 36        f.debug_struct("LanguageModelImage")
 37            .field("source", &format!("<{} bytes>", self.source.len()))
 38            .field("size", &self.size)
 39            .finish()
 40    }
 41}
 42
 43/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 44const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
 45
 46impl LanguageModelImage {
 47    pub fn empty() -> Self {
 48        Self {
 49            source: "".into(),
 50            size: size(DevicePixels(0), DevicePixels(0)),
 51        }
 52    }
 53
 54    pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
 55        cx.background_spawn(async move {
 56            let image_bytes = Cursor::new(data.bytes());
 57            let dynamic_image = match data.format() {
 58                ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
 59                    .and_then(image::DynamicImage::from_decoder),
 60                ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
 61                    .and_then(image::DynamicImage::from_decoder),
 62                ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
 63                    .and_then(image::DynamicImage::from_decoder),
 64                ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
 65                    .and_then(image::DynamicImage::from_decoder),
 66                _ => return None,
 67            }
 68            .log_err()?;
 69
 70            let width = dynamic_image.width();
 71            let height = dynamic_image.height();
 72            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
 73
 74            let base64_image = {
 75                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
 76                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
 77                {
 78                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
 79                        gpui::Bounds {
 80                            origin: point(px(0.0), px(0.0)),
 81                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
 82                        },
 83                        image_size,
 84                    );
 85                    let resized_image = dynamic_image.resize(
 86                        new_bounds.size.width.0 as u32,
 87                        new_bounds.size.height.0 as u32,
 88                        image::imageops::FilterType::Triangle,
 89                    );
 90
 91                    encode_as_base64(data, resized_image)
 92                } else {
 93                    encode_as_base64(data, dynamic_image)
 94                }
 95            }
 96            .log_err()?;
 97
 98            // SAFETY: The base64 encoder should not produce non-UTF8.
 99            let source = unsafe { String::from_utf8_unchecked(base64_image) };
100
101            Some(LanguageModelImage {
102                size: image_size,
103                source: source.into(),
104            })
105        })
106    }
107
108    pub fn estimate_tokens(&self) -> usize {
109        let width = self.size.width.0.unsigned_abs() as usize;
110        let height = self.size.height.0.unsigned_abs() as usize;
111
112        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
113        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
114        // so this method is more of a rough guess.
115        (width * height) / 750
116    }
117
118    pub fn to_base64_url(&self) -> String {
119        format!("data:image/png;base64,{}", self.source)
120    }
121}
122
123fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
124    let mut base64_image = Vec::new();
125    {
126        let mut base64_encoder = EncoderWriter::new(
127            Cursor::new(&mut base64_image),
128            &base64::engine::general_purpose::STANDARD,
129        );
130        if data.format() == ImageFormat::Png {
131            base64_encoder.write_all(data.bytes())?;
132        } else {
133            let mut png = Vec::new();
134            image.write_with_encoder(PngEncoder::new(&mut png))?;
135
136            base64_encoder.write_all(png.as_slice())?;
137        }
138    }
139    Ok(base64_image)
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
143pub struct LanguageModelToolResult {
144    pub tool_use_id: LanguageModelToolUseId,
145    pub tool_name: Arc<str>,
146    pub is_error: bool,
147    pub content: LanguageModelToolResultContent,
148    pub output: Option<serde_json::Value>,
149}
150
151#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
152#[serde(untagged)]
153pub enum LanguageModelToolResultContent {
154    Text(Arc<str>),
155    Image(LanguageModelImage),
156    WrappedText(WrappedTextContent),
157}
158
159#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
160pub struct WrappedTextContent {
161    #[serde(rename = "type")]
162    pub content_type: String,
163    pub text: Arc<str>,
164}
165
166impl LanguageModelToolResultContent {
167    pub fn to_str(&self) -> Option<&str> {
168        match self {
169            Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => Some(&text),
170            Self::Image(_) => None,
171        }
172    }
173
174    pub fn is_empty(&self) -> bool {
175        match self {
176            Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => {
177                text.chars().all(|c| c.is_whitespace())
178            }
179            Self::Image(_) => false,
180        }
181    }
182}
183
184impl From<&str> for LanguageModelToolResultContent {
185    fn from(value: &str) -> Self {
186        Self::Text(Arc::from(value))
187    }
188}
189
190impl From<String> for LanguageModelToolResultContent {
191    fn from(value: String) -> Self {
192        Self::Text(Arc::from(value))
193    }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
197pub enum MessageContent {
198    Text(String),
199    Thinking {
200        text: String,
201        signature: Option<String>,
202    },
203    RedactedThinking(Vec<u8>),
204    Image(LanguageModelImage),
205    ToolUse(LanguageModelToolUse),
206    ToolResult(LanguageModelToolResult),
207}
208
209impl MessageContent {
210    pub fn to_str(&self) -> Option<&str> {
211        match self {
212            MessageContent::Text(text) => Some(text.as_str()),
213            MessageContent::Thinking { text, .. } => Some(text.as_str()),
214            MessageContent::RedactedThinking(_) => None,
215            MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
216            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
217        }
218    }
219
220    pub fn is_empty(&self) -> bool {
221        match self {
222            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
223            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
224            MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
225            MessageContent::RedactedThinking(_)
226            | MessageContent::ToolUse(_)
227            | MessageContent::Image(_) => false,
228        }
229    }
230}
231
232impl From<String> for MessageContent {
233    fn from(value: String) -> Self {
234        MessageContent::Text(value)
235    }
236}
237
238impl From<&str> for MessageContent {
239    fn from(value: &str) -> Self {
240        MessageContent::Text(value.to_string())
241    }
242}
243
244#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
245pub struct LanguageModelRequestMessage {
246    pub role: Role,
247    pub content: Vec<MessageContent>,
248    pub cache: bool,
249}
250
251impl LanguageModelRequestMessage {
252    pub fn string_contents(&self) -> String {
253        let mut buffer = String::new();
254        for string in self.content.iter().filter_map(|content| content.to_str()) {
255            buffer.push_str(string);
256        }
257
258        buffer
259    }
260
261    pub fn contents_empty(&self) -> bool {
262        self.content.iter().all(|content| content.is_empty())
263    }
264}
265
266#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
267pub struct LanguageModelRequestTool {
268    pub name: String,
269    pub description: String,
270    pub input_schema: serde_json::Value,
271}
272
273#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
274pub enum LanguageModelToolChoice {
275    Auto,
276    Any,
277    None,
278}
279
280#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
281pub struct LanguageModelRequest {
282    pub thread_id: Option<String>,
283    pub prompt_id: Option<String>,
284    pub mode: Option<CompletionMode>,
285    pub messages: Vec<LanguageModelRequestMessage>,
286    pub tools: Vec<LanguageModelRequestTool>,
287    pub tool_choice: Option<LanguageModelToolChoice>,
288    pub stop: Vec<String>,
289    pub temperature: Option<f32>,
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293pub struct LanguageModelResponseMessage {
294    pub role: Option<Role>,
295    pub content: Option<String>,
296}