context.rs

  1use std::{
  2    ops::Range,
  3    path::{Path, PathBuf},
  4    sync::Arc,
  5};
  6
  7use gpui::{App, Entity, SharedString};
  8use language::Buffer;
  9use language_model::LanguageModelRequestMessage;
 10use project::{ProjectEntryId, ProjectPath, Worktree};
 11use prompt_store::UserPromptId;
 12use rope::Point;
 13use serde::{Deserialize, Serialize};
 14use text::{Anchor, BufferId};
 15use ui::IconName;
 16use util::post_inc;
 17
 18use crate::thread::Thread;
 19
 20pub const RULES_ICON: IconName = IconName::Context;
 21
 22#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 23pub struct ContextId(pub(crate) usize);
 24
 25impl ContextId {
 26    pub fn post_inc(&mut self) -> Self {
 27        Self(post_inc(&mut self.0))
 28    }
 29}
 30
 31pub enum ContextKind {
 32    File,
 33    Directory,
 34    Symbol,
 35    Excerpt,
 36    FetchedUrl,
 37    Thread,
 38    Rules,
 39}
 40
 41impl ContextKind {
 42    pub fn icon(&self) -> IconName {
 43        match self {
 44            ContextKind::File => IconName::File,
 45            ContextKind::Directory => IconName::Folder,
 46            ContextKind::Symbol => IconName::Code,
 47            ContextKind::Excerpt => IconName::Code,
 48            ContextKind::FetchedUrl => IconName::Globe,
 49            ContextKind::Thread => IconName::MessageBubbles,
 50            ContextKind::Rules => RULES_ICON,
 51        }
 52    }
 53}
 54
 55#[derive(Debug, Clone)]
 56pub enum AssistantContext {
 57    File(FileContext),
 58    Directory(DirectoryContext),
 59    Symbol(SymbolContext),
 60    FetchedUrl(FetchedUrlContext),
 61    Thread(ThreadContext),
 62    Excerpt(ExcerptContext),
 63    Rules(RulesContext),
 64}
 65
 66impl AssistantContext {
 67    pub fn id(&self) -> ContextId {
 68        match self {
 69            Self::File(file) => file.id,
 70            Self::Directory(directory) => directory.id,
 71            Self::Symbol(symbol) => symbol.id,
 72            Self::FetchedUrl(url) => url.id,
 73            Self::Thread(thread) => thread.id,
 74            Self::Excerpt(excerpt) => excerpt.id,
 75            Self::Rules(rules) => rules.id,
 76        }
 77    }
 78}
 79
 80#[derive(Debug, Clone)]
 81pub struct FileContext {
 82    pub id: ContextId,
 83    pub context_buffer: ContextBuffer,
 84}
 85
 86#[derive(Debug, Clone)]
 87pub struct DirectoryContext {
 88    pub id: ContextId,
 89    pub worktree: Entity<Worktree>,
 90    pub entry_id: ProjectEntryId,
 91    pub last_path: Arc<Path>,
 92    /// Buffers of the files within the directory.
 93    pub context_buffers: Vec<ContextBuffer>,
 94}
 95
 96impl DirectoryContext {
 97    pub fn entry<'a>(&self, cx: &'a App) -> Option<&'a project::Entry> {
 98        self.worktree.read(cx).entry_for_id(self.entry_id)
 99    }
100
101    pub fn project_path(&self, cx: &App) -> Option<ProjectPath> {
102        let worktree = self.worktree.read(cx);
103        worktree
104            .entry_for_id(self.entry_id)
105            .map(|entry| ProjectPath {
106                worktree_id: worktree.id(),
107                path: entry.path.clone(),
108            })
109    }
110}
111
112#[derive(Debug, Clone)]
113pub struct SymbolContext {
114    pub id: ContextId,
115    pub context_symbol: ContextSymbol,
116}
117
118#[derive(Debug, Clone)]
119pub struct FetchedUrlContext {
120    pub id: ContextId,
121    pub url: SharedString,
122    pub text: SharedString,
123}
124
125#[derive(Debug, Clone)]
126pub struct ThreadContext {
127    pub id: ContextId,
128    // TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
129    // a WeakEntity and handle removal from the UI when it has dropped.
130    pub thread: Entity<Thread>,
131    pub text: SharedString,
132}
133
134impl ThreadContext {
135    pub fn summary(&self, cx: &App) -> SharedString {
136        self.thread
137            .read(cx)
138            .summary()
139            .unwrap_or("New thread".into())
140    }
141}
142
143#[derive(Clone)]
144pub struct ContextBuffer {
145    pub id: BufferId,
146    // TODO: Entity<Buffer> holds onto the buffer even if the buffer is deleted. Should probably be
147    // a WeakEntity and handle removal from the UI when it has dropped.
148    pub buffer: Entity<Buffer>,
149    pub last_full_path: Arc<Path>,
150    pub version: clock::Global,
151    pub text: SharedString,
152}
153
154impl ContextBuffer {
155    pub fn full_path(&self, cx: &App) -> PathBuf {
156        let file = self.buffer.read(cx).file();
157        // Note that in practice file can't be `None` because it is present when this is created and
158        // there's no way for buffers to go from having a file to not.
159        file.map_or(self.last_full_path.to_path_buf(), |file| file.full_path(cx))
160    }
161}
162
163impl std::fmt::Debug for ContextBuffer {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        f.debug_struct("ContextBuffer")
166            .field("id", &self.id)
167            .field("buffer", &self.buffer)
168            .field("version", &self.version)
169            .field("text", &self.text)
170            .finish()
171    }
172}
173
174#[derive(Debug, Clone)]
175pub struct ContextSymbol {
176    pub id: ContextSymbolId,
177    pub buffer: Entity<Buffer>,
178    pub buffer_version: clock::Global,
179    /// The range that the symbol encloses, e.g. for function symbol, this will
180    /// include not only the signature, but also the body
181    pub enclosing_range: Range<Anchor>,
182    pub text: SharedString,
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Hash)]
186pub struct ContextSymbolId {
187    pub path: ProjectPath,
188    pub name: SharedString,
189    pub range: Range<Anchor>,
190}
191
192#[derive(Debug, Clone)]
193pub struct ExcerptContext {
194    pub id: ContextId,
195    pub range: Range<Anchor>,
196    pub line_range: Range<Point>,
197    pub context_buffer: ContextBuffer,
198}
199
200#[derive(Debug, Clone)]
201pub struct RulesContext {
202    pub id: ContextId,
203    pub prompt_id: UserPromptId,
204    pub title: SharedString,
205    pub text: SharedString,
206}
207
208/// Formats a collection of contexts into a string representation
209pub fn format_context_as_string<'a>(
210    contexts: impl Iterator<Item = &'a AssistantContext>,
211    cx: &App,
212) -> Option<String> {
213    let mut file_context = Vec::new();
214    let mut directory_context = Vec::new();
215    let mut symbol_context = Vec::new();
216    let mut excerpt_context = Vec::new();
217    let mut fetch_context = Vec::new();
218    let mut thread_context = Vec::new();
219    let mut rules_context = Vec::new();
220
221    for context in contexts {
222        match context {
223            AssistantContext::File(context) => file_context.push(context),
224            AssistantContext::Directory(context) => directory_context.push(context),
225            AssistantContext::Symbol(context) => symbol_context.push(context),
226            AssistantContext::Excerpt(context) => excerpt_context.push(context),
227            AssistantContext::FetchedUrl(context) => fetch_context.push(context),
228            AssistantContext::Thread(context) => thread_context.push(context),
229            AssistantContext::Rules(context) => rules_context.push(context),
230        }
231    }
232
233    if file_context.is_empty()
234        && directory_context.is_empty()
235        && symbol_context.is_empty()
236        && excerpt_context.is_empty()
237        && fetch_context.is_empty()
238        && thread_context.is_empty()
239        && rules_context.is_empty()
240    {
241        return None;
242    }
243
244    let mut result = String::new();
245    result.push_str("\n<context>\n\
246        The following items were attached by the user. You don't need to use other tools to read them.\n\n");
247
248    if !file_context.is_empty() {
249        result.push_str("<files>\n");
250        for context in file_context {
251            result.push_str(&context.context_buffer.text);
252        }
253        result.push_str("</files>\n");
254    }
255
256    if !directory_context.is_empty() {
257        result.push_str("<directories>\n");
258        for context in directory_context {
259            for context_buffer in &context.context_buffers {
260                result.push_str(&context_buffer.text);
261            }
262        }
263        result.push_str("</directories>\n");
264    }
265
266    if !symbol_context.is_empty() {
267        result.push_str("<symbols>\n");
268        for context in symbol_context {
269            result.push_str(&context.context_symbol.text);
270            result.push('\n');
271        }
272        result.push_str("</symbols>\n");
273    }
274
275    if !excerpt_context.is_empty() {
276        result.push_str("<excerpts>\n");
277        for context in excerpt_context {
278            result.push_str(&context.context_buffer.text);
279            result.push('\n');
280        }
281        result.push_str("</excerpts>\n");
282    }
283
284    if !fetch_context.is_empty() {
285        result.push_str("<fetched_urls>\n");
286        for context in &fetch_context {
287            result.push_str(&context.url);
288            result.push('\n');
289            result.push_str(&context.text);
290            result.push('\n');
291        }
292        result.push_str("</fetched_urls>\n");
293    }
294
295    if !thread_context.is_empty() {
296        result.push_str("<conversation_threads>\n");
297        for context in &thread_context {
298            result.push_str(&context.summary(cx));
299            result.push('\n');
300            result.push_str(&context.text);
301            result.push('\n');
302        }
303        result.push_str("</conversation_threads>\n");
304    }
305
306    if !rules_context.is_empty() {
307        result.push_str(
308            "<user_rules>\n\
309            The user has specified the following rules that should be applied:\n\n",
310        );
311        for context in &rules_context {
312            result.push_str(&context.text);
313            result.push('\n');
314        }
315        result.push_str("</user_rules>\n");
316    }
317
318    result.push_str("</context>\n");
319    Some(result)
320}
321
322pub fn attach_context_to_message<'a>(
323    message: &mut LanguageModelRequestMessage,
324    contexts: impl Iterator<Item = &'a AssistantContext>,
325    cx: &App,
326) {
327    if let Some(context_string) = format_context_as_string(contexts, cx) {
328        message.content.push(context_string.into());
329    }
330}