context.rs

  1use std::{
  2    ops::Range,
  3    path::{Path, PathBuf},
  4    sync::Arc,
  5};
  6
  7use futures::{FutureExt, future::Shared};
  8use gpui::{App, Entity, SharedString, Task};
  9use language::Buffer;
 10use language_model::{LanguageModelImage, LanguageModelRequestMessage};
 11use project::{ProjectEntryId, ProjectPath, Worktree};
 12use prompt_store::UserPromptId;
 13use rope::Point;
 14use serde::{Deserialize, Serialize};
 15use text::{Anchor, BufferId};
 16use ui::IconName;
 17use util::post_inc;
 18
 19use crate::thread::Thread;
 20
 21pub const RULES_ICON: IconName = IconName::Context;
 22
 23#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 24pub struct ContextId(pub(crate) usize);
 25
 26impl ContextId {
 27    pub fn post_inc(&mut self) -> Self {
 28        Self(post_inc(&mut self.0))
 29    }
 30}
 31
 32pub enum ContextKind {
 33    File,
 34    Directory,
 35    Symbol,
 36    Selection,
 37    FetchedUrl,
 38    Thread,
 39    Rules,
 40    Image,
 41}
 42
 43impl ContextKind {
 44    pub fn icon(&self) -> IconName {
 45        match self {
 46            ContextKind::File => IconName::File,
 47            ContextKind::Directory => IconName::Folder,
 48            ContextKind::Symbol => IconName::Code,
 49            ContextKind::Selection => IconName::Context,
 50            ContextKind::FetchedUrl => IconName::Globe,
 51            ContextKind::Thread => IconName::MessageBubbles,
 52            ContextKind::Rules => RULES_ICON,
 53            ContextKind::Image => IconName::Image,
 54        }
 55    }
 56}
 57
 58#[derive(Debug, Clone)]
 59pub enum AssistantContext {
 60    File(FileContext),
 61    Directory(DirectoryContext),
 62    Symbol(SymbolContext),
 63    FetchedUrl(FetchedUrlContext),
 64    Thread(ThreadContext),
 65    Selection(SelectionContext),
 66    Rules(RulesContext),
 67    Image(ImageContext),
 68}
 69
 70impl AssistantContext {
 71    pub fn id(&self) -> ContextId {
 72        match self {
 73            Self::File(file) => file.id,
 74            Self::Directory(directory) => directory.id,
 75            Self::Symbol(symbol) => symbol.id,
 76            Self::FetchedUrl(url) => url.id,
 77            Self::Thread(thread) => thread.id,
 78            Self::Selection(selection) => selection.id,
 79            Self::Rules(rules) => rules.id,
 80            Self::Image(image) => image.id,
 81        }
 82    }
 83}
 84
 85#[derive(Debug, Clone)]
 86pub struct FileContext {
 87    pub id: ContextId,
 88    pub context_buffer: ContextBuffer,
 89}
 90
 91#[derive(Debug, Clone)]
 92pub struct DirectoryContext {
 93    pub id: ContextId,
 94    pub worktree: Entity<Worktree>,
 95    pub entry_id: ProjectEntryId,
 96    pub last_path: Arc<Path>,
 97    /// Buffers of the files within the directory.
 98    pub context_buffers: Vec<ContextBuffer>,
 99}
