request.rs

  1use std::io::{Cursor, Write};
  2use std::sync::Arc;
  3
  4use crate::role::Role;
  5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
  6use anyhow::Result;
  7use base64::write::EncoderWriter;
  8use gpui::{
  9    App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
 10    point, px, size,
 11};
 12use image::codecs::png::PngEncoder;
 13use serde::{Deserialize, Serialize};
 14use util::ResultExt;
 15use zed_llm_client::{CompletionIntent, CompletionMode};
 16
 17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 18pub struct LanguageModelImage {
 19    /// A base64-encoded PNG image.
 20    pub source: SharedString,
 21    pub size: Size<DevicePixels>,
 22}
 23
 24impl LanguageModelImage {
 25    pub fn len(&self) -> usize {
 26        self.source.len()
 27    }
 28
 29    pub fn is_empty(&self) -> bool {
 30        self.source.is_empty()
 31    }
 32
 33    // Parse Self from a JSON object with case-insensitive field names
 34    pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
 35        let mut source = None;
 36        let mut size_obj = None;
 37
 38        // Find source and size fields (case-insensitive)
 39        for (k, v) in obj.iter() {
 40            match k.to_lowercase().as_str() {
 41                "source" => source = v.as_str(),
 42                "size" => size_obj = v.as_object(),
 43                _ => {}
 44            }
 45        }
 46
 47        let source = source?;
 48        let size_obj = size_obj?;
 49
 50        let mut width = None;
 51        let mut height = None;
 52
 53        // Find width and height in size object (case-insensitive)
 54        for (k, v) in size_obj.iter() {
 55            match k.to_lowercase().as_str() {
 56                "width" => width = v.as_i64().map(|w| w as i32),
 57                "height" => height = v.as_i64().map(|h| h as i32),
 58                _ => {}
 59            }
 60        }
 61
 62        Some(Self {
 63            size: size(DevicePixels(width?), DevicePixels(height?)),
 64            source: SharedString::from(source.to_string()),
 65        })
 66    }
 67}
 68
 69impl std::fmt::Debug for LanguageModelImage {
 70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 71        f.debug_struct("LanguageModelImage")
 72            .field("source", &format!("<{} bytes>", self.source.len()))
 73            .field("size", &self.size)
 74            .finish()
 75    }
 76}
 77
 78/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 79const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
 80
 81impl LanguageModelImage {
 82    pub fn empty() -> Self {
 83        Self {
 84            source: "".into(),
 85            size: size(DevicePixels(0), DevicePixels(0)),
 86        }
 87    }
 88
 89    pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
 90        cx.background_spawn(async move {
 91            let image_bytes = Cursor::new(data.bytes());
 92            let dynamic_image = match data.format() {
 93                ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
 94                    .and_then(image::DynamicImage::from_decoder),
 95                ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
 96                    .and_then(image::DynamicImage::from_decoder),
 97                ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
 98                    .and_then(image::DynamicImage::from_decoder),
 99                ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
100                    .and_then(image::DynamicImage::from_decoder),
101                _ => return None,
102            }
103            .log_err()?;
104
105            let width = dynamic_image.width();
106            let height = dynamic_image.height();
107            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
108
109            let base64_image = {
110                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
111                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
112                {
113                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
114                        gpui::Bounds {
115                            origin: point(px(0.0), px(0.0)),
116                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
117                        },
118                        image_size,
119                    );
120                    let resized_image = dynamic_image.resize(
121                        new_bounds.size.width.0 as u32,
122                        new_bounds.size.height.0 as u32,
123                        image::imageops::FilterType::Triangle,
124                    );
125
126                    encode_as_base64(data, resized_image)
127                } else {
128                    encode_as_base64(data, dynamic_image)
129                }
130            }
131            .log_err()?;
132
133            // SAFETY: The base64 encoder should not produce non-UTF8.
134            let source = unsafe { String::from_utf8_unchecked(base64_image) };
135
136            Some(LanguageModelImage {
137                size: image_size,
138                source: source.into(),
139            })
140        })
141    }
142
143    pub fn estimate_tokens(&self) -> usize {
144        let width = self.size.width.0.unsigned_abs() as usize;
145        let height = self.size.height.0.unsigned_abs() as usize;
146
147        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
148        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
149        // so this method is more of a rough guess.
150        (width * height) / 750
151    }
152
153    pub fn to_base64_url(&self) -> String {
154        format!("data:image/png;base64,{}", self.source)
155    }
156}
157
158fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
159    let mut base64_image = Vec::new();
160    {
161        let mut base64_encoder = EncoderWriter::new(
162            Cursor::new(&mut base64_image),
163            &base64::engine::general_purpose::STANDARD,
164        );
165        if data.format() == ImageFormat::Png {
166            base64_encoder.write_all(data.bytes())?;
167        } else {
168            let mut png = Vec::new();
169            image.write_with_encoder(PngEncoder::new(&mut png))?;
170
171            base64_encoder.write_all(png.as_slice())?;
172        }
173    }
174    Ok(base64_image)
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
178pub struct LanguageModelToolResult {
179    pub tool_use_id: LanguageModelToolUseId,
180    pub tool_name: Arc<str>,
181    pub is_error: bool,
182    pub content: LanguageModelToolResultContent,
183    pub output: Option<serde_json::Value>,
184}
185
186#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
187pub enum LanguageModelToolResultContent {
188    Text(Arc<str>),
189    Image(LanguageModelImage),
190}
191
192impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
193    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
194    where
195        D: serde::Deserializer<'de>,
196    {
197        use serde::de::Error;
198
199        let value = serde_json::Value::deserialize(deserializer)?;
200
201        // Models can provide these responses in several styles. Try each in order.
202
203        // 1. Try as plain string
204        if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
205            return Ok(Self::Text(Arc::from(text)));
206        }
207
208        // 2. Try as object
209        if let Some(obj) = value.as_object() {
210            // get a JSON field case-insensitively
211            fn get_field<'a>(
212                obj: &'a serde_json::Map<String, serde_json::Value>,
213                field: &str,
214            ) -> Option<&'a serde_json::Value> {
215                obj.iter()
216                    .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
217                    .map(|(_, v)| v)
218            }
219
220            // Accept wrapped text format: { "type": "text", "text": "..." }
221            if let (Some(type_value), Some(text_value)) =
222                (get_field(&obj, "type"), get_field(&obj, "text"))
223            {
224                if let Some(type_str) = type_value.as_str() {
225                    if type_str.to_lowercase() == "text" {
226                        if let Some(text) = text_value.as_str() {
227                            return Ok(Self::Text(Arc::from(text)));
228                        }
229                    }
230                }
231            }
232
233            // Check for wrapped Text variant: { "text": "..." }
234            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") {
235                if obj.len() == 1 {
236                    // Only one field, and it's "text" (case-insensitive)
237                    if let Some(text) = value.as_str() {
238                        return Ok(Self::Text(Arc::from(text)));
239                    }
240                }
241            }
242
243            // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
244            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") {
245                if obj.len() == 1 {
246                    // Only one field, and it's "image" (case-insensitive)
247                    // Try to parse the nested image object
248                    if let Some(image_obj) = value.as_object() {
249                        if let Some(image) = LanguageModelImage::from_json(image_obj) {
250                            return Ok(Self::Image(image));
251                        }
252                    }
253                }
254            }
255
256            // Try as direct Image (object with "source" and "size" fields)
257            if let Some(image) = LanguageModelImage::from_json(&obj) {
258                return Ok(Self::Image(image));
259            }
260        }
261
262        // If none of the variants match, return an error with the problematic JSON
263        Err(D::Error::custom(format!(
264            "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
265             an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
266            serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
267        )))
268    }
269}
270
271impl LanguageModelToolResultContent {
272    pub fn to_str(&self) -> Option<&str> {
273        match self {
274            Self::Text(text) => Some(&text),
275            Self::Image(_) => None,
276        }
277    }
278
279    pub fn is_empty(&self) -> bool {
280        match self {
281            Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
282            Self::Image(_) => false,
283        }
284    }
285}
286
287impl From<&str> for LanguageModelToolResultContent {
288    fn from(value: &str) -> Self {
289        Self::Text(Arc::from(value))
290    }
291}
292
293impl From<String> for LanguageModelToolResultContent {
294    fn from(value: String) -> Self {
295        Self::Text(Arc::from(value))
296    }
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
300pub enum MessageContent {
301    Text(String),
302    Thinking {
303        text: String,
304        signature: Option<String>,
305    },
306    RedactedThinking(String),
307    Image(LanguageModelImage),
308    ToolUse(LanguageModelToolUse),
309    ToolResult(LanguageModelToolResult),
310}
311
312impl MessageContent {
313    pub fn to_str(&self) -> Option<&str> {
314        match self {
315            MessageContent::Text(text) => Some(text.as_str()),
316            MessageContent::Thinking { text, .. } => Some(text.as_str()),
317            MessageContent::RedactedThinking(_) => None,
318            MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
319            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
320        }
321    }
322
323    pub fn is_empty(&self) -> bool {
324        match self {
325            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
326            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
327            MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
328            MessageContent::RedactedThinking(_)
329            | MessageContent::ToolUse(_)
330            | MessageContent::Image(_) => false,
331        }
332    }
333}
334
335impl From<String> for MessageContent {
336    fn from(value: String) -> Self {
337        MessageContent::Text(value)
338    }
339}
340
341impl From<&str> for MessageContent {
342    fn from(value: &str) -> Self {
343        MessageContent::Text(value.to_string())
344    }
345}
346
347#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
348pub struct LanguageModelRequestMessage {
349    pub role: Role,
350    pub content: Vec<MessageContent>,
351    pub cache: bool,
352}
353
354impl LanguageModelRequestMessage {
355    pub fn string_contents(&self) -> String {
356        let mut buffer = String::new();
357        for string in self.content.iter().filter_map(|content| content.to_str()) {
358            buffer.push_str(string);
359        }
360
361        buffer
362    }
363
364    pub fn contents_empty(&self) -> bool {
365        self.content.iter().all(|content| content.is_empty())
366    }
367}
368
369#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
370pub struct LanguageModelRequestTool {
371    pub name: String,
372    pub description: String,
373    pub input_schema: serde_json::Value,
374}
375
376#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
377pub enum LanguageModelToolChoice {
378    Auto,
379    Any,
380    None,
381}
382
383#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
384pub struct LanguageModelRequest {
385    pub thread_id: Option<String>,
386    pub prompt_id: Option<String>,
387    pub intent: Option<CompletionIntent>,
388    pub mode: Option<CompletionMode>,
389    pub messages: Vec<LanguageModelRequestMessage>,
390    pub tools: Vec<LanguageModelRequestTool>,
391    pub tool_choice: Option<LanguageModelToolChoice>,
392    pub stop: Vec<String>,
393    pub temperature: Option<f32>,
394    pub thinking_allowed: bool,
395}
396
397#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
398pub struct LanguageModelResponseMessage {
399    pub role: Option<Role>,
400    pub content: Option<String>,
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_language_model_tool_result_content_deserialization() {
409        let json = r#""This is plain text""#;
410        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
411        assert_eq!(
412            result,
413            LanguageModelToolResultContent::Text("This is plain text".into())
414        );
415
416        let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
417        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
418        assert_eq!(
419            result,
420            LanguageModelToolResultContent::Text("This is wrapped text".into())
421        );
422
423        let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
424        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
425        assert_eq!(
426            result,
427            LanguageModelToolResultContent::Text("Case insensitive".into())
428        );
429
430        let json = r#"{"Text": "Wrapped variant"}"#;
431        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
432        assert_eq!(
433            result,
434            LanguageModelToolResultContent::Text("Wrapped variant".into())
435        );
436
437        let json = r#"{"text": "Lowercase wrapped"}"#;
438        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
439        assert_eq!(
440            result,
441            LanguageModelToolResultContent::Text("Lowercase wrapped".into())
442        );
443
444        // Test image deserialization
445        let json = r#"{
446            "source": "base64encodedimagedata",
447            "size": {
448                "width": 100,
449                "height": 200
450            }
451        }"#;
452        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
453        match result {
454            LanguageModelToolResultContent::Image(image) => {
455                assert_eq!(image.source.as_ref(), "base64encodedimagedata");
456                assert_eq!(image.size.width.0, 100);
457                assert_eq!(image.size.height.0, 200);
458            }
459            _ => panic!("Expected Image variant"),
460        }
461
462        // Test wrapped Image variant
463        let json = r#"{
464            "Image": {
465                "source": "wrappedimagedata",
466                "size": {
467                    "width": 50,
468                    "height": 75
469                }
470            }
471        }"#;
472        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
473        match result {
474            LanguageModelToolResultContent::Image(image) => {
475                assert_eq!(image.source.as_ref(), "wrappedimagedata");
476                assert_eq!(image.size.width.0, 50);
477                assert_eq!(image.size.height.0, 75);
478            }
479            _ => panic!("Expected Image variant"),
480        }
481
482        // Test wrapped Image variant with case insensitive
483        let json = r#"{
484            "image": {
485                "Source": "caseinsensitive",
486                "SIZE": {
487                    "width": 30,
488                    "height": 40
489                }
490            }
491        }"#;
492        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
493        match result {
494            LanguageModelToolResultContent::Image(image) => {
495                assert_eq!(image.source.as_ref(), "caseinsensitive");
496                assert_eq!(image.size.width.0, 30);
497                assert_eq!(image.size.height.0, 40);
498            }
499            _ => panic!("Expected Image variant"),
500        }
501
502        // Test that wrapped text with wrong type fails
503        let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
504        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
505        assert!(result.is_err());
506
507        // Test that malformed JSON fails
508        let json = r#"{"invalid": "structure"}"#;
509        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
510        assert!(result.is_err());
511
512        // Test edge cases
513        let json = r#""""#; // Empty string
514        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
515        assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
516
517        // Test with extra fields in wrapped text (should be ignored)
518        let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
519        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
520        assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
521
522        // Test direct image with case-insensitive fields
523        let json = r#"{
524            "SOURCE": "directimage",
525            "Size": {
526                "width": 200,
527                "height": 300
528            }
529        }"#;
530        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
531        match result {
532            LanguageModelToolResultContent::Image(image) => {
533                assert_eq!(image.source.as_ref(), "directimage");
534                assert_eq!(image.size.width.0, 200);
535                assert_eq!(image.size.height.0, 300);
536            }
537            _ => panic!("Expected Image variant"),
538        }
539
540        // Test that multiple fields prevent wrapped variant interpretation
541        let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
542        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
543        assert!(result.is_err());
544
545        // Test wrapped text with uppercase TEXT variant
546        let json = r#"{"TEXT": "Uppercase variant"}"#;
547        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
548        assert_eq!(
549            result,
550            LanguageModelToolResultContent::Text("Uppercase variant".into())
551        );
552
553        // Test that numbers and other JSON values fail gracefully
554        let json = r#"123"#;
555        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
556        assert!(result.is_err());
557
558        let json = r#"null"#;
559        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
560        assert!(result.is_err());
561
562        let json = r#"[1, 2, 3]"#;
563        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
564        assert!(result.is_err());
565    }
566}