context.rs

  1use std::{ops::Range, path::Path, sync::Arc};
  2
  3use gpui::{App, Entity, SharedString};
  4use language::{Buffer, File};
  5use language_model::LanguageModelRequestMessage;
  6use project::{ProjectPath, Worktree};
  7use serde::{Deserialize, Serialize};
  8use text::{Anchor, BufferId};
  9use ui::IconName;
 10use util::post_inc;
 11
 12use crate::thread::Thread;
 13
 14#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 15pub struct ContextId(pub(crate) usize);
 16
 17impl ContextId {
 18    pub fn post_inc(&mut self) -> Self {
 19        Self(post_inc(&mut self.0))
 20    }
 21}
 22pub enum ContextKind {
 23    File,
 24    Directory,
 25    Symbol,
 26    FetchedUrl,
 27    Thread,
 28}
 29
 30impl ContextKind {
 31    pub fn icon(&self) -> IconName {
 32        match self {
 33            ContextKind::File => IconName::File,
 34            ContextKind::Directory => IconName::Folder,
 35            ContextKind::Symbol => IconName::Code,
 36            ContextKind::FetchedUrl => IconName::Globe,
 37            ContextKind::Thread => IconName::MessageBubbles,
 38        }
 39    }
 40}
 41
 42#[derive(Debug, Clone)]
 43pub enum AssistantContext {
 44    File(FileContext),
 45    Directory(DirectoryContext),
 46    Symbol(SymbolContext),
 47    FetchedUrl(FetchedUrlContext),
 48    Thread(ThreadContext),
 49}
 50
 51impl AssistantContext {
 52    pub fn id(&self) -> ContextId {
 53        match self {
 54            Self::File(file) => file.id,
 55            Self::Directory(directory) => directory.id,
 56            Self::Symbol(symbol) => symbol.id,
 57            Self::FetchedUrl(url) => url.id,
 58            Self::Thread(thread) => thread.id,
 59        }
 60    }
 61}
 62
 63#[derive(Debug, Clone)]
 64pub struct FileContext {
 65    pub id: ContextId,
 66    pub context_buffer: ContextBuffer,
 67}
 68
 69#[derive(Debug, Clone)]
 70pub struct DirectoryContext {
 71    pub id: ContextId,
 72    pub worktree: Entity<Worktree>,
 73    pub path: Arc<Path>,
 74    /// Buffers of the files within the directory.
 75    pub context_buffers: Vec<ContextBuffer>,
 76}
 77
 78impl DirectoryContext {
 79    pub fn project_path(&self, cx: &App) -> ProjectPath {
 80        ProjectPath {
 81            worktree_id: self.worktree.read(cx).id(),
 82            path: self.path.clone(),
 83        }
 84    }
 85}
 86
 87#[derive(Debug, Clone)]
 88pub struct SymbolContext {
 89    pub id: ContextId,
 90    pub context_symbol: ContextSymbol,
 91}
 92
 93#[derive(Debug, Clone)]
 94pub struct FetchedUrlContext {
 95    pub id: ContextId,
 96    pub url: SharedString,
 97    pub text: SharedString,
 98}
 99
