1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use crate::role::Role;
5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
6use anyhow::Result;
7use base64::write::EncoderWriter;
8use gpui::{
9 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
10 point, px, size,
11};
12use image::codecs::png::PngEncoder;
13use serde::{Deserialize, Serialize};
14use util::ResultExt;
15use zed_llm_client::CompletionMode;
16
17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
18pub struct LanguageModelImage {
19 /// A base64-encoded PNG image.
20 pub source: SharedString,
21 size: Size<DevicePixels>,
22}
23
24impl LanguageModelImage {
25 pub fn len(&self) -> usize {
26 self.source.len()
27 }
28
29 pub fn is_empty(&self) -> bool {
30 self.source.is_empty()
31 }
32}
33
34impl std::fmt::Debug for LanguageModelImage {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("LanguageModelImage")
37 .field("source", &format!("<{} bytes>", self.source.len()))
38 .field("size", &self.size)
39 .finish()
40 }
41}
42
43/// Anthropic wants uploaded images to be smaller than this in both dimensions.
44const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
45
46impl LanguageModelImage {
47 pub fn empty() -> Self {
48 Self {
49 source: "".into(),
50 size: size(DevicePixels(0), DevicePixels(0)),
51 }
52 }
53
54 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
55 cx.background_spawn(async move {
56 let image_bytes = Cursor::new(data.bytes());
57 let dynamic_image = match data.format() {
58 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
59 .and_then(image::DynamicImage::from_decoder),
60 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
61 .and_then(image::DynamicImage::from_decoder),
62 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
63 .and_then(image::DynamicImage::from_decoder),
64 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
65 .and_then(image::DynamicImage::from_decoder),
66 _ => return None,
67 }
68 .log_err()?;
69
70 let width = dynamic_image.width();
71 let height = dynamic_image.height();
72 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
73
74 let base64_image = {
75 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
76 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
77 {
78 let new_bounds = ObjectFit::ScaleDown.get_bounds(
79 gpui::Bounds {
80 origin: point(px(0.0), px(0.0)),
81 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
82 },
83 image_size,
84 );
85 let resized_image = dynamic_image.resize(
86 new_bounds.size.width.0 as u32,
87 new_bounds.size.height.0 as u32,
88 image::imageops::FilterType::Triangle,
89 );
90
91 encode_as_base64(data, resized_image)
92 } else {
93 encode_as_base64(data, dynamic_image)
94 }
95 }
96 .log_err()?;
97
98 // SAFETY: The base64 encoder should not produce non-UTF8.
99 let source = unsafe { String::from_utf8_unchecked(base64_image) };
100
101 Some(LanguageModelImage {
102 size: image_size,
103 source: source.into(),
104 })
105 })
106 }
107
108 pub fn estimate_tokens(&self) -> usize {
109 let width = self.size.width.0.unsigned_abs() as usize;
110 let height = self.size.height.0.unsigned_abs() as usize;
111
112 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
113 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
114 // so this method is more of a rough guess.
115 (width * height) / 750
116 }
117
118 pub fn to_base64_url(&self) -> String {
119 format!("data:image/png;base64,{}", self.source)
120 }
121}
122
123fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
124 let mut base64_image = Vec::new();
125 {
126 let mut base64_encoder = EncoderWriter::new(
127 Cursor::new(&mut base64_image),
128 &base64::engine::general_purpose::STANDARD,
129 );
130 if data.format() == ImageFormat::Png {
131 base64_encoder.write_all(data.bytes())?;
132 } else {
133 let mut png = Vec::new();
134 image.write_with_encoder(PngEncoder::new(&mut png))?;
135
136 base64_encoder.write_all(png.as_slice())?;
137 }
138 }
139 Ok(base64_image)
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
143pub struct LanguageModelToolResult {
144 pub tool_use_id: LanguageModelToolUseId,
145 pub tool_name: Arc<str>,
146 pub is_error: bool,
147 pub content: LanguageModelToolResultContent,
148 pub output: Option<serde_json::Value>,
149}
150
151#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
152#[serde(untagged)]
153pub enum LanguageModelToolResultContent {
154 Text(Arc<str>),
155 Image(LanguageModelImage),
156 WrappedText(WrappedTextContent),
157}
158
159#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
160pub struct WrappedTextContent {
161 #[serde(rename = "type")]
162 pub content_type: String,
163 pub text: Arc<str>,
164}
165
166impl LanguageModelToolResultContent {
167 pub fn to_str(&self) -> Option<&str> {
168 match self {
169 Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => Some(&text),
170 Self::Image(_) => None,
171 }
172 }
173
174 pub fn is_empty(&self) -> bool {
175 match self {
176 Self::Text(text) | Self::WrappedText(WrappedTextContent { text, .. }) => {
177 text.chars().all(|c| c.is_whitespace())
178 }
179 Self::Image(_) => false,
180 }
181 }
182}
183
184impl From<&str> for LanguageModelToolResultContent {
185 fn from(value: &str) -> Self {
186 Self::Text(Arc::from(value))
187 }
188}
189
190impl From<String> for LanguageModelToolResultContent {
191 fn from(value: String) -> Self {
192 Self::Text(Arc::from(value))
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
197pub enum MessageContent {
198 Text(String),
199 Thinking {
200 text: String,
201 signature: Option<String>,
202 },
203 RedactedThinking(Vec<u8>),
204 Image(LanguageModelImage),
205 ToolUse(LanguageModelToolUse),
206 ToolResult(LanguageModelToolResult),
207}
208
209impl MessageContent {
210 pub fn to_str(&self) -> Option<&str> {
211 match self {
212 MessageContent::Text(text) => Some(text.as_str()),
213 MessageContent::Thinking { text, .. } => Some(text.as_str()),
214 MessageContent::RedactedThinking(_) => None,
215 MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
216 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
217 }
218 }
219
220 pub fn is_empty(&self) -> bool {
221 match self {
222 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
223 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
224 MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
225 MessageContent::RedactedThinking(_)
226 | MessageContent::ToolUse(_)
227 | MessageContent::Image(_) => false,
228 }
229 }
230}
231
232impl From<String> for MessageContent {
233 fn from(value: String) -> Self {
234 MessageContent::Text(value)
235 }
236}
237
238impl From<&str> for MessageContent {
239 fn from(value: &str) -> Self {
240 MessageContent::Text(value.to_string())
241 }
242}
243
244#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
245pub struct LanguageModelRequestMessage {
246 pub role: Role,
247 pub content: Vec<MessageContent>,
248 pub cache: bool,
249}
250
251impl LanguageModelRequestMessage {
252 pub fn string_contents(&self) -> String {
253 let mut buffer = String::new();
254 for string in self.content.iter().filter_map(|content| content.to_str()) {
255 buffer.push_str(string);
256 }
257
258 buffer
259 }
260
261 pub fn contents_empty(&self) -> bool {
262 self.content.iter().all(|content| content.is_empty())
263 }
264}
265
266#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
267pub struct LanguageModelRequestTool {
268 pub name: String,
269 pub description: String,
270 pub input_schema: serde_json::Value,
271}
272
273#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
274pub enum LanguageModelToolChoice {
275 Auto,
276 Any,
277 None,
278}
279
280#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
281pub struct LanguageModelRequest {
282 pub thread_id: Option<String>,
283 pub prompt_id: Option<String>,
284 pub mode: Option<CompletionMode>,
285 pub messages: Vec<LanguageModelRequestMessage>,
286 pub tools: Vec<LanguageModelRequestTool>,
287 pub tool_choice: Option<LanguageModelToolChoice>,
288 pub stop: Vec<String>,
289 pub temperature: Option<f32>,
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293pub struct LanguageModelResponseMessage {
294 pub role: Option<Role>,
295 pub content: Option<String>,
296}