request.rs

  1use std::io::{Cursor, Write};
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use base64::write::EncoderWriter;
  6use cloud_llm_client::{CompletionIntent, CompletionMode};
  7use gpui::{
  8    App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
  9    point, px, size,
 10};
 11use image::codecs::png::PngEncoder;
 12use serde::{Deserialize, Serialize};
 13use util::ResultExt;
 14
 15use crate::role::Role;
 16use crate::{LanguageModelToolUse, LanguageModelToolUseId};
 17
 18#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 19pub struct LanguageModelImage {
 20    /// A base64-encoded PNG image.
 21    pub source: SharedString,
 22    #[serde(default, skip_serializing_if = "Option::is_none")]
 23    pub size: Option<Size<DevicePixels>>,
 24}
 25
 26impl LanguageModelImage {
 27    pub fn len(&self) -> usize {
 28        self.source.len()
 29    }
 30
 31    pub fn is_empty(&self) -> bool {
 32        self.source.is_empty()
 33    }
 34
 35    // Parse Self from a JSON object with case-insensitive field names
 36    pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
 37        let mut source = None;
 38        let mut size_obj = None;
 39
 40        // Find source and size fields (case-insensitive)
 41        for (k, v) in obj.iter() {
 42            match k.to_lowercase().as_str() {
 43                "source" => source = v.as_str(),
 44                "size" => size_obj = v.as_object(),
 45                _ => {}
 46            }
 47        }
 48
 49        let source = source?;
 50        let size_obj = size_obj?;
 51
 52        let mut width = None;
 53        let mut height = None;
 54
 55        // Find width and height in size object (case-insensitive)
 56        for (k, v) in size_obj.iter() {
 57            match k.to_lowercase().as_str() {
 58                "width" => width = v.as_i64().map(|w| w as i32),
 59                "height" => height = v.as_i64().map(|h| h as i32),
 60                _ => {}
 61            }
 62        }
 63
 64        Some(Self {
 65            size: Some(size(DevicePixels(width?), DevicePixels(height?))),
 66            source: SharedString::from(source.to_string()),
 67        })
 68    }
 69}
 70
 71impl std::fmt::Debug for LanguageModelImage {
 72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 73        f.debug_struct("LanguageModelImage")
 74            .field("source", &format!("<{} bytes>", self.source.len()))
 75            .field("size", &self.size)
 76            .finish()
 77    }
 78}
 79
 80/// Anthropic wants uploaded images to be smaller than this in both dimensions.
 81const ANTHROPIC_SIZE_LIMIT: f32 = 1568.;
 82
 83impl LanguageModelImage {
 84    pub fn empty() -> Self {
 85        Self {
 86            source: "".into(),
 87            size: None,
 88        }
 89    }
 90
 91    pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
 92        cx.background_spawn(async move {
 93            let image_bytes = Cursor::new(data.bytes());
 94            let dynamic_image = match data.format() {
 95                ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
 96                    .and_then(image::DynamicImage::from_decoder),
 97                ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
 98                    .and_then(image::DynamicImage::from_decoder),
 99                ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
100                    .and_then(image::DynamicImage::from_decoder),
101                ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
102                    .and_then(image::DynamicImage::from_decoder),
103                ImageFormat::Bmp => image::codecs::bmp::BmpDecoder::new(image_bytes)
104                    .and_then(image::DynamicImage::from_decoder),
105                ImageFormat::Tiff => image::codecs::tiff::TiffDecoder::new(image_bytes)
106                    .and_then(image::DynamicImage::from_decoder),
107                _ => return None,
108            }
109            .log_err()?;
110
111            let width = dynamic_image.width();
112            let height = dynamic_image.height();
113            let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
114
115            let base64_image = {
116                if image_size.width.0 > ANTHROPIC_SIZE_LIMIT as i32
117                    || image_size.height.0 > ANTHROPIC_SIZE_LIMIT as i32
118                {
119                    let new_bounds = ObjectFit::ScaleDown.get_bounds(
120                        gpui::Bounds {
121                            origin: point(px(0.0), px(0.0)),
122                            size: size(px(ANTHROPIC_SIZE_LIMIT), px(ANTHROPIC_SIZE_LIMIT)),
123                        },
124                        image_size,
125                    );
126                    let resized_image = dynamic_image.resize(
127                        new_bounds.size.width.into(),
128                        new_bounds.size.height.into(),
129                        image::imageops::FilterType::Triangle,
130                    );
131
132                    encode_as_base64(data, resized_image)
133                } else {
134                    encode_as_base64(data, dynamic_image)
135                }
136            }
137            .log_err()?;
138
139            // SAFETY: The base64 encoder should not produce non-UTF8.
140            let source = unsafe { String::from_utf8_unchecked(base64_image) };
141
142            Some(LanguageModelImage {
143                size: Some(image_size),
144                source: source.into(),
145            })
146        })
147    }
148
149    pub fn estimate_tokens(&self) -> usize {
150        let Some(size) = self.size.as_ref() else {
151            return 0;
152        };
153        let width = size.width.0.unsigned_abs() as usize;
154        let height = size.height.0.unsigned_abs() as usize;
155
156        // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
157        // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
158        // so this method is more of a rough guess.
159        (width * height) / 750
160    }
161
162    pub fn to_base64_url(&self) -> String {
163        format!("data:image/png;base64,{}", self.source)
164    }
165}
166
167fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
168    let mut base64_image = Vec::new();
169    {
170        let mut base64_encoder = EncoderWriter::new(
171            Cursor::new(&mut base64_image),
172            &base64::engine::general_purpose::STANDARD,
173        );
174        if data.format() == ImageFormat::Png {
175            base64_encoder.write_all(data.bytes())?;
176        } else {
177            let mut png = Vec::new();
178            image.write_with_encoder(PngEncoder::new(&mut png))?;
179
180            base64_encoder.write_all(png.as_slice())?;
181        }
182    }
183    Ok(base64_image)
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
187pub struct LanguageModelToolResult {
188    pub tool_use_id: LanguageModelToolUseId,
189    pub tool_name: Arc<str>,
190    pub is_error: bool,
191    pub content: LanguageModelToolResultContent,
192    pub output: Option<serde_json::Value>,
193}
194
195#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
196pub enum LanguageModelToolResultContent {
197    Text(Arc<str>),
198    Image(LanguageModelImage),
199}
200
201impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
202    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
203    where
204        D: serde::Deserializer<'de>,
205    {
206        use serde::de::Error;
207
208        let value = serde_json::Value::deserialize(deserializer)?;
209
210        // Models can provide these responses in several styles. Try each in order.
211
212        // 1. Try as plain string
213        if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
214            return Ok(Self::Text(Arc::from(text)));
215        }
216
217        // 2. Try as object
218        if let Some(obj) = value.as_object() {
219            // get a JSON field case-insensitively
220            fn get_field<'a>(
221                obj: &'a serde_json::Map<String, serde_json::Value>,
222                field: &str,
223            ) -> Option<&'a serde_json::Value> {
224                obj.iter()
225                    .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
226                    .map(|(_, v)| v)
227            }
228
229            // Accept wrapped text format: { "type": "text", "text": "..." }
230            if let (Some(type_value), Some(text_value)) =
231                (get_field(obj, "type"), get_field(obj, "text"))
232                && let Some(type_str) = type_value.as_str()
233                && type_str.to_lowercase() == "text"
234                && let Some(text) = text_value.as_str()
235            {
236                return Ok(Self::Text(Arc::from(text)));
237            }
238
239            // Check for wrapped Text variant: { "text": "..." }
240            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text")
241                && obj.len() == 1
242            {
243                // Only one field, and it's "text" (case-insensitive)
244                if let Some(text) = value.as_str() {
245                    return Ok(Self::Text(Arc::from(text)));
246                }
247            }
248
249            // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
250            if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image")
251                && obj.len() == 1
252            {
253                // Only one field, and it's "image" (case-insensitive)
254                // Try to parse the nested image object
255                if let Some(image_obj) = value.as_object()
256                    && let Some(image) = LanguageModelImage::from_json(image_obj)
257                {
258                    return Ok(Self::Image(image));
259                }
260            }
261
262            // Try as direct Image (object with "source" and "size" fields)
263            if let Some(image) = LanguageModelImage::from_json(obj) {
264                return Ok(Self::Image(image));
265            }
266        }
267
268        // If none of the variants match, return an error with the problematic JSON
269        Err(D::Error::custom(format!(
270            "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
271             an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
272            serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
273        )))
274    }
275}
276
277impl LanguageModelToolResultContent {
278    pub fn to_str(&self) -> Option<&str> {
279        match self {
280            Self::Text(text) => Some(text),
281            Self::Image(_) => None,
282        }
283    }
284
285    pub fn is_empty(&self) -> bool {
286        match self {
287            Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
288            Self::Image(_) => false,
289        }
290    }
291}
292
293impl From<&str> for LanguageModelToolResultContent {
294    fn from(value: &str) -> Self {
295        Self::Text(Arc::from(value))
296    }
297}
298
299impl From<String> for LanguageModelToolResultContent {
300    fn from(value: String) -> Self {
301        Self::Text(Arc::from(value))
302    }
303}
304
305impl From<LanguageModelImage> for LanguageModelToolResultContent {
306    fn from(image: LanguageModelImage) -> Self {
307        Self::Image(image)
308    }
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
312pub enum MessageContent {
313    Text(String),
314    Thinking {
315        text: String,
316        signature: Option<String>,
317    },
318    RedactedThinking(String),
319    Image(LanguageModelImage),
320    ToolUse(LanguageModelToolUse),
321    ToolResult(LanguageModelToolResult),
322}
323
324impl MessageContent {
325    pub fn to_str(&self) -> Option<&str> {
326        match self {
327            MessageContent::Text(text) => Some(text.as_str()),
328            MessageContent::Thinking { text, .. } => Some(text.as_str()),
329            MessageContent::RedactedThinking(_) => None,
330            MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
331            MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
332        }
333    }
334
335    pub fn is_empty(&self) -> bool {
336        match self {
337            MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
338            MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
339            MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
340            MessageContent::RedactedThinking(_)
341            | MessageContent::ToolUse(_)
342            | MessageContent::Image(_) => false,
343        }
344    }
345}
346
347impl From<String> for MessageContent {
348    fn from(value: String) -> Self {
349        MessageContent::Text(value)
350    }
351}
352
353impl From<&str> for MessageContent {
354    fn from(value: &str) -> Self {
355        MessageContent::Text(value.to_string())
356    }
357}
358
359#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
360pub struct LanguageModelRequestMessage {
361    pub role: Role,
362    pub content: Vec<MessageContent>,
363    pub cache: bool,
364    #[serde(default, skip_serializing_if = "Option::is_none")]
365    pub reasoning_details: Option<serde_json::Value>,
366}
367
368impl LanguageModelRequestMessage {
369    pub fn string_contents(&self) -> String {
370        let mut buffer = String::new();
371        for string in self.content.iter().filter_map(|content| content.to_str()) {
372            buffer.push_str(string);
373        }
374
375        buffer
376    }
377
378    pub fn contents_empty(&self) -> bool {
379        self.content.iter().all(|content| content.is_empty())
380    }
381}
382
383#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
384pub struct LanguageModelRequestTool {
385    pub name: String,
386    pub description: String,
387    pub input_schema: serde_json::Value,
388}
389
390#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
391pub enum LanguageModelToolChoice {
392    Auto,
393    Any,
394    None,
395}
396
397#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
398pub struct LanguageModelRequest {
399    pub thread_id: Option<String>,
400    pub prompt_id: Option<String>,
401    pub intent: Option<CompletionIntent>,
402    pub mode: Option<CompletionMode>,
403    pub messages: Vec<LanguageModelRequestMessage>,
404    pub tools: Vec<LanguageModelRequestTool>,
405    pub tool_choice: Option<LanguageModelToolChoice>,
406    pub stop: Vec<String>,
407    pub temperature: Option<f32>,
408    pub thinking_allowed: bool,
409}
410
411#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
412pub struct LanguageModelResponseMessage {
413    pub role: Option<Role>,
414    pub content: Option<String>,
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_language_model_tool_result_content_deserialization() {
423        let json = r#""This is plain text""#;
424        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
425        assert_eq!(
426            result,
427            LanguageModelToolResultContent::Text("This is plain text".into())
428        );
429
430        let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
431        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
432        assert_eq!(
433            result,
434            LanguageModelToolResultContent::Text("This is wrapped text".into())
435        );
436
437        let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
438        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
439        assert_eq!(
440            result,
441            LanguageModelToolResultContent::Text("Case insensitive".into())
442        );
443
444        let json = r#"{"Text": "Wrapped variant"}"#;
445        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
446        assert_eq!(
447            result,
448            LanguageModelToolResultContent::Text("Wrapped variant".into())
449        );
450
451        let json = r#"{"text": "Lowercase wrapped"}"#;
452        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
453        assert_eq!(
454            result,
455            LanguageModelToolResultContent::Text("Lowercase wrapped".into())
456        );
457
458        // Test image deserialization
459        let json = r#"{
460            "source": "base64encodedimagedata",
461            "size": {
462                "width": 100,
463                "height": 200
464            }
465        }"#;
466        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
467        match result {
468            LanguageModelToolResultContent::Image(image) => {
469                assert_eq!(image.source.as_ref(), "base64encodedimagedata");
470                let size = image.size.expect("size");
471                assert_eq!(size.width.0, 100);
472                assert_eq!(size.height.0, 200);
473            }
474            _ => panic!("Expected Image variant"),
475        }
476
477        // Test wrapped Image variant
478        let json = r#"{
479            "Image": {
480                "source": "wrappedimagedata",
481                "size": {
482                    "width": 50,
483                    "height": 75
484                }
485            }
486        }"#;
487        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
488        match result {
489            LanguageModelToolResultContent::Image(image) => {
490                assert_eq!(image.source.as_ref(), "wrappedimagedata");
491                let size = image.size.expect("size");
492                assert_eq!(size.width.0, 50);
493                assert_eq!(size.height.0, 75);
494            }
495            _ => panic!("Expected Image variant"),
496        }
497
498        // Test wrapped Image variant with case insensitive
499        let json = r#"{
500            "image": {
501                "Source": "caseinsensitive",
502                "SIZE": {
503                    "width": 30,
504                    "height": 40
505                }
506            }
507        }"#;
508        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
509        match result {
510            LanguageModelToolResultContent::Image(image) => {
511                assert_eq!(image.source.as_ref(), "caseinsensitive");
512                let size = image.size.expect("size");
513                assert_eq!(size.width.0, 30);
514                assert_eq!(size.height.0, 40);
515            }
516            _ => panic!("Expected Image variant"),
517        }
518
519        // Test that wrapped text with wrong type fails
520        let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
521        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
522        assert!(result.is_err());
523
524        // Test that malformed JSON fails
525        let json = r#"{"invalid": "structure"}"#;
526        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
527        assert!(result.is_err());
528
529        // Test edge cases
530        let json = r#""""#; // Empty string
531        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
532        assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
533
534        // Test with extra fields in wrapped text (should be ignored)
535        let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
536        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
537        assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
538
539        // Test direct image with case-insensitive fields
540        let json = r#"{
541            "SOURCE": "directimage",
542            "Size": {
543                "width": 200,
544                "height": 300
545            }
546        }"#;
547        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
548        match result {
549            LanguageModelToolResultContent::Image(image) => {
550                assert_eq!(image.source.as_ref(), "directimage");
551                let size = image.size.expect("size");
552                assert_eq!(size.width.0, 200);
553                assert_eq!(size.height.0, 300);
554            }
555            _ => panic!("Expected Image variant"),
556        }
557
558        // Test that multiple fields prevent wrapped variant interpretation
559        let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
560        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
561        assert!(result.is_err());
562
563        // Test wrapped text with uppercase TEXT variant
564        let json = r#"{"TEXT": "Uppercase variant"}"#;
565        let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
566        assert_eq!(
567            result,
568            LanguageModelToolResultContent::Text("Uppercase variant".into())
569        );
570
571        // Test that numbers and other JSON values fail gracefully
572        let json = r#"123"#;
573        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
574        assert!(result.is_err());
575
576        let json = r#"null"#;
577        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
578        assert!(result.is_err());
579
580        let json = r#"[1, 2, 3]"#;
581        let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
582        assert!(result.is_err());
583    }
584}