agent: Support images via @file and the file context picker (#29596)

Bennet Bo Fenner and Oleksiy Syvokon created

Release Notes:

- agent: Add support for @mentioning images
- agent: Add support for including images via file context picker

---------

Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>

Change summary

crates/agent/src/context.rs          |   1 
crates/agent/src/context_store.rs    |  77 +++++++++++++++--
crates/agent/src/message_editor.rs   |   2 
crates/agent/src/ui/context_pill.rs  |   3 
crates/language_model/src/request.rs | 129 +++++++++--------------------
crates/project/src/image_store.rs    |  49 ++++++-----
6 files changed, 140 insertions(+), 121 deletions(-)

Detailed changes

crates/agent/src/context.rs 🔗

@@ -630,6 +630,7 @@ impl Display for RulesContext {
 
 #[derive(Debug, Clone)]
 pub struct ImageContext {
+    pub project_path: Option<ProjectPath>,
     pub original_image: Arc<gpui::Image>,
     // TODO: handle this elsewhere and remove `ignore-interior-mutability` opt-out in clippy.toml
     // needed due to a false positive of `clippy::mutable_key_type`.

crates/agent/src/context_store.rs 🔗

@@ -9,6 +9,7 @@ use futures::{self, FutureExt};
 use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
 use language::Buffer;
 use language_model::LanguageModelImage;
+use project::image_store::is_image_file;
 use project::{Project, ProjectItem, ProjectPath, Symbol};
 use prompt_store::UserPromptId;
 use ref_cast::RefCast as _;
@@ -85,15 +86,19 @@ impl ContextStore {
             return Task::ready(Err(anyhow!("failed to read project")));
         };
 
-        cx.spawn(async move |this, cx| {
-            let open_buffer_task = project.update(cx, |project, cx| {
-                project.open_buffer(project_path.clone(), cx)
-            })?;
-            let buffer = open_buffer_task.await?;
-            this.update(cx, |this, cx| {
-                this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
+        if is_image_file(&project, &project_path, cx) {
+            self.add_image_from_path(project_path, remove_if_exists, cx)
+        } else {
+            cx.spawn(async move |this, cx| {
+                let open_buffer_task = project.update(cx, |project, cx| {
+                    project.open_buffer(project_path.clone(), cx)
+                })?;
+                let buffer = open_buffer_task.await?;
+                this.update(cx, |this, cx| {
+                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
+                })
             })
-        })
+        }
     }
 
     pub fn add_file_from_buffer(
@@ -272,13 +277,55 @@ impl ContextStore {
         self.insert_context(context, cx);
     }
 
-    pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
+    pub fn add_image_from_path(
+        &mut self,
+        project_path: ProjectPath,
+        remove_if_exists: bool,
+        cx: &mut Context<ContextStore>,
+    ) -> Task<Result<()>> {
+        let project = self.project.clone();
+        cx.spawn(async move |this, cx| {
+            let open_image_task = project.update(cx, |project, cx| {
+                project.open_image(project_path.clone(), cx)
+            })?;
+            let image_item = open_image_task.await?;
+            let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
+            this.update(cx, |this, cx| {
+                this.insert_image(
+                    Some(image_item.read(cx).project_path(cx)),
+                    image,
+                    remove_if_exists,
+                    cx,
+                );
+            })
+        })
+    }
+
+    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
+        self.insert_image(None, image, false, cx);
+    }
+
+    fn insert_image(
+        &mut self,
+        project_path: Option<ProjectPath>,
+        image: Arc<Image>,
+        remove_if_exists: bool,
+        cx: &mut Context<ContextStore>,
+    ) {
         let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
         let context = AgentContextHandle::Image(ImageContext {
+            project_path,
             original_image: image,
             image_task,
             context_id: self.next_context_id.post_inc(),
         });
+        if self.has_context(&context) {
+            if remove_if_exists {
+                self.remove_context(&context, cx);
+                return;
+            }
+        }
+
         self.insert_context(context, cx);
     }
 
@@ -373,6 +420,9 @@ impl ContextStore {
             AgentContextHandle::File(file_context) => {
                 FileInclusion::check_file(file_context, path, cx)
             }
+            AgentContextHandle::Image(image_context) => {
+                FileInclusion::check_image(image_context, path)
+            }
             AgentContextHandle::Directory(directory_context) => {
                 FileInclusion::check_directory(directory_context, path, project, cx)
             }
@@ -467,6 +517,15 @@ impl FileInclusion {
         }
     }
 
+    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
+        let image_path = image_context.project_path.as_ref()?;
+        if path == image_path {
+            Some(FileInclusion::Direct)
+        } else {
+            None
+        }
+    }
+
     fn check_directory(
         directory_context: &DirectoryContextHandle,
         path: &ProjectPath,

crates/agent/src/message_editor.rs 🔗

@@ -396,7 +396,7 @@ impl MessageEditor {
 
         self.context_store.update(cx, |store, cx| {
             for image in images {
-                store.add_image(Arc::new(image), cx);
+                store.add_image_instance(Arc::new(image), cx);
             }
         });
     }

crates/agent/src/ui/context_pill.rs 🔗

@@ -723,6 +723,7 @@ impl Component for AddedContext {
             "Ready",
             AddedContext::image(ImageContext {
                 context_id: next_context_id.post_inc(),
+                project_path: None,
                 original_image: Arc::new(Image::empty()),
                 image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
             }),
@@ -732,6 +733,7 @@ impl Component for AddedContext {
             "Loading",
             AddedContext::image(ImageContext {
                 context_id: next_context_id.post_inc(),
+                project_path: None,
                 original_image: Arc::new(Image::empty()),
                 image_task: cx
                     .background_spawn(async move {
@@ -746,6 +748,7 @@ impl Component for AddedContext {
             "Error",
             AddedContext::image(ImageContext {
                 context_id: next_context_id.post_inc(),
+                project_path: None,
                 original_image: Arc::new(Image::empty()),
                 image_task: Task::ready(None).shared(),
             }),

crates/language_model/src/request.rs 🔗

@@ -3,12 +3,13 @@ use std::sync::Arc;
 
 use crate::role::Role;
 use crate::{LanguageModelToolUse, LanguageModelToolUseId};
+use anyhow::Result;
 use base64::write::EncoderWriter;
 use gpui::{
-    App, AppContext as _, DevicePixels, Image, ObjectFit, RenderImage, SharedString, Size, Task,
+    App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
     point, px, size,
 };
-use image::{DynamicImage, ImageDecoder, codecs::png::PngEncoder, imageops::resize};
+use image::codecs::png::PngEncoder;
 use serde::{Deserialize, Serialize};
 use util::ResultExt;
 use zed_llm_client::CompletionMode;
@@ -42,26 +43,25 @@ impl LanguageModelImage {
 
     pub fn from_image(data: Arc<Image>, cx: &mut App) -> Task<Option<Self>> {
         cx.background_spawn(async move {
-            match data.format() {
-                gpui::ImageFormat::Png
-                | gpui::ImageFormat::Jpeg
-                | gpui::ImageFormat::Webp
-                | gpui::ImageFormat::Gif => {}
+            let image_bytes = Cursor::new(data.bytes());
+            let dynamic_image = match data.format() {
+                ImageFormat::Png => image::codecs::png::PngDecoder::new(image_bytes)
+                    .and_then(image::DynamicImage::from_decoder),
+                ImageFormat::Jpeg => image::codecs::jpeg::JpegDecoder::new(image_bytes)
+                    .and_then(image::DynamicImage::from_decoder),
+                ImageFormat::Webp => image::codecs::webp::WebPDecoder::new(image_bytes)
+                    .and_then(image::DynamicImage::from_decoder),
+                ImageFormat::Gif => image::codecs::gif::GifDecoder::new(image_bytes)
+                    .and_then(image::DynamicImage::from_decoder),
                 _ => return None,
-            };
+            }
+            .log_err()?;
 
-            let image = image::codecs::png::PngDecoder::new(Cursor::new(data.bytes())).log_err()?;
-            let (width, height) = image.dimensions();
+            let width = dynamic_image.width();
+            let height = dynamic_image.height();
             let image_size = size(DevicePixels(width as i32), DevicePixels(height as i32));
 
-            let mut base64_image = Vec::new();
-
-            {
-                let mut base64_encoder = EncoderWriter::new(
-                    Cursor::new(&mut base64_image),
-                    &base64::engine::general_purpose::STANDARD,
-                );
-
+            let base64_image = {
                 if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
                     || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
                 {
@@ -72,22 +72,18 @@ impl LanguageModelImage {
                         },
                         image_size,
                     );
-                    let image = DynamicImage::from_decoder(image).log_err()?.resize(
+                    let resized_image = dynamic_image.resize(
                         new_bounds.size.width.0 as u32,
                         new_bounds.size.height.0 as u32,
                         image::imageops::FilterType::Triangle,
                     );
 
-                    let mut png = Vec::new();
-                    image
-                        .write_with_encoder(PngEncoder::new(&mut png))
-                        .log_err()?;
-
-                    base64_encoder.write_all(png.as_slice()).log_err()?;
+                    encode_as_base64(data, resized_image)
                 } else {
-                    base64_encoder.write_all(data.bytes()).log_err()?;
+                    encode_as_base64(data, dynamic_image)
                 }
             }
+            .log_err()?;
 
             // SAFETY: The base64 encoder should not produce non-UTF8.
             let source = unsafe { String::from_utf8_unchecked(base64_image) };
@@ -99,68 +95,6 @@ impl LanguageModelImage {
         })
     }
 
-    /// Resolves image into an LLM-ready format (base64).
-    pub fn from_render_image(data: &RenderImage) -> Option<Self> {
-        let image_size = data.size(0);
-
-        let mut bytes = data.as_bytes(0).unwrap_or(&[]).to_vec();
-        // Convert from BGRA to RGBA.
-        for pixel in bytes.chunks_exact_mut(4) {
-            pixel.swap(2, 0);
-        }
-        let mut image = image::RgbaImage::from_vec(
-            image_size.width.0 as u32,
-            image_size.height.0 as u32,
-            bytes,
-        )
-        .expect("We already know this works");
-
-        // https://docs.anthropic.com/en/docs/build-with-claude/vision
-        if image_size.width.0 > ANTHROPIC_SIZE_LIMT as i32
-            || image_size.height.0 > ANTHROPIC_SIZE_LIMT as i32
-        {
-            let new_bounds = ObjectFit::ScaleDown.get_bounds(
-                gpui::Bounds {
-                    origin: point(px(0.0), px(0.0)),
-                    size: size(px(ANTHROPIC_SIZE_LIMT), px(ANTHROPIC_SIZE_LIMT)),
-                },
-                image_size,
-            );
-
-            image = resize(
-                &image,
-                new_bounds.size.width.0 as u32,
-                new_bounds.size.height.0 as u32,
-                image::imageops::FilterType::Triangle,
-            );
-        }
-
-        let mut png = Vec::new();
-
-        image
-            .write_with_encoder(PngEncoder::new(&mut png))
-            .log_err()?;
-
-        let mut base64_image = Vec::new();
-
-        {
-            let mut base64_encoder = EncoderWriter::new(
-                Cursor::new(&mut base64_image),
-                &base64::engine::general_purpose::STANDARD,
-            );
-
-            base64_encoder.write_all(png.as_slice()).log_err()?;
-        }
-
-        // SAFETY: The base64 encoder should not produce non-UTF8.
-        let source = unsafe { String::from_utf8_unchecked(base64_image) };
-
-        Some(LanguageModelImage {
-            size: image_size,
-            source: source.into(),
-        })
-    }
-
     pub fn estimate_tokens(&self) -> usize {
         let width = self.size.width.0.unsigned_abs() as usize;
         let height = self.size.height.0.unsigned_abs() as usize;
@@ -172,6 +106,25 @@ impl LanguageModelImage {
     }
 }
 
+fn encode_as_base64(data: Arc<Image>, image: image::DynamicImage) -> Result<Vec<u8>> {
+    let mut base64_image = Vec::new();
+    {
+        let mut base64_encoder = EncoderWriter::new(
+            Cursor::new(&mut base64_image),
+            &base64::engine::general_purpose::STANDARD,
+        );
+        if data.format() == ImageFormat::Png {
+            base64_encoder.write_all(data.bytes())?;
+        } else {
+            let mut png = Vec::new();
+            image.write_with_encoder(PngEncoder::new(&mut png))?;
+
+            base64_encoder.write_all(png.as_slice())?;
+        }
+    }
+    Ok(base64_image)
+}
+
 #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
 pub struct LanguageModelToolResult {
     pub tool_use_id: LanguageModelToolUseId,

crates/project/src/image_store.rs 🔗

@@ -210,38 +210,41 @@ impl ImageItem {
     }
 }
 
-impl ProjectItem for ImageItem {
-    fn try_open(
-        project: &Entity<Project>,
-        path: &ProjectPath,
-        cx: &mut App,
-    ) -> Option<Task<gpui::Result<Entity<Self>>>> {
-        let path = path.clone();
-        let project = project.clone();
-
+pub fn is_image_file(project: &Entity<Project>, path: &ProjectPath, cx: &App) -> bool {
+    let ext = util::maybe!({
         let worktree_abs_path = project
             .read(cx)
             .worktree_for_id(path.worktree_id, cx)?
             .read(cx)
             .abs_path();
-
-        // Resolve the file extension from either the worktree path (if it's a single file)
-        // or from the project path's subpath.
-        let ext = worktree_abs_path
+        worktree_abs_path
             .extension()
             .or_else(|| path.path.extension())
             .and_then(OsStr::to_str)
             .map(str::to_lowercase)
-            .unwrap_or_default();
-        let ext = ext.as_str();
-
-        // Only open the item if it's a binary image (no SVGs, etc.)
-        // Since we do not have a way to toggle to an editor
-        if Img::extensions().contains(&ext) && !ext.contains("svg") {
-            Some(cx.spawn(async move |cx| {
-                project
-                    .update(cx, |project, cx| project.open_image(path, cx))?
-                    .await
+    });
+
+    match ext {
+        Some(ext) => Img::extensions().contains(&ext.as_str()) && !ext.contains("svg"),
+        None => false,
+    }
+}
+
+impl ProjectItem for ImageItem {
+    fn try_open(
+        project: &Entity<Project>,
+        path: &ProjectPath,
+        cx: &mut App,
+    ) -> Option<Task<gpui::Result<Entity<Self>>>> {
+        if is_image_file(&project, &path, cx) {
+            Some(cx.spawn({
+                let path = path.clone();
+                let project = project.clone();
+                async move |cx| {
+                    project
+                        .update(cx, |project, cx| project.open_image(path, cx))?
+                        .await
+                }
             }))
         } else {
             None