1use std::io::{Cursor, Write};
2use std::sync::Arc;
3
4use crate::role::Role;
5use crate::{LanguageModelToolUse, LanguageModelToolUseId};
6use anyhow::Result;
7use base64::write::EncoderWriter;
8use gpui::{
9 App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
10 point, px, size,
11};
12use image::codecs::png::PngEncoder;
13use serde::{Deserialize, Serialize};
14use util::ResultExt;
15use zed_llm_client::CompletionMode;
16
17#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
18pub struct LanguageModelImage {
19 /// A base64-encoded PNG image.
20 pub source: SharedString,
21 size: Size<DevicePixels>,
22}
23
24impl std::fmt::Debug for LanguageModelImage {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("LanguageModelImage")
27 .field("source", &format!("<{} bytes>", self.source.len()))
28 .field("size", &self.size)
29 .finish()
30 }
31}
32
33/// Anthropic wants uploaded images to be smaller than this in both dimensions.
34const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
35
36impl LanguageModelImage {
37 pub fn empty() -> Self {
38 Self {
39 source: "".into(),
40 size: size(DevicePixels(0), DevicePixels(0)),
41 }
42 }
43
44 pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
45 cx.background_spawn(async move {
46 let image_bytes = Cursor::new(data.bytes());
47 let dynamic_image = match data.format() {
48 ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
49 .and_then(image::DynamicImage::from_decoder),
50 ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
51 .and_then(image::DynamicImage::from_decoder),
52 ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
53 .and_then(image::DynamicImage::from_decoder),
54 ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
55 .and_then(image::DynamicImage::from_decoder),
56 _ => return None,
57 }
58 .log_err()?;
59
60 let width = dynamic_image.width();
61 let height = dynamic_image.height();
62 let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
63
64 let base64_image = {
65 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
66 || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
67 {
68 let new_bounds = ObjectFit::ScaleDown.get_bounds(
69 gpui::Bounds {
70 origin: point(px(0.0), px(0.0)),
71 size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
72 },
73 image_size,
74 );
75 let resized_image = dynamic_image.resize(
76 new_bounds.size.width.0 as u32,
77 new_bounds.size.height.0 as u32,
78 image::imageops::FilterType::Triangle,
79 );
80
81 encode_as_base64(data, resized_image)
82 } else {
83 encode_as_base64(data, dynamic_image)
84 }
85 }
86 .log_err()?;
87
88 // SAFETY: The base64 encoder should not produce non-UTF8.
89 let source = unsafe { String::from_utf8_unchecked(base64_image) };
90
91 Some(LanguageModelImage {
92 size: image_size,
93 source: source.into(),
94 })
95 })
96 }
97
98 pub fn estimate_tokens(&self) -> usize {
99 let width = self.size.width.0.unsigned_abs() as usize;
100 let height = self.size.height.0.unsigned_abs() as usize;
101
102 // From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
103 // Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
104 // so this method is more of a rough guess.
105 (width * height) / 750
106 }
107}
108
109fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
110 let mut base64_image = Vec::new();
111 {
112 let mut base64_encoder = EncoderWriter::new(
113 Cursor::new(&mut base64_image),
114 &base64::engine::general_purpose::STANDARD,
115 );
116 if data.format() == ImageFormat::Png {
117 base64_encoder.write_all(data.bytes())?;
118 } else {
119 let mut png = Vec::new();
120 image.write_with_encoder(PngEncoder::new(&mut png))?;
121
122 base64_encoder.write_all(png.as_slice())?;
123 }
124 }
125 Ok(base64_image)
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
129pub struct LanguageModelToolResult {
130 pub tool_use_id: LanguageModelToolUseId,
131 pub tool_name: Arc<str>,
132 pub is_error: bool,
133 pub content: Arc<str>,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
137pub enum MessageContent {
138 Text(String),
139 Thinking {
140 text: String,
141 signature: Option<String>,
142 },
143 RedactedThinking(Vec<u8>),
144 Image(LanguageModelImage),
145 ToolUse(LanguageModelToolUse),
146 ToolResult(LanguageModelToolResult),
147}
148
149impl From<String> for MessageContent {
150 fn from(value: String) -> Self {
151 MessageContent::Text(value)
152 }
153}
154
155impl From<&str> for MessageContent {
156 fn from(value: &str) -> Self {
157 MessageContent::Text(value.to_string())
158 }
159}
160
161#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
162pub struct LanguageModelRequestMessage {
163 pub role: Role,
164 pub content: Vec<MessageContent>,
165 pub cache: bool,
166}
167
168impl LanguageModelRequestMessage {
169 pub fn string_contents(&self) -> String {
170 let mut buffer = String::new();
171 for string in self.content.iter().filter_map(|content| match content {
172 MessageContent::Text(text) => Some(text.as_str()),
173 MessageContent::Thinking { text, .. } => Some(text.as_str()),
174 MessageContent::RedactedThinking(_) => None,
175 MessageContent::ToolResult(tool_result) => Some(tool_result.content.as_ref()),
176 MessageContent::ToolUse(_) | MessageContent::Image(_) => None,
177 }) {
178 buffer.push_str(string);
179 }
180
181 buffer
182 }
183
184 pub fn contents_empty(&self) -> bool {
185 self.content.iter().all(|content| match content {
186 MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()),
187 MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()),
188 MessageContent::ToolResult(tool_result) => {
189 tool_result.content.chars().all(|c| c.is_whitespace())
190 }
191 MessageContent::RedactedThinking(_)
192 | MessageContent::ToolUse(_)
193 | MessageContent::Image(_) => false,
194 })
195 }
196}
197
198#[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)]
199pub struct LanguageModelRequestTool {
200 pub name: String,
201 pub description: String,
202 pub input_schema: serde_json::Value,
203}
204
205#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
206pub struct LanguageModelRequest {
207 pub thread_id: Option<String>,
208 pub prompt_id: Option<String>,
209 pub mode: Option<CompletionMode>,
210 pub messages: Vec<LanguageModelRequestMessage>,
211 pub tools: Vec<LanguageModelRequestTool>,
212 pub stop: Vec<String>,
213 pub temperature: Option<f32>,
214}
215
216#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
217pub struct LanguageModelResponseMessage {
218 pub role: Option<Role>,
219 pub content: Option<String>,
220}