1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use crate::role::Role;
5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
6use anyhow::Result;
7use base64::write::EncoderWriter;
8use gpui::{
9 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
10 point, px, size,
11};
12use image::codecs::png::PngEncoder;
13use serde::{Deserialize, Serialize};
14use util::ResultExt;
15use zed_llm_client::CompletionMode;
16
17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
18pub struct LanguageModelImage {
19 /// A base64-encoded PNG image.
20 pub source: SharedString,
21 pub size: Size<DevicePixels>,
22}
23
24impl LanguageModelImage {
25 pub fn len(&self) -> usize {
26 self.source.len()
27 }
28
29 pub fn is_empty(&self) -> bool {
30 self.source.is_empty()
31 }
32
33 // Parse Self from a JSON object with case-insensitive field names
34 pub fn from_json(obj: &serde_json::Map<String, serde_json::Value>) -> Option<Self> {
35 let mut source = None;
36 let mut size_obj = None;
37
38 // Find source and size fields (case-insensitive)
39 for (k, v) in obj.iter() {
40 match k.to_lowercase().as_str() {
41 "source" => source = v.as_str(),
42 "size" => size_obj = v.as_object(),
43 _ => {}
44 }
45 }
46
47 let source = source?;
48 let size_obj = size_obj?;
49
50 let mut width = None;
51 let mut height = None;
52
53 // Find width and height in size object (case-insensitive)
54 for (k, v) in size_obj.iter() {
55 match k.to_lowercase().as_str() {
56 "width" => width = v.as_i64().map(|w| w as i32),
57 "height" => height = v.as_i64().map(|h| h as i32),
58 _ => {}
59 }
60 }
61
62 Some(Self {
63 size: size(DevicePixels(width?), DevicePixels(height?)),
64 source: SharedString::from(source.to_string()),
65 })
66 }
67}
68
69impl std::fmt::Debug for LanguageModelImage {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("LanguageModelImage")
72 .field("source", &format!("<{} bytes>", self.source.len()))
73 .field("size", &self.size)
74 .finish()
75 }
76}
77
78/// Anthropic wants uploaded images to be smaller than this in both dimensions.
79const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
80
81impl LanguageModelImage {
82 pub fn empty() -> Self {
83 Self {
84 source: "".into(),
85 size: size(DevicePixels(0), DevicePixels(0)),
86 }
87 }
88
89 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
90 cx.background_spawn(async move {
91 let image_bytes = Cursor::new(data.bytes());
92 let dynamic_image = match data.format() {
93 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
94 .and_then(image::DynamicImage::from_decoder),
95 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
96 .and_then(image::DynamicImage::from_decoder),
97 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
98 .and_then(image::DynamicImage::from_decoder),
99 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
100 .and_then(image::DynamicImage::from_decoder),
101 _ => return None,
102 }
103 .log_err()?;
104
105 let width = dynamic_image.width();
106 let height = dynamic_image.height();
107 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
108
109 let base64_image = {
110 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
111 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
112 {
113 let new_bounds = ObjectFit::ScaleDown.get_bounds(
114 gpui::Bounds {
115 origin: point(px(0.0), px(0.0)),
116 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
117 },
118 image_size,
119 );
120 let resized_image = dynamic_image.resize(
121 new_bounds.size.width.0 as u32,
122 new_bounds.size.height.0 as u32,
123 image::imageops::FilterType::Triangle,
124 );
125
126 encode_as_base64(data, resized_image)
127 } else {
128 encode_as_base64(data, dynamic_image)
129 }
130 }
131 .log_err()?;
132
133 // SAFETY: The base64 encoder should not produce non-UTF8.
134 let source = unsafe { String::from_utf8_unchecked(base64_image) };
135
136 Some(LanguageModelImage {
137 size: image_size,
138 source: source.into(),
139 })
140 })
141 }
142
143 pub fn estimate_tokens(&self) -> usize {
144 let width = self.size.width.0.unsigned_abs() as usize;
145 let height = self.size.height.0.unsigned_abs() as usize;
146
147 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
148 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
149 // so this method is more of a rough guess.
150 (width * height) / 750
151 }
152
153 pub fn to_base64_url(&self) -> String {
154 format!("data:image/png;base64,{}", self.source)
155 }
156}
157
158fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
159 let mut base64_image = Vec::new();
160 {
161 let mut base64_encoder = EncoderWriter::new(
162 Cursor::new(&mut base64_image),
163 &base64::engine::general_purpose::STANDARD,
164 );
165 if data.format() == ImageFormat::Png {
166 base64_encoder.write_all(data.bytes())?;
167 } else {
168 let mut png = Vec::new();
169 image.write_with_encoder(PngEncoder::new(&mut png))?;
170
171 base64_encoder.write_all(png.as_slice())?;
172 }
173 }
174 Ok(base64_image)
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
178pub struct LanguageModelToolResult {
179 pub tool_use_id: LanguageModelToolUseId,
180 pub tool_name: Arc<str>,
181 pub is_error: bool,
182 pub content: LanguageModelToolResultContent,
183 pub output: Option<serde_json::Value>,
184}
185
186#[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)]
187pub enum LanguageModelToolResultContent {
188 Text(Arc<str>),
189 Image(LanguageModelImage),
190}
191
192impl<'de> Deserialize<'de> for LanguageModelToolResultContent {
193 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
194 where
195 D: serde::Deserializer<'de>,
196 {
197 use serde::de::Error;
198
199 let value = serde_json::Value::deserialize(deserializer)?;
200
201 // Models can provide these responses in several styles. Try each in order.
202
203 // 1. Try as plain string
204 if let Ok(text) = serde_json::from_value::<String>(value.clone()) {
205 return Ok(Self::Text(Arc::from(text)));
206 }
207
208 // 2. Try as object
209 if let Some(obj) = value.as_object() {
210 // get a JSON field case-insensitively
211 fn get_field<'a>(
212 obj: &'a serde_json::Map<String, serde_json::Value>,
213 field: &str,
214 ) -> Option<&'a serde_json::Value> {
215 obj.iter()
216 .find(|(k, _)| k.to_lowercase() == field.to_lowercase())
217 .map(|(_, v)| v)
218 }
219
220 // Accept wrapped text format: { "type": "text", "text": "..." }
221 if let (Some(type_value), Some(text_value)) =
222 (get_field(&obj, "type"), get_field(&obj, "text"))
223 {
224 if let Some(type_str) = type_value.as_str() {
225 if type_str.to_lowercase() == "text" {
226 if let Some(text) = text_value.as_str() {
227 return Ok(Self::Text(Arc::from(text)));
228 }
229 }
230 }
231 }
232
233 // Check for wrapped Text variant: { "text": "..." }
234 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "text") {
235 if obj.len() == 1 {
236 // Only one field, and it's "text" (case-insensitive)
237 if let Some(text) = value.as_str() {
238 return Ok(Self::Text(Arc::from(text)));
239 }
240 }
241 }
242
243 // Check for wrapped Image variant: { "image": { "source": "...", "size": ... } }
244 if let Some((_key, value)) = obj.iter().find(|(k, _)| k.to_lowercase() == "image") {
245 if obj.len() == 1 {
246 // Only one field, and it's "image" (case-insensitive)
247 // Try to parse the nested image object
248 if let Some(image_obj) = value.as_object() {
249 if let Some(image) = LanguageModelImage::from_json(image_obj) {
250 return Ok(Self::Image(image));
251 }
252 }
253 }
254 }
255
256 // Try as direct Image (object with "source" and "size" fields)
257 if let Some(image) = LanguageModelImage::from_json(&obj) {
258 return Ok(Self::Image(image));
259 }
260 }
261
262 // If none of the variants match, return an error with the problematic JSON
263 Err(D::Error::custom(format!(
264 "data did not match any variant of LanguageModelToolResultContent. Expected either a string, \
265 an object with 'type': 'text', a wrapped variant like {{\"Text\": \"...\"}}, or an image object. Got: {}",
266 serde_json::to_string_pretty(&value).unwrap_or_else(|_| value.to_string())
267 )))
268 }
269}
270
271impl LanguageModelToolResultContent {
272 pub fn to_str(&self) -> Option<&str> {
273 match self {
274 Self::Text(text) => Some(&text),
275 Self::Image(_) => None,
276 }
277 }
278
279 pub fn is_empty(&self) -> bool {
280 match self {
281 Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
282 Self::Image(_) => false,
283 }
284 }
285}
286
287impl From<&str> for LanguageModelToolResultContent {
288 fn from(value: &str) -> Self {
289 Self::Text(Arc::from(value))
290 }
291}
292
293impl From<String> for LanguageModelToolResultContent {
294 fn from(value: String) -> Self {
295 Self::Text(Arc::from(value))
296 }
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
300pub enum MessageContent {
301 Text(String),
302 Thinking {
303 text: String,
304 signature: Option<String>,
305 },
306 RedactedThinking(Vec<u8>),
307 Image(LanguageModelImage),
308 ToolUse(LanguageModelToolUse),
309 ToolResult(LanguageModelToolResult),
310}
311
312impl MessageContent {
313 pub fn to_str(&self) -> Option<&str> {
314 match self {
315 MessageContent::Text(text) => Some(text.as_str()),
316 MessageContent::Thinking { text, .. } => Some(text.as_str()),
317 MessageContent::RedactedThinking(_) => None,
318 MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
319 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
320 }
321 }
322
323 pub fn is_empty(&self) -> bool {
324 match self {
325 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
326 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
327 MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
328 MessageContent::RedactedThinking(_)
329 | MessageContent::ToolUse(_)
330 | MessageContent::Image(_) => false,
331 }
332 }
333}
334
335impl From<String> for MessageContent {
336 fn from(value: String) -> Self {
337 MessageContent::Text(value)
338 }
339}
340
341impl From<&str> for MessageContent {
342 fn from(value: &str) -> Self {
343 MessageContent::Text(value.to_string())
344 }
345}
346
347#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
348pub struct LanguageModelRequestMessage {
349 pub role: Role,
350 pub content: Vec<MessageContent>,
351 pub cache: bool,
352}
353
354impl LanguageModelRequestMessage {
355 pub fn string_contents(&self) -> String {
356 let mut buffer = String::new();
357 for string in self.content.iter().filter_map(|content| content.to_str()) {
358 buffer.push_str(string);
359 }
360
361 buffer
362 }
363
364 pub fn contents_empty(&self) -> bool {
365 self.content.iter().all(|content| content.is_empty())
366 }
367}
368
369#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
370pub struct LanguageModelRequestTool {
371 pub name: String,
372 pub description: String,
373 pub input_schema: serde_json::Value,
374}
375
376#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
377pub enum LanguageModelToolChoice {
378 Auto,
379 Any,
380 None,
381}
382
383#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
384pub struct LanguageModelRequest {
385 pub thread_id: Option<String>,
386 pub prompt_id: Option<String>,
387 pub mode: Option<CompletionMode>,
388 pub messages: Vec<LanguageModelRequestMessage>,
389 pub tools: Vec<LanguageModelRequestTool>,
390 pub tool_choice: Option<LanguageModelToolChoice>,
391 pub stop: Vec<String>,
392 pub temperature: Option<f32>,
393}
394
395#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
396pub struct LanguageModelResponseMessage {
397 pub role: Option<Role>,
398 pub content: Option<String>,
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_language_model_tool_result_content_deserialization() {
407 let json = r#""This is plain text""#;
408 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
409 assert_eq!(
410 result,
411 LanguageModelToolResultContent::Text("This is plain text".into())
412 );
413
414 let json = r#"{"type": "text", "text": "This is wrapped text"}"#;
415 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
416 assert_eq!(
417 result,
418 LanguageModelToolResultContent::Text("This is wrapped text".into())
419 );
420
421 let json = r#"{"Type": "TEXT", "TEXT": "Case insensitive"}"#;
422 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
423 assert_eq!(
424 result,
425 LanguageModelToolResultContent::Text("Case insensitive".into())
426 );
427
428 let json = r#"{"Text": "Wrapped variant"}"#;
429 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
430 assert_eq!(
431 result,
432 LanguageModelToolResultContent::Text("Wrapped variant".into())
433 );
434
435 let json = r#"{"text": "Lowercase wrapped"}"#;
436 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
437 assert_eq!(
438 result,
439 LanguageModelToolResultContent::Text("Lowercase wrapped".into())
440 );
441
442 // Test image deserialization
443 let json = r#"{
444 "source": "base64encodedimagedata",
445 "size": {
446 "width": 100,
447 "height": 200
448 }
449 }"#;
450 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
451 match result {
452 LanguageModelToolResultContent::Image(image) => {
453 assert_eq!(image.source.as_ref(), "base64encodedimagedata");
454 assert_eq!(image.size.width.0, 100);
455 assert_eq!(image.size.height.0, 200);
456 }
457 _ => panic!("Expected Image variant"),
458 }
459
460 // Test wrapped Image variant
461 let json = r#"{
462 "Image": {
463 "source": "wrappedimagedata",
464 "size": {
465 "width": 50,
466 "height": 75
467 }
468 }
469 }"#;
470 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
471 match result {
472 LanguageModelToolResultContent::Image(image) => {
473 assert_eq!(image.source.as_ref(), "wrappedimagedata");
474 assert_eq!(image.size.width.0, 50);
475 assert_eq!(image.size.height.0, 75);
476 }
477 _ => panic!("Expected Image variant"),
478 }
479
480 // Test wrapped Image variant with case insensitive
481 let json = r#"{
482 "image": {
483 "Source": "caseinsensitive",
484 "SIZE": {
485 "width": 30,
486 "height": 40
487 }
488 }
489 }"#;
490 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
491 match result {
492 LanguageModelToolResultContent::Image(image) => {
493 assert_eq!(image.source.as_ref(), "caseinsensitive");
494 assert_eq!(image.size.width.0, 30);
495 assert_eq!(image.size.height.0, 40);
496 }
497 _ => panic!("Expected Image variant"),
498 }
499
500 // Test that wrapped text with wrong type fails
501 let json = r#"{"type": "blahblah", "text": "This should fail"}"#;
502 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
503 assert!(result.is_err());
504
505 // Test that malformed JSON fails
506 let json = r#"{"invalid": "structure"}"#;
507 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
508 assert!(result.is_err());
509
510 // Test edge cases
511 let json = r#""""#; // Empty string
512 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
513 assert_eq!(result, LanguageModelToolResultContent::Text("".into()));
514
515 // Test with extra fields in wrapped text (should be ignored)
516 let json = r#"{"type": "text", "text": "Hello", "extra": "field"}"#;
517 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
518 assert_eq!(result, LanguageModelToolResultContent::Text("Hello".into()));
519
520 // Test direct image with case-insensitive fields
521 let json = r#"{
522 "SOURCE": "directimage",
523 "Size": {
524 "width": 200,
525 "height": 300
526 }
527 }"#;
528 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
529 match result {
530 LanguageModelToolResultContent::Image(image) => {
531 assert_eq!(image.source.as_ref(), "directimage");
532 assert_eq!(image.size.width.0, 200);
533 assert_eq!(image.size.height.0, 300);
534 }
535 _ => panic!("Expected Image variant"),
536 }
537
538 // Test that multiple fields prevent wrapped variant interpretation
539 let json = r#"{"Text": "not wrapped", "extra": "field"}"#;
540 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
541 assert!(result.is_err());
542
543 // Test wrapped text with uppercase TEXT variant
544 let json = r#"{"TEXT": "Uppercase variant"}"#;
545 let result: LanguageModelToolResultContent = serde_json::from_str(json).unwrap();
546 assert_eq!(
547 result,
548 LanguageModelToolResultContent::Text("Uppercase variant".into())
549 );
550
551 // Test that numbers and other JSON values fail gracefully
552 let json = r#"123"#;
553 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
554 assert!(result.is_err());
555
556 let json = r#"null"#;
557 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
558 assert!(result.is_err());
559
560 let json = r#"[1, 2, 3]"#;
561 let result: Result<LanguageModelToolResultContent, _> = serde_json::from_str(json);
562 assert!(result.is_err());
563 }
564}