100
101impl DirectoryContext {
102    pub fn entry<'a>(&self, cx: &'a App) -> Option<&'a project::Entry> {
103        self.worktree.read(cx).entry_for_id(self.entry_id)
104    }
105
106    pub fn project_path(&self, cx: &App) -> Option<ProjectPath> {
107        let worktree = self.worktree.read(cx);
108        worktree
109            .entry_for_id(self.entry_id)
110            .map(|entry| ProjectPath {
111                worktree_id: worktree.id(),
112                path: entry.path.clone(),
113            })
114    }
115}
116
117#[derive(Debug, Clone)]
118pub struct SymbolContext {
119    pub id: ContextId,
120    pub context_symbol: ContextSymbol,
121}
122
123#[derive(Debug, Clone)]
124pub struct FetchedUrlContext {
125    pub id: ContextId,
126    pub url: SharedString,
127    pub text: SharedString,
128}
129
130#[derive(Debug, Clone)]
131pub struct ThreadContext {
132    pub id: ContextId,
133    // TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
134    // a WeakEntity and handle removal from the UI when it has dropped.
135    pub thread: Entity<Thread>,
136    pub text: SharedString,
137}
138
139impl ThreadContext {
140    pub fn summary(&self, cx: &App) -> SharedString {
141        self.thread
142            .read(cx)
143            .summary()
144            .unwrap_or("New thread".into())
145    }
146}
147
148#[derive(Debug, Clone)]
149pub struct ImageContext {
150    pub id: ContextId,
151    pub original_image: Arc<gpui::Image>,
152    pub image_task: Shared<Task<Option<LanguageModelImage>>>,
153}
154
155impl ImageContext {
156    pub fn image(&self) -> Option<LanguageModelImage> {
157        self.image_task.clone().now_or_never().flatten()
158    }
159
160    pub fn is_loading(&self) -> bool {
161        self.image_task.clone().now_or_never().is_none()
162    }
163
164    pub fn is_error(&self) -> bool {
165        self.image_task
166            .clone()
167            .now_or_never()
168            .map(|result| result.is_none())
169            .unwrap_or(false)
170    }
171}
172
173#[derive(Clone)]
174pub struct ContextBuffer {
175    pub id: BufferId,
176    // TODO: Entity<Buffer> holds onto the buffer even if the buffer is deleted. Should probably be
177    // a WeakEntity and handle removal from the UI when it has dropped.
178    pub buffer: Entity<Buffer>,
179    pub last_full_path: Arc<Path>,
180    pub version: clock::Global,
181    pub text: SharedString,
182}
183
184impl ContextBuffer {
185    pub fn full_path(&self, cx: &App) -> PathBuf {
186        let file = self.buffer.read(cx).file();
187        // Note that in practice file can't be `None` because it is present when this is created and
188        // there's no way for buffers to go from having a file to not.
189        file.map_or(self.last_full_path.to_path_buf(), |file| file.full_path(cx))
190    }
191}
192
193impl std::fmt::Debug for ContextBuffer {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        f.debug_struct("ContextBuffer")
196            .field("id", &self.id)
197            .field("buffer", &self.buffer)
198            .field("version", &self.version)
199            .field("text", &self.text)
200            .finish()
201    }
202}
203
204#[derive(Debug, Clone)]
205pub struct ContextSymbol {
206    pub id: ContextSymbolId,
207    pub buffer: Entity<Buffer>,
208    pub buffer_version: clock::Global,
209    /// The range that the symbol encloses, e.g. for function symbol, this will
210    /// include not only the signature, but also the body
211    pub enclosing_range: Range<Anchor>,
212    pub text: SharedString,
213}
214
215#[derive(Debug, Clone, PartialEq, Eq, Hash)]
216pub struct ContextSymbolId {
217    pub path: ProjectPath,
218    pub name: SharedString,
219    pub range: Range<Anchor>,
220}
221
222#[derive(Debug, Clone)]
223pub struct SelectionContext {
224    pub id: ContextId,
225    pub range: Range<Anchor>,
226    pub line_range: Range<Point>,
227    pub context_buffer: ContextBuffer,
228}
229
230#[derive(Debug, Clone)]
231pub struct RulesContext {
232    pub id: ContextId,
233    pub prompt_id: UserPromptId,
234    pub title: SharedString,
235    pub text: SharedString,
236}
237
238/// Formats a collection of contexts into a string representation
239pub fn format_context_as_string<'a>(
240    contexts: impl Iterator<Item = &'a AssistantContext>,
241    cx: &App,
242) -> Option<String> {
243    let mut file_context = Vec::new();
244    let mut directory_context = Vec::new();
245    let mut symbol_context = Vec::new();
246    let mut selection_context = Vec::new();
247    let mut fetch_context = Vec::new();
248    let mut thread_context = Vec::new();
249    let mut rules_context = Vec::new();
250
251    for context in contexts {
252        match context {
253            AssistantContext::File(context) => file_context.push(context),
254            AssistantContext::Directory(context) => directory_context.push(context),
255            AssistantContext::Symbol(context) => symbol_context.push(context),
256            AssistantContext::Selection(context) => selection_context.push(context),
257            AssistantContext::FetchedUrl(context) => fetch_context.push(context),
258            AssistantContext::Thread(context) => thread_context.push(context),
259            AssistantContext::Rules(context) => rules_context.push(context),
260            AssistantContext::Image(_) => {}
261        }
262    }
263
264    if file_context.is_empty()
265        && directory_context.is_empty()
266        && symbol_context.is_empty()
267        && selection_context.is_empty()
268        && fetch_context.is_empty()
269        && thread_context.is_empty()
270        && rules_context.is_empty()
271    {
272        return None;
273    }
274
275    let mut result = String::new();
276    result.push_str("\n<context>\n\
277        The following items were attached by the user. You don't need to use other tools to read them.\n\n");
278
279    if !file_context.is_empty() {
280        result.push_str("<files>\n");
281        for context in file_context {
282            result.push_str(&context.context_buffer.text);
283        }
284        result.push_str("</files>\n");
285    }
286
287    if !directory_context.is_empty() {
288        result.push_str("<directories>\n");
289        for context in directory_context {
290            for context_buffer in &context.context_buffers {
291                result.push_str(&context_buffer.text);
292            }
293        }
294        result.push_str("</directories>\n");
295    }
296
297    if !symbol_context.is_empty() {
298        result.push_str("<symbols>\n");
299        for context in symbol_context {
300            result.push_str(&context.context_symbol.text);
301            result.push('\n');
302        }
303        result.push_str("</symbols>\n");
304    }
305
306    if !selection_context.is_empty() {
307        result.push_str("<selections>\n");
308        for context in selection_context {
309            result.push_str(&context.context_buffer.text);
310            result.push('\n');
311        }
312        result.push_str("</selections>\n");
313    }
314
315    if !fetch_context.is_empty() {
316        result.push_str("<fetched_urls>\n");
317        for context in &fetch_context {
318            result.push_str(&context.url);
319            result.push('\n');
320            result.push_str(&context.text);
321            result.push('\n');
322        }
323        result.push_str("</fetched_urls>\n");
324    }
325
326    if !thread_context.is_empty() {
327        result.push_str("<conversation_threads>\n");
328        for context in &thread_context {
329            result.push_str(&context.summary(cx));
330            result.push('\n');
331            result.push_str(&context.text);
332            result.push('\n');
333        }
334        result.push_str("</conversation_threads>\n");
335    }
336
337    if !rules_context.is_empty() {
338        result.push_str(
339            "<user_rules>\n\
340            The user has specified the following rules that should be applied:\n\n",
341        );
342        for context in &rules_context {
343            result.push_str(&context.text);
344            result.push('\n');
345        }
346        result.push_str("</user_rules>\n");
347    }
348
349    result.push_str("</context>\n");
350    Some(result)
351}
352
353pub fn attach_context_to_message<'a>(
354    message: &mut LanguageModelRequestMessage,
355    contexts: impl Iterator<Item = &'a AssistantContext>,
356    cx: &App,
357) {
358    if let Some(context_string) = format_context_as_string(contexts, cx) {
359        message.content.push(context_string.into());
360    }
361}