1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use crate::role::Role;
5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
6use anyhow::Result;
7use base64::write::EncoderWriter;
8use gpui::{
9 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
10 point, px, size,
11};
12use image::codecs::png::PngEncoder;
13use serde::{Deserialize, Serialize};
14use util::ResultExt;
15use zed_llm_client::CompletionMode;
16
17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
18pub struct LanguageModelImage {
19 /// A base64-encoded PNG image.
20 pub source: SharedString,
21 size: Size<DevicePixels>,
22}
23
24impl LanguageModelImage {
25 pub fn len(&self) -> usize {
26 self.source.len()
27 }
28
29 pub fn is_empty(&self) -> bool {
30 self.source.is_empty()
31 }
32}
33
34impl std::fmt::Debug for LanguageModelImage {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("LanguageModelImage")
37 .field("source", &format!("<{} bytes>", self.source.len()))
38 .field("size", &self.size)
39 .finish()
40 }
41}
42
43/// Anthropic wants uploaded images to be smaller than this in both dimensions.
44const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
45
46impl LanguageModelImage {
47 pub fn empty() -> Self {
48 Self {
49 source: "".into(),
50 size: size(DevicePixels(0), DevicePixels(0)),
51 }
52 }
53
54 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
55 cx.background_spawn(async move {
56 let image_bytes = Cursor::new(data.bytes());
57 let dynamic_image = match data.format() {
58 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
59 .and_then(image::DynamicImage::from_decoder),
60 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
61 .and_then(image::DynamicImage::from_decoder),
62 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
63 .and_then(image::DynamicImage::from_decoder),
64 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
65 .and_then(image::DynamicImage::from_decoder),
66 _ => return None,
67 }
68 .log_err()?;
69
70 let width = dynamic_image.width();
71 let height = dynamic_image.height();
72 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
73
74 let base64_image = {
75 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
76 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
77 {
78 let new_bounds = ObjectFit::ScaleDown.get_bounds(
79 gpui::Bounds {
80 origin: point(px(0.0), px(0.0)),
81 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
82 },
83 image_size,
84 );
85 let resized_image = dynamic_image.resize(
86 new_bounds.size.width.0 as u32,
87 new_bounds.size.height.0 as u32,
88 image::imageops::FilterType::Triangle,
89 );
90
91 encode_as_base64(data, resized_image)
92 } else {
93 encode_as_base64(data, dynamic_image)
94 }
95 }
96 .log_err()?;
97
98 // SAFETY: The base64 encoder should not produce non-UTF8.
99 let source = unsafe { String::from_utf8_unchecked(base64_image) };
100
101 Some(LanguageModelImage {
102 size: image_size,
103 source: source.into(),
104 })
105 })
106 }
107
108 pub fn estimate_tokens(&self) -> usize {
109 let width = self.size.width.0.unsigned_abs() as usize;
110 let height = self.size.height.0.unsigned_abs() as usize;
111
112 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
113 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
114 // so this method is more of a rough guess.
115 (width * height) / 750
116 }
117
118 pub fn to_base64_url(&self) -> String {
119 format!("data:image/png;base64,{}", self.source)
120 }
121}
122
123fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
124 let mut base64_image = Vec::new();
125 {
126 let mut base64_encoder = EncoderWriter::new(
127 Cursor::new(&mut base64_image),
128 &base64::engine::general_purpose::STANDARD,
129 );
130 if data.format() == ImageFormat::Png {
131 base64_encoder.write_all(data.bytes())?;
132 } else {
133 let mut png = Vec::new();
134 image.write_with_encoder(PngEncoder::new(&mut png))?;
135
136 base64_encoder.write_all(png.as_slice())?;
137 }
138 }
139 Ok(base64_image)
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
143pub struct LanguageModelToolResult {
144 pub tool_use_id: LanguageModelToolUseId,
145 pub tool_name: Arc<str>,
146 pub is_error: bool,
147 pub content: LanguageModelToolResultContent,
148 pub output: Option<serde_json::Value>,
149}
150
151#[derive(Debug, Clone, Deserialize, Serialize, Eq, PartialEq, Hash)]
152#[serde(untagged)]
153pub enum LanguageModelToolResultContent {
154 Text(Arc<str>),
155 Image(LanguageModelImage),
156}
157
158impl LanguageModelToolResultContent {
159 pub fn to_str(&self) -> Option<&str> {
160 match self {
161 Self::Text(text) => Some(&text),
162 Self::Image(_) => None,
163 }
164 }
165
166 pub fn is_empty(&self) -> bool {
167 match self {
168 Self::Text(text) => text.chars().all(|c| c.is_whitespace()),
169 Self::Image(_) => false,
170 }
171 }
172}
173
174impl From<&str> for LanguageModelToolResultContent {
175 fn from(value: &str) -> Self {
176 Self::Text(Arc::from(value))
177 }
178}
179
180impl From<String> for LanguageModelToolResultContent {
181 fn from(value: String) -> Self {
182 Self::Text(Arc::from(value))
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
187pub enum MessageContent {
188 Text(String),
189 Thinking {
190 text: String,
191 signature: Option<String>,
192 },
193 RedactedThinking(Vec<u8>),
194 Image(LanguageModelImage),
195 ToolUse(LanguageModelToolUse),
196 ToolResult(LanguageModelToolResult),
197}
198
199impl MessageContent {
200 pub fn to_str(&self) -> Option<&str> {
201 match self {
202 MessageContent::Text(text) => Some(text.as_str()),
203 MessageContent::Thinking { text, .. } => Some(text.as_str()),
204 MessageContent::RedactedThinking(_) => None,
205 MessageContent::ToolResult(tool_result) => tool_result.content.to_str(),
206 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
207 }
208 }
209
210 pub fn is_empty(&self) -> bool {
211 match self {
212 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
213 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
214 MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(),
215 MessageContent::RedactedThinking(_)
216 | MessageContent::ToolUse(_)
217 | MessageContent::Image(_) => false,
218 }
219 }
220}
221
222impl From<String> for MessageContent {
223 fn from(value: String) -> Self {
224 MessageContent::Text(value)
225 }
226}
227
228impl From<&str> for MessageContent {
229 fn from(value: &str) -> Self {
230 MessageContent::Text(value.to_string())
231 }
232}
233
234#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
235pub struct LanguageModelRequestMessage {
236 pub role: Role,
237 pub content: Vec<MessageContent>,
238 pub cache: bool,
239}
240
241impl LanguageModelRequestMessage {
242 pub fn string_contents(&self) -> String {
243 let mut buffer = String::new();
244 for string in self.content.iter().filter_map(|content| content.to_str()) {
245 buffer.push_str(string);
246 }
247
248 buffer
249 }
250
251 pub fn contents_empty(&self) -> bool {
252 self.content.iter().all(|content| content.is_empty())
253 }
254}
255
256#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
257pub struct LanguageModelRequestTool {
258 pub name: String,
259 pub description: String,
260 pub input_schema: serde_json::Value,
261}
262
263#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
264pub enum LanguageModelToolChoice {
265 Auto,
266 Any,
267 None,
268}
269
270#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
271pub struct LanguageModelRequest {
272 pub thread_id: Option<String>,
273 pub prompt_id: Option<String>,
274 pub mode: Option<CompletionMode>,
275 pub messages: Vec<LanguageModelRequestMessage>,
276 pub tools: Vec<LanguageModelRequestTool>,
277 pub tool_choice: Option<LanguageModelToolChoice>,
278 pub stop: Vec<String>,
279 pub temperature: Option<f32>,
280}
281
282#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
283pub struct LanguageModelResponseMessage {
284 pub role: Option<Role>,
285 pub content: Option<String>,
286}