context.rs

  1use std::{ops::Range, path::Path, sync::Arc};
  2
  3use gpui::{App, Entity, SharedString};
  4use language::{Buffer, File};
  5use language_model::LanguageModelRequestMessage;
  6use project::{ProjectPath, Worktree};
  7use prompt_store::UserPromptId;
  8use rope::Point;
  9use serde::{Deserialize, Serialize};
 10use text::{Anchor, BufferId};
 11use ui::IconName;
 12use util::post_inc;
 13
 14use crate::thread::Thread;
 15
 16pub const RULES_ICON: IconName = IconName::Context;
 17
 18#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 19pub struct ContextId(pub(crate) usize);
 20
 21impl ContextId {
 22    pub fn post_inc(&mut self) -> Self {
 23        Self(post_inc(&mut self.0))
 24    }
 25}
 26
 27pub enum ContextKind {
 28    File,
 29    Directory,
 30    Symbol,
 31    Excerpt,
 32    FetchedUrl,
 33    Thread,
 34    Rules,
 35}
 36
 37impl ContextKind {
 38    pub fn icon(&self) -> IconName {
 39        match self {
 40            ContextKind::File => IconName::File,
 41            ContextKind::Directory => IconName::Folder,
 42            ContextKind::Symbol => IconName::Code,
 43            ContextKind::Excerpt => IconName::Code,
 44            ContextKind::FetchedUrl => IconName::Globe,
 45            ContextKind::Thread => IconName::MessageBubbles,
 46            ContextKind::Rules => RULES_ICON,
 47        }
 48    }
 49}
 50
 51#[derive(Debug, Clone)]
 52pub enum AssistantContext {
 53    File(FileContext),
 54    Directory(DirectoryContext),
 55    Symbol(SymbolContext),
 56    FetchedUrl(FetchedUrlContext),
 57    Thread(ThreadContext),
 58    Excerpt(ExcerptContext),
 59    Rules(RulesContext),
 60}
 61
 62impl AssistantContext {
 63    pub fn id(&self) -> ContextId {
 64        match self {
 65            Self::File(file) => file.id,
 66            Self::Directory(directory) => directory.id,
 67            Self::Symbol(symbol) => symbol.id,
 68            Self::FetchedUrl(url) => url.id,
 69            Self::Thread(thread) => thread.id,
 70            Self::Excerpt(excerpt) => excerpt.id,
 71            Self::Rules(rules) => rules.id,
 72        }
 73    }
 74}
 75
 76#[derive(Debug, Clone)]
 77pub struct FileContext {
 78    pub id: ContextId,
 79    pub context_buffer: ContextBuffer,
 80}
 81
 82#[derive(Debug, Clone)]
 83pub struct DirectoryContext {
 84    pub id: ContextId,
 85    pub worktree: Entity<Worktree>,
 86    pub path: Arc<Path>,
 87    /// Buffers of the files within the directory.
 88    pub context_buffers: Vec<ContextBuffer>,
 89}
 90
 91impl DirectoryContext {
 92    pub fn project_path(&self, cx: &App) -> ProjectPath {
 93        ProjectPath {
 94            worktree_id: self.worktree.read(cx).id(),
 95            path: self.path.clone(),
 96        }
 97    }
 98}
 99
100#[derive(Debug, Clone)]
101pub struct SymbolContext {
102    pub id: ContextId,
103    pub context_symbol: ContextSymbol,
104}
105
106#[derive(Debug, Clone)]
107pub struct FetchedUrlContext {
108    pub id: ContextId,
109    pub url: SharedString,
110    pub text: SharedString,
111}
112
113#[derive(Debug, Clone)]
114pub struct ThreadContext {
115    pub id: ContextId,
116    // TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
117    // a WeakEntity and handle removal from the UI when it has dropped.
118    pub thread: Entity<Thread>,
119    pub text: SharedString,
120}
121
122impl ThreadContext {
123    pub fn summary(&self, cx: &App) -> SharedString {
124        self.thread
125            .read(cx)
126            .summary()
127            .unwrap_or("New thread".into())
128    }
129}
130
131#[derive(Clone)]
132pub struct ContextBuffer {
133    pub id: BufferId,
134    // TODO: Entity<Buffer> holds onto the thread even if the thread is deleted. Should probably be
135    // a WeakEntity and handle removal from the UI when it has dropped.
136    pub buffer: Entity<Buffer>,
137    pub file: Arc<dyn File>,
138    pub version: clock::Global,
139    pub text: SharedString,
140}
141
142impl std::fmt::Debug for ContextBuffer {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        f.debug_struct("ContextBuffer")
145            .field("id", &self.id)
146            .field("buffer", &self.buffer)
147            .field("version", &self.version)
148            .field("text", &self.text)
149            .finish()
150    }
151}
152
153#[derive(Debug, Clone)]
154pub struct ContextSymbol {
155    pub id: ContextSymbolId,
156    pub buffer: Entity<Buffer>,
157    pub buffer_version: clock::Global,
158    /// The range that the symbol encloses, e.g. for function symbol, this will
159    /// include not only the signature, but also the body
160    pub enclosing_range: Range<Anchor>,
161    pub text: SharedString,
162}
163
164#[derive(Debug, Clone, PartialEq, Eq, Hash)]
165pub struct ContextSymbolId {
166    pub path: ProjectPath,
167    pub name: SharedString,
168    pub range: Range<Anchor>,
169}
170
171#[derive(Debug, Clone)]
172pub struct ExcerptContext {
173    pub id: ContextId,
174    pub range: Range<Anchor>,
175    pub line_range: Range<Point>,
176    pub context_buffer: ContextBuffer,
177}
178
179#[derive(Debug, Clone)]
180pub struct RulesContext {
181    pub id: ContextId,
182    pub prompt_id: UserPromptId,
183    pub title: SharedString,
184    pub text: SharedString,
185}
186
187/// Formats a collection of contexts into a string representation
188pub fn format_context_as_string<'a>(
189    contexts: impl Iterator<Item = &'a AssistantContext>,
190    cx: &App,
191) -> Option<String> {
192    let mut file_context = Vec::new();
193    let mut directory_context = Vec::new();
194    let mut symbol_context = Vec::new();
195    let mut excerpt_context = Vec::new();
196    let mut fetch_context = Vec::new();
197    let mut thread_context = Vec::new();
198    let mut rules_context = Vec::new();
199
200    for context in contexts {
201        match context {
202            AssistantContext::File(context) => file_context.push(context),
203            AssistantContext::Directory(context) => directory_context.push(context),
204            AssistantContext::Symbol(context) => symbol_context.push(context),
205            AssistantContext::Excerpt(context) => excerpt_context.push(context),
206            AssistantContext::FetchedUrl(context) => fetch_context.push(context),
207            AssistantContext::Thread(context) => thread_context.push(context),
208            AssistantContext::Rules(context) => rules_context.push(context),
209        }
210    }
211
212    if file_context.is_empty()
213        && directory_context.is_empty()
214        && symbol_context.is_empty()
215        && excerpt_context.is_empty()
216        && fetch_context.is_empty()
217        && thread_context.is_empty()
218        && rules_context.is_empty()
219    {
220        return None;
221    }
222
223    let mut result = String::new();
224    result.push_str("\n<context>\n\
225        The following items were attached by the user. You don't need to use other tools to read them.\n\n");
226
227    if !file_context.is_empty() {
228        result.push_str("<files>\n");
229        for context in file_context {
230            result.push_str(&context.context_buffer.text);
231        }
232        result.push_str("</files>\n");
233    }
234
235    if !directory_context.is_empty() {
236        result.push_str("<directories>\n");
237        for context in directory_context {
238            for context_buffer in &context.context_buffers {
239                result.push_str(&context_buffer.text);
240            }
241        }
242        result.push_str("</directories>\n");
243    }
244
245    if !symbol_context.is_empty() {
246        result.push_str("<symbols>\n");
247        for context in symbol_context {
248            result.push_str(&context.context_symbol.text);
249            result.push('\n');
250        }
251        result.push_str("</symbols>\n");
252    }
253
254    if !excerpt_context.is_empty() {
255        result.push_str("<excerpts>\n");
256        for context in excerpt_context {
257            result.push_str(&context.context_buffer.text);
258            result.push('\n');
259        }
260        result.push_str("</excerpts>\n");
261    }
262
263    if !fetch_context.is_empty() {
264        result.push_str("<fetched_urls>\n");
265        for context in &fetch_context {
266            result.push_str(&context.url);
267            result.push('\n');
268            result.push_str(&context.text);
269            result.push('\n');
270        }
271        result.push_str("</fetched_urls>\n");
272    }
273
274    if !thread_context.is_empty() {
275        result.push_str("<conversation_threads>\n");
276        for context in &thread_context {
277            result.push_str(&context.summary(cx));
278            result.push('\n');
279            result.push_str(&context.text);
280            result.push('\n');
281        }
282        result.push_str("</conversation_threads>\n");
283    }
284
285    if !rules_context.is_empty() {
286        result.push_str(
287            "<user_rules>\n\
288            The user has specified the following rules that should be applied:\n\n",
289        );
290        for context in &rules_context {
291            result.push_str(&context.text);
292            result.push('\n');
293        }
294        result.push_str("</user_rules>\n");
295    }
296
297    result.push_str("</context>\n");
298    Some(result)
299}
300
301pub fn attach_context_to_message<'a>(
302    message: &mut LanguageModelRequestMessage,
303    contexts: impl Iterator<Item = &'a AssistantContext>,
304    cx: &App,
305) {
306    if let Some(context_string) = format_context_as_string(contexts, cx) {
307        message.content.push(context_string.into());
308    }
309}