request.rs

  1use std::io::{Cursor, Write};
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use base64::write::EncoderWriter;
  6use cloud_llm_client::{CompletionIntent, CompletionMode};
  7use gpui::{
  8    App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
  9    point, px, size,
 10};
 11use image::codecs::png::PngEncoder;
 12use serde::{Deserialize, Serialize};
 13use util::ResultExt;
 14
 15use crate::role::Role;
 16use crate::{LanguageModelToolUse, LanguageModelToolUseId};
 17
 18#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 19pub struct LanguageModelImage {
 20    /// A base64-encoded PNG image.
 21    pub source: SharedString,
 22    pub size: Size<DevicePixels>,
 23}
 24
 25impl LanguageModelImage {
 26    pub fn len(&self) -> usize {
 27        self.source.len()
 28    }
 29
 30    pub fn is_empty(&self) -> bool {
 31        self.source.is_empty()
 32    }
 33
 34    // Parse Self from a JSON object with case-insensitive field names
 35    pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
 36        let mut source = None;
 37        let mut size_obj = None;
 38
 39        // Find source and size fields (case-insensitive)
 40        for (k, v) in obj.iter() {
 41            match k.to_lowercase().as_str() {
 42                "source" => source = v.as_str(),
 43                "size" => size_obj = v.as_object(),
 44                _ => {}
 45            }
 46        }
 47
 48        let source = source?;
 49        let size_obj = size_obj?;
 50
 51        let mut width = None;
 52        let mut height = None;
 53
 54        // Find width and height in size object (case-insensitive)
 55        for (k, v) in size_obj.iter() {
 56            match k.to_lowercase().as_str() {
 57                "width" => width = v.as_i64().map(|w| w as i32),
 58                "height" => height = v.as_i64().map(|h| h as i32),
 59                _ => {}
 60            }
 61        }
 62
 63        Some(Self {
 64            size: size(DevicePixels(width?), DevicePixels(height?)),
 65            source: SharedString::from(source.to_string()),
 66        })
 67    }
 68}
 69
 70impl std::fmt::Debug for LanguageModelImage {
 71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 72        f.debug_struct("LanguageModelImage")
 73            .field("source", &format!("<{} bytes>", self.source.len()))
 74            .field("size", &self.size)
 75            .finish()
 76    }
 77}
 78
 79/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 80const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
 81
 82impl LanguageModelImage {
 83    pub fn empty() -> Self {
 84        Self {
 85            source: "".into(),
 86            size: size(DevicePixels(0), DevicePixels(0)),
 87        }
 88    }
 89
 90    pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
 91        cx.background_spawn(async move {
 92            let image_bytes = Cursor::new(data.bytes());
 93            let dynamic_image = match data.format() {
 94                ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
 95                    .and_then(image::DynamicImage::from_decoder),
 96                ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
 97                    .and_then(image::DynamicImage::from_decoder),
 98                ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
 99                    .and_then(image::DynamicImage::from_decoder),
100                ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
101                    .and_then(image::DynamicImage::from_decoder),
102                _ => return None,
103            }
104            .log_err()?;
105
106            let width = dynamic_image.width();
107            let height = dynamic_image.height();
108            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
109
110            let base64_image = {
111                if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
112                    || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
113                {
114                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
115                        gpui::Bounds {
116                            origin: point(px(0.0), px(0.0)),
117                            size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
118                        },
119                        image_size,
120                    );
121                    let resized_image = dynamic_image.resize(
122                        new_bounds.size.width.0 as u32,
123                        new_bounds.size.height.0 as u32,
124                        image::imageops::FilterType::Triangle,
125                    );
126
127                    encode_as_base64(data, resized_image)
128                } else {
129                    encode_as_base64(data, dynamic_image)
130                }
131            }
132            .log_err()?;
133
134            // SAFETY: The base64 encoder should not produce non-UTF8.
135            let source = unsafe { String::from_utf8_unchecked(base64_image) };
136
137            Some(LanguageModelImage {
138                size: image_size,
139                source: source.into(),
140            })
141        })
142    }
143
144    pub fn estimate_tokens(&self) -> usize {
145        let width = self.size.width.0.unsigned_abs() as usize;
146        let height = self.size.height.0.unsigned_abs() as usize;
147
148        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
149        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
150        // so this method is more of a rough guess.
151        (width * height) / 750
152    }
153
154    pub fn to_base64_url(&self) -> String {
155        format!("data:image/png;base64,{}", self.source)
156    }
157}
158
159fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
160    let mut base64_image = Vec::new();
161    {
162        let mut base64_encoder = EncoderWriter::new(
163            Cursor::new(&mut base64_image),
164            &base64::engine::general_purpose::STANDARD,
165        );
166        if data.format() == ImageFormat::Png {
167            base64_encoder.write_all(data.bytes())?;
168        } else {
169            let mut png = Vec::new();
170            image.write_with_encoder(PngEncoder::new(&mut png))?;
171
172            base64_encoder.write_all(png.as_slice())?;
173        }
174    }
175    Ok(base64_image)
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
179pub struct LanguageModelToolResult {
180    pub tool_use_id: LanguageModelToolUseId,
181    pub tool_name: Arc<str>,
182    pub is_error: bool,
183    pub content: LanguageModelToolResultContent,
184    pub output: Option<serde_json::Value>,
185}
186
187#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
188pub enum LanguageModelToolResultContent {
189    Text(Arc<str>),
190    Image(LanguageModelImage),
191}
192
193impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
194    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
195    where
196        D: serde::Deserializer<'de>,
197    {
198        use serde::de::Error;
199
200        let value = serde_json::Value::deserialize(deserializer)?;
201
202        // Models can provide these responses in several styles. Try each in order.
203
204        // 1. Try as plain string
205        if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
206            return Ok(Self::Text(Arc::from(text)));
207        }
208
209        // 2. Try as object
210        if let Some(obj) = value.as_object() {
211            // get a JSON field case-insensitively
212            fn get_field<'a>(
213                obj: &'a serde_json::Map<String, serde_json::Value>,
214                field: &str,
215            ) -> Option<&'a serde_json::Value> {
216                obj.iter()
217                    .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
218                    .map(|(_, v)| v)
219            }
220
221            // Accept wrapped text format: { "type": "text", "text": "..." }
222            if let (Some(type_value), Some(text_value)) =
223                (get_field(obj, "type"), get_field(obj, "text"))
224            {
225                if let Some(type_str) = type_value.as_str() {
226                    if type_str.to_lowercase() == "text" {
227                        if let Some(text) = text_value.as_str() {
228                            return Ok(Self::Text(Arc::from(text)));
229                        }
230                    }
231                }
232            }
233
234            // Check for wrapped Text variant: { "text": "..." }
235            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") {
236                if obj.len() == 1 {
237                    // Only one field, and it's "text" (case-insensitive)
238                    if let Some(text) = value.as_str() {
239                        return Ok(Self::Text(Arc::from(text)));
240                    }
241                }
242            }
243
244            // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
245            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") {
246                if obj.len() == 1 {
247                    // Only one field, and it's "image" (case-insensitive)
248                    // Try to parse the nested image object
249                    if let Some(image_obj) = value.as_object() {
250                        if let Some(image) = LanguageModelImage::from_json(image_obj) {
251                            return Ok(Self::Image(image));
252                        }
253                    }
254                }
255            }
256
257            // Try as direct Image (object with "source" and "size" fields)
258            if let Some(image) = LanguageModelImage::from_json(obj) {
259                return Ok(Self::Image(image));
260            }
261        }
262
263        // If none of the variants match, return an error with the problematic JSON
264        Err(D::Error::custom(format!(
265            "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
266             an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
267            serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
268        )))
269    }
270}
271
272impl LanguageModelToolResultContent {
273    pub fn to_str(&self) -> Option<&str> {
274        match self {
275            Self::Text(text) => Some(text),
276            Self::Image(_) => None,
277        }
278    }
279
280    pub fn is_empty(&self) -> bool {
281        match self {
282            Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
283            Self::Image(_) => false,
284        }
285    }
286}
287
288impl From<&str> for LanguageModelToolResultContent {
289    fn from(value: &str) -> Self {
290        Self::Text(Arc::from(value))
291    }
292}
293
294impl From<String> for LanguageModelToolResultContent {
295    fn from(value: String) -> Self {
296        Self::Text(Arc::from(value))
297    }
298}
299
300impl From<LanguageModelImage> for LanguageModelToolResultContent {
301    fn from(image: LanguageModelImage) -> Self {
302        Self::Image(image)
303    }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
307pub enum MessageContent {
308    Text(String),
309    Thinking {
310        text: String,
311        signature: Option<String>,
312    },
313    RedactedThinking(String),
314    Image(LanguageModelImage),
315    ToolUse(LanguageModelToolUse),
316    ToolResult(LanguageModelToolResult),
317}
318
319impl MessageContent {
320    pub fn to_str(&self) -> Option<&str> {
321        match self {
322            MessageContent::Text(text) => Some(text.as_str()),
323            MessageContent::Thinking { text, .. } => Some(text.as_str()),
324            MessageContent::RedactedThinking(_) => None,
325            MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
326            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
327        }
328    }
329
330    pub fn is_empty(&self) -> bool {
331        match self {
332            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
333            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
334            MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
335            MessageContent::RedactedThinking(_)
336            | MessageContent::ToolUse(_)
337            | MessageContent::Image(_) => false,
338        }
339    }
340}
341
342impl From<String> for MessageContent {
343    fn from(value: String) -> Self {
344        MessageContent::Text(value)
345    }
346}
347
348impl From<&str> for MessageContent {
349    fn from(value: &str) -> Self {
350        MessageContent::Text(value.to_string())
351    }
352}
353
354#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
355pub struct LanguageModelRequestMessage {
356    pub role: Role,
357    pub content: Vec<MessageContent>,
358    pub cache: bool,
359}
360
361impl LanguageModelRequestMessage {
362    pub fn string_contents(&self) -> String {
363        let mut buffer = String::new();
364        for string in self.content.iter().filter_map(|content| content.to_str()) {
365            buffer.push_str(string);
366        }
367
368        buffer
369    }
370
371    pub fn contents_empty(&self) -> bool {
372        self.content.iter().all(|content| content.is_empty())
373    }
374}
375
376#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
377pub struct LanguageModelRequestTool {
378    pub name: String,
379    pub description: String,
380    pub input_schema: serde_json::Value,
381}
382
383#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
384pub enum LanguageModelToolChoice {
385    Auto,
386    Any,
387    None,
388}
389
390#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
391pub struct LanguageModelRequest {
392    pub thread_id: Option<String>,
393    pub prompt_id: Option<String>,
394    pub intent: Option<CompletionIntent>,
395    pub mode: Option<CompletionMode>,
396    pub messages: Vec<LanguageModelRequestMessage>,
397    pub tools: Vec<LanguageModelRequestTool>,
398    pub tool_choice: Option<LanguageModelToolChoice>,
399    pub stop: Vec<String>,
400    pub temperature: Option<f32>,
401    pub thinking_allowed: bool,
402}
403
404#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
405pub struct LanguageModelResponseMessage {
406    pub role: Option<Role>,
407    pub content: Option<String>,
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_language_model_tool_result_content_deserialization() {
416        let json = r#""This is plain text""#;
417        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
418        assert_eq!(
419            result,
420            LanguageModelToolResultContent::Text("This is plain text".into())
421        );
422
423        let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
424        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
425        assert_eq!(
426            result,
427            LanguageModelToolResultContent::Text("This is wrapped text".into())
428        );
429
430        let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
431        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
432        assert_eq!(
433            result,
434            LanguageModelToolResultContent::Text("Case insensitive".into())
435        );
436
437        let json = r#"{"Text": "Wrapped variant"}"#;
438        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
439        assert_eq!(
440            result,
441            LanguageModelToolResultContent::Text("Wrapped variant".into())
442        );
443
444        let json = r#"{"text": "Lowercase wrapped"}"#;
445        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
446        assert_eq!(
447            result,
448            LanguageModelToolResultContent::Text("Lowercase wrapped".into())
449        );
450
451        // Test image deserialization
452        let json = r#"{
453            "source": "base64encodedimagedata",
454            "size": {
455                "width": 100,
456                "height": 200
457            }
458        }"#;
459        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
460        match result {
461            LanguageModelToolResultContent::Image(image) => {
462                assert_eq!(image.source.as_ref(), "base64encodedimagedata");
463                assert_eq!(image.size.width.0, 100);
464                assert_eq!(image.size.height.0, 200);
465            }
466            _ => panic!("Expected Image variant"),
467        }
468
469        // Test wrapped Image variant
470        let json = r#"{
471            "Image": {
472                "source": "wrappedimagedata",
473                "size": {
474                    "width": 50,
475                    "height": 75
476                }
477            }
478        }"#;
479        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
480        match result {
481            LanguageModelToolResultContent::Image(image) => {
482                assert_eq!(image.source.as_ref(), "wrappedimagedata");
483                assert_eq!(image.size.width.0, 50);
484                assert_eq!(image.size.height.0, 75);
485            }
486            _ => panic!("Expected Image variant"),
487        }
488
489        // Test wrapped Image variant with case insensitive
490        let json = r#"{
491            "image": {
492                "Source": "caseinsensitive",
493                "SIZE": {
494                    "width": 30,
495                    "height": 40
496                }
497            }
498        }"#;
499        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
500        match result {
501            LanguageModelToolResultContent::Image(image) => {
502                assert_eq!(image.source.as_ref(), "caseinsensitive");
503                assert_eq!(image.size.width.0, 30);
504                assert_eq!(image.size.height.0, 40);
505            }
506            _ => panic!("Expected Image variant"),
507        }
508
509        // Test that wrapped text with wrong type fails
510        let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
511        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
512        assert!(result.is_err());
513
514        // Test that malformed JSON fails
515        let json = r#"{"invalid": "structure"}"#;
516        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
517        assert!(result.is_err());
518
519        // Test edge cases
520        let json = r#""""#; // Empty string
521        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
522        assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
523
524        // Test with extra fields in wrapped text (should be ignored)
525        let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
526        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
527        assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
528
529        // Test direct image with case-insensitive fields
530        let json = r#"{
531            "SOURCE": "directimage",
532            "Size": {
533                "width": 200,
534                "height": 300
535            }
536        }"#;
537        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
538        match result {
539            LanguageModelToolResultContent::Image(image) => {
540                assert_eq!(image.source.as_ref(), "directimage");
541                assert_eq!(image.size.width.0, 200);
542                assert_eq!(image.size.height.0, 300);
543            }
544            _ => panic!("Expected Image variant"),
545        }
546
547        // Test that multiple fields prevent wrapped variant interpretation
548        let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
549        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
550        assert!(result.is_err());
551
552        // Test wrapped text with uppercase TEXT variant
553        let json = r#"{"TEXT": "Uppercase variant"}"#;
554        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
555        assert_eq!(
556            result,
557            LanguageModelToolResultContent::Text("Uppercase variant".into())
558        );
559
560        // Test that numbers and other JSON values fail gracefully
561        let json = r#"123"#;
562        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
563        assert!(result.is_err());
564
565        let json = r#"null"#;
566        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
567        assert!(result.is_err());
568
569        let json = r#"[1, 2, 3]"#;
570        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
571        assert!(result.is_err());
572    }
573}