1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use crate::role::Role;
5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
6use anyhow::Result;
7use base64::write::EncoderWriter;
8use gpui::{
9 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
10 point, px, size,
11};
12use image::codecs::png::PngEncoder;
13use serde::{Deserialize, Serialize};
14use util::ResultExt;
15use zed_llm_client::CompletionMode;
16
17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
18pub struct LanguageModelImage {
19 /// A base64-encoded PNG image.
20 pub source: SharedString,
21 size: Size<DevicePixels>,
22}
23
24impl std::fmt::Debug for LanguageModelImage {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("LanguageModelImage")
27 .field("source", &format!("<{} bytes>", self.source.len()))
28 .field("size", &self.size)
29 .finish()
30 }
31}
32
33/// Anthropic wants uploaded images to be smaller than this in both dimensions.
34const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
35
36impl LanguageModelImage {
37 pub fn empty() -> Self {
38 Self {
39 source: "".into(),
40 size: size(DevicePixels(0), DevicePixels(0)),
41 }
42 }
43
44 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
45 cx.background_spawn(async move {
46 let image_bytes = Cursor::new(data.bytes());
47 let dynamic_image = match data.format() {
48 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
49 .and_then(image::DynamicImage::from_decoder),
50 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
51 .and_then(image::DynamicImage::from_decoder),
52 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
53 .and_then(image::DynamicImage::from_decoder),
54 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
55 .and_then(image::DynamicImage::from_decoder),
56 _ => return None,
57 }
58 .log_err()?;
59
60 let width = dynamic_image.width();
61 let height = dynamic_image.height();
62 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
63
64 let base64_image = {
65 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
66 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
67 {
68 let new_bounds = ObjectFit::ScaleDown.get_bounds(
69 gpui::Bounds {
70 origin: point(px(0.0), px(0.0)),
71 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
72 },
73 image_size,
74 );
75 let resized_image = dynamic_image.resize(
76 new_bounds.size.width.0 as u32,
77 new_bounds.size.height.0 as u32,
78 image::imageops::FilterType::Triangle,
79 );
80
81 encode_as_base64(data, resized_image)
82 } else {
83 encode_as_base64(data, dynamic_image)
84 }
85 }
86 .log_err()?;
87
88 // SAFETY: The base64 encoder should not produce non-UTF8.
89 let source = unsafe { String::from_utf8_unchecked(base64_image) };
90
91 Some(LanguageModelImage {
92 size: image_size,
93 source: source.into(),
94 })
95 })
96 }
97
98 pub fn estimate_tokens(&self) -> usize {
99 let width = self.size.width.0.unsigned_abs() as usize;
100 let height = self.size.height.0.unsigned_abs() as usize;
101
102 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
103 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
104 // so this method is more of a rough guess.
105 (width * height) / 750
106 }
107
108 pub fn to_base64_url(&self) -> String {
109 format!("data:image/png;base64,{}", self.source)
110 }
111}
112
113fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
114 let mut base64_image = Vec::new();
115 {
116 let mut base64_encoder = EncoderWriter::new(
117 Cursor::new(&mut base64_image),
118 &base64::engine::general_purpose::STANDARD,
119 );
120 if data.format() == ImageFormat::Png {
121 base64_encoder.write_all(data.bytes())?;
122 } else {
123 let mut png = Vec::new();
124 image.write_with_encoder(PngEncoder::new(&mut png))?;
125
126 base64_encoder.write_all(png.as_slice())?;
127 }
128 }
129 Ok(base64_image)
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
133pub struct LanguageModelToolResult {
134 pub tool_use_id: LanguageModelToolUseId,
135 pub tool_name: Arc<str>,
136 pub is_error: bool,
137 pub content: Arc<str>,
138 pub output: Option<serde_json::Value>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
142pub enum MessageContent {
143 Text(String),
144 Thinking {
145 text: String,
146 signature: Option<String>,
147 },
148 RedactedThinking(Vec<u8>),
149 Image(LanguageModelImage),
150 ToolUse(LanguageModelToolUse),
151 ToolResult(LanguageModelToolResult),
152}
153
154impl From<String> for MessageContent {
155 fn from(value: String) -> Self {
156 MessageContent::Text(value)
157 }
158}
159
160impl From<&str> for MessageContent {
161 fn from(value: &str) -> Self {
162 MessageContent::Text(value.to_string())
163 }
164}
165
166#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
167pub struct LanguageModelRequestMessage {
168 pub role: Role,
169 pub content: Vec<MessageContent>,
170 pub cache: bool,
171}
172
173impl LanguageModelRequestMessage {
174 pub fn string_contents(&self) -> String {
175 let mut buffer = String::new();
176 for string in self.content.iter().filter_map(|content| match content {
177 MessageContent::Text(text) => Some(text.as_str()),
178 MessageContent::Thinking { text, .. } => Some(text.as_str()),
179 MessageContent::RedactedThinking(_) => None,
180 MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()),
181 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
182 }) {
183 buffer.push_str(string);
184 }
185
186 buffer
187 }
188
189 pub fn contents_empty(&self) -> bool {
190 self.content.iter().all(|content| match content {
191 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
192 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
193 MessageContent::ToolResult(tool_result) => {
194 tool_result.content.chars().all(|c| c.is_whitespace())
195 }
196 MessageContent::RedactedThinking(_)
197 | MessageContent::ToolUse(_)
198 | MessageContent::Image(_) => false,
199 })
200 }
201}
202
203#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
204pub struct LanguageModelRequestTool {
205 pub name: String,
206 pub description: String,
207 pub input_schema: serde_json::Value,
208}
209
210#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
211pub enum LanguageModelToolChoice {
212 Auto,
213 Any,
214 None,
215}
216
217#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
218pub struct LanguageModelRequest {
219 pub thread_id: Option<String>,
220 pub prompt_id: Option<String>,
221 pub mode: Option<CompletionMode>,
222 pub messages: Vec<LanguageModelRequestMessage>,
223 pub tools: Vec<LanguageModelRequestTool>,
224 pub tool_choice: Option<LanguageModelToolChoice>,
225 pub stop: Vec<String>,
226 pub temperature: Option<f32>,
227}
228
229#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
230pub struct LanguageModelResponseMessage {
231 pub role: Option<Role>,
232 pub content: Option<String>,
233}