1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use crate::role::Role;
5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
6use anyhow::Result;
7use base64::write::EncoderWriter;
8use gpui::{
9 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
10 point, px, size,
11};
12use image::codecs::png::PngEncoder;
13use serde::{Deserialize, Serialize};
14use util::ResultExt;
15use zed_llm_client::CompletionMode;
16
17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
18pub struct LanguageModelImage {
19 /// A base64-encoded PNG image.
20 pub source: SharedString,
21 size: Size<DevicePixels>,
22}
23
24impl std::fmt::Debug for LanguageModelImage {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("LanguageModelImage")
27 .field("source", &format!("<{} bytes>", self.source.len()))
28 .field("size", &self.size)
29 .finish()
30 }
31}
32
33/// Anthropic wants uploaded images to be smaller than this in both dimensions.
34const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
35
36impl LanguageModelImage {
37 pub fn empty() -> Self {
38 Self {
39 source: "".into(),
40 size: size(DevicePixels(0), DevicePixels(0)),
41 }
42 }
43
44 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
45 cx.background_spawn(async move {
46 let image_bytes = Cursor::new(data.bytes());
47 let dynamic_image = match data.format() {
48 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
49 .and_then(image::DynamicImage::from_decoder),
50 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
51 .and_then(image::DynamicImage::from_decoder),
52 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
53 .and_then(image::DynamicImage::from_decoder),
54 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
55 .and_then(image::DynamicImage::from_decoder),
56 _ => return None,
57 }
58 .log_err()?;
59
60 let width = dynamic_image.width();
61 let height = dynamic_image.height();
62 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
63
64 let base64_image = {
65 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
66 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
67 {
68 let new_bounds = ObjectFit::ScaleDown.get_bounds(
69 gpui::Bounds {
70 origin: point(px(0.0), px(0.0)),
71 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
72 },
73 image_size,
74 );
75 let resized_image = dynamic_image.resize(
76 new_bounds.size.width.0 as u32,
77 new_bounds.size.height.0 as u32,
78 image::imageops::FilterType::Triangle,
79 );
80
81 encode_as_base64(data, resized_image)
82 } else {
83 encode_as_base64(data, dynamic_image)
84 }
85 }
86 .log_err()?;
87
88 // SAFETY: The base64 encoder should not produce non-UTF8.
89 let source = unsafe { String::from_utf8_unchecked(base64_image) };
90
91 Some(LanguageModelImage {
92 size: image_size,
93 source: source.into(),
94 })
95 })
96 }
97
98 pub fn estimate_tokens(&self) -> usize {
99 let width = self.size.width.0.unsigned_abs() as usize;
100 let height = self.size.height.0.unsigned_abs() as usize;
101
102 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
103 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
104 // so this method is more of a rough guess.
105 (width * height) / 750
106 }
107}
108
109fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
110 let mut base64_image = Vec::new();
111 {
112 let mut base64_encoder = EncoderWriter::new(
113 Cursor::new(&mut base64_image),
114 &base64::engine::general_purpose::STANDARD,
115 );
116 if data.format() == ImageFormat::Png {
117 base64_encoder.write_all(data.bytes())?;
118 } else {
119 let mut png = Vec::new();
120 image.write_with_encoder(PngEncoder::new(&mut png))?;
121
122 base64_encoder.write_all(png.as_slice())?;
123 }
124 }
125 Ok(base64_image)
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
129pub struct LanguageModelToolResult {
130 pub tool_use_id: LanguageModelToolUseId,
131 pub tool_name: Arc<str>,
132 pub is_error: bool,
133 pub content: Arc<str>,
134 pub output: Option<serde_json::Value>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
138pub enum MessageContent {
139 Text(String),
140 Thinking {
141 text: String,
142 signature: Option<String>,
143 },
144 RedactedThinking(Vec<u8>),
145 Image(LanguageModelImage),
146 ToolUse(LanguageModelToolUse),
147 ToolResult(LanguageModelToolResult),
148}
149
150impl From<String> for MessageContent {
151 fn from(value: String) -> Self {
152 MessageContent::Text(value)
153 }
154}
155
156impl From<&str> for MessageContent {
157 fn from(value: &str) -> Self {
158 MessageContent::Text(value.to_string())
159 }
160}
161
162#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
163pub struct LanguageModelRequestMessage {
164 pub role: Role,
165 pub content: Vec<MessageContent>,
166 pub cache: bool,
167}
168
169impl LanguageModelRequestMessage {
170 pub fn string_contents(&self) -> String {
171 let mut buffer = String::new();
172 for string in self.content.iter().filter_map(|content| match content {
173 MessageContent::Text(text) => Some(text.as_str()),
174 MessageContent::Thinking { text, .. } => Some(text.as_str()),
175 MessageContent::RedactedThinking(_) => None,
176 MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()),
177 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
178 }) {
179 buffer.push_str(string);
180 }
181
182 buffer
183 }
184
185 pub fn contents_empty(&self) -> bool {
186 self.content.iter().all(|content| match content {
187 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
188 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
189 MessageContent::ToolResult(tool_result) => {
190 tool_result.content.chars().all(|c| c.is_whitespace())
191 }
192 MessageContent::RedactedThinking(_)
193 | MessageContent::ToolUse(_)
194 | MessageContent::Image(_) => false,
195 })
196 }
197}
198
199#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
200pub struct LanguageModelRequestTool {
201 pub name: String,
202 pub description: String,
203 pub input_schema: serde_json::Value,
204}
205
206#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
207pub enum LanguageModelToolChoice {
208 Auto,
209 Any,
210 None,
211}
212
213#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
214pub struct LanguageModelRequest {
215 pub thread_id: Option<String>,
216 pub prompt_id: Option<String>,
217 pub mode: Option<CompletionMode>,
218 pub messages: Vec<LanguageModelRequestMessage>,
219 pub tools: Vec<LanguageModelRequestTool>,
220 pub tool_choice: Option<LanguageModelToolChoice>,
221 pub stop: Vec<String>,
222 pub temperature: Option<f32>,
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226pub struct LanguageModelResponseMessage {
227 pub role: Option<Role>,
228 pub content: Option<String>,
229}