100#[derive(Debug, Clone)]
101pub struct ThreadContext {
102    pub id: ContextId,
103    // TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
104    // a WeakEntity and handle removal from the UI when it has dropped.
105    pub thread: Entity<Thread>,
106    pub text: SharedString,
107}
108
109impl ThreadContext {
110    pub fn summary(&self, cx: &App) -> SharedString {
111        self.thread
112            .read(cx)
113            .summary()
114            .unwrap_or("New thread".into())
115    }
116}
117
118#[derive(Clone)]
119pub struct ContextBuffer {
120    pub id: BufferId,
121    // TODO: Entity<Buffer> holds onto the thread even if the thread is deleted. Should probably be
122    // a WeakEntity and handle removal from the UI when it has dropped.
123    pub buffer: Entity<Buffer>,
124    pub file: Arc<dyn File>,
125    pub version: clock::Global,
126    pub text: SharedString,
127}
128
129impl std::fmt::Debug for ContextBuffer {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("ContextBuffer")
132            .field("id", &self.id)
133            .field("buffer", &self.buffer)
134            .field("version", &self.version)
135            .field("text", &self.text)
136            .finish()
137    }
138}
139
140#[derive(Debug, Clone)]
141pub struct ContextSymbol {
142    pub id: ContextSymbolId,
143    pub buffer: Entity<Buffer>,
144    pub buffer_version: clock::Global,
145    /// The range that the symbol encloses, e.g. for function symbol, this will
146    /// include not only the signature, but also the body
147    pub enclosing_range: Range<Anchor>,
148    pub text: SharedString,
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Hash)]
152pub struct ContextSymbolId {
153    pub path: ProjectPath,
154    pub name: SharedString,
155    pub range: Range<Anchor>,
156}
157
158/// Formats a collection of contexts into a string representation
159pub fn format_context_as_string<'a>(
160    contexts: impl Iterator<Item = &'a AssistantContext>,
161    cx: &App,
162) -> Option<String> {
163    let mut file_context = Vec::new();
164    let mut directory_context = Vec::new();
165    let mut symbol_context = Vec::new();
166    let mut fetch_context = Vec::new();
167    let mut thread_context = Vec::new();
168
169    for context in contexts {
170        match context {
171            AssistantContext::File(context) => file_context.push(context),
172            AssistantContext::Directory(context) => directory_context.push(context),
173            AssistantContext::Symbol(context) => symbol_context.push(context),
174            AssistantContext::FetchedUrl(context) => fetch_context.push(context),
175            AssistantContext::Thread(context) => thread_context.push(context),
176        }
177    }
178
179    if file_context.is_empty()
180        && directory_context.is_empty()
181        && symbol_context.is_empty()
182        && fetch_context.is_empty()
183        && thread_context.is_empty()
184    {
185        return None;
186    }
187
188    let mut result = String::new();
189    result.push_str("\n<context>\n\
190        The following items were attached by the user. You don't need to use other tools to read them.\n\n");
191
192    if !file_context.is_empty() {
193        result.push_str("<files>\n");
194        for context in file_context {
195            result.push_str(&context.context_buffer.text);
196        }
197        result.push_str("</files>\n");
198    }
199
200    if !directory_context.is_empty() {
201        result.push_str("<directories>\n");
202        for context in directory_context {
203            for context_buffer in &context.context_buffers {
204                result.push_str(&context_buffer.text);
205            }
206        }
207        result.push_str("</directories>\n");
208    }
209
210    if !symbol_context.is_empty() {
211        result.push_str("<symbols>\n");
212        for context in symbol_context {
213            result.push_str(&context.context_symbol.text);
214            result.push('\n');
215        }
216        result.push_str("</symbols>\n");
217    }
218
219    if !fetch_context.is_empty() {
220        result.push_str("<fetched_urls>\n");
221        for context in &fetch_context {
222            result.push_str(&context.url);
223            result.push('\n');
224            result.push_str(&context.text);
225            result.push('\n');
226        }
227        result.push_str("</fetched_urls>\n");
228    }
229
230    if !thread_context.is_empty() {
231        result.push_str("<conversation_threads>\n");
232        for context in &thread_context {
233            result.push_str(&context.summary(cx));
234            result.push('\n');
235            result.push_str(&context.text);
236            result.push('\n');
237        }
238        result.push_str("</conversation_threads>\n");
239    }
240
241    result.push_str("</context>\n");
242    Some(result)
243}
244
245pub fn attach_context_to_message<'a>(
246    message: &mut LanguageModelRequestMessage,
247    contexts: impl Iterator<Item = &'a AssistantContext>,
248    cx: &App,
249) {
250    if let Some(context_string) = format_context_as_string(contexts, cx) {
251        message.content.push(context_string.into());
252    }
253}