context.rs

  1use std::{ops::Range, path::Path, sync::Arc};
  2
  3use gpui::{App, Entity, SharedString};
  4use language::{Buffer, File};
  5use language_model::LanguageModelRequestMessage;
  6use project::{ProjectPath, Worktree};
  7use rope::Point;
  8use serde::{Deserialize, Serialize};
  9use text::{Anchor, BufferId};
 10use ui::IconName;
 11use util::post_inc;
 12
 13use crate::thread::Thread;
 14
 15#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 16pub struct ContextId(pub(crate) usize);
 17
 18impl ContextId {
 19    pub fn post_inc(&mut self) -> Self {
 20        Self(post_inc(&mut self.0))
 21    }
 22}
 23pub enum ContextKind {
 24    File,
 25    Directory,
 26    Symbol,
 27    Excerpt,
 28    FetchedUrl,
 29    Thread,
 30}
 31
 32impl ContextKind {
 33    pub fn icon(&self) -> IconName {
 34        match self {
 35            ContextKind::File => IconName::File,
 36            ContextKind::Directory => IconName::Folder,
 37            ContextKind::Symbol => IconName::Code,
 38            ContextKind::Excerpt => IconName::Code,
 39            ContextKind::FetchedUrl => IconName::Globe,
 40            ContextKind::Thread => IconName::MessageBubbles,
 41        }
 42    }
 43}
 44
 45#[derive(Debug, Clone)]
 46pub enum AssistantContext {
 47    File(FileContext),
 48    Directory(DirectoryContext),
 49    Symbol(SymbolContext),
 50    FetchedUrl(FetchedUrlContext),
 51    Thread(ThreadContext),
 52    Excerpt(ExcerptContext),
 53}
 54
 55impl AssistantContext {
 56    pub fn id(&self) -> ContextId {
 57        match self {
 58            Self::File(file) => file.id,
 59            Self::Directory(directory) => directory.id,
 60            Self::Symbol(symbol) => symbol.id,
 61            Self::FetchedUrl(url) => url.id,
 62            Self::Thread(thread) => thread.id,
 63            Self::Excerpt(excerpt) => excerpt.id,
 64        }
 65    }
 66}
 67
 68#[derive(Debug, Clone)]
 69pub struct FileContext {
 70    pub id: ContextId,
 71    pub context_buffer: ContextBuffer,
 72}
 73
 74#[derive(Debug, Clone)]
 75pub struct DirectoryContext {
 76    pub id: ContextId,
 77    pub worktree: Entity<Worktree>,
 78    pub path: Arc<Path>,
 79    /// Buffers of the files within the directory.
 80    pub context_buffers: Vec<ContextBuffer>,
 81}
 82
 83impl DirectoryContext {
 84    pub fn project_path(&self, cx: &App) -> ProjectPath {
 85        ProjectPath {
 86            worktree_id: self.worktree.read(cx).id(),
 87            path: self.path.clone(),
 88        }
 89    }
 90}
 91
 92#[derive(Debug, Clone)]
 93pub struct SymbolContext {
 94    pub id: ContextId,
 95    pub context_symbol: ContextSymbol,
 96}
 97
 98#[derive(Debug, Clone)]
 99pub struct FetchedUrlContext {
100    pub id: ContextId,
101    pub url: SharedString,
102    pub text: SharedString,
103}
104
105#[derive(Debug, Clone)]
106pub struct ThreadContext {
107    pub id: ContextId,
108    // TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
109    // a WeakEntity and handle removal from the UI when it has dropped.
110    pub thread: Entity<Thread>,
111    pub text: SharedString,
112}
113
114impl ThreadContext {
115    pub fn summary(&self, cx: &App) -> SharedString {
116        self.thread
117            .read(cx)
118            .summary()
119            .unwrap_or("New thread".into())
120    }
121}
122
123#[derive(Clone)]
124pub struct ContextBuffer {
125    pub id: BufferId,
126    // TODO: Entity<Buffer> holds onto the thread even if the thread is deleted. Should probably be
127    // a WeakEntity and handle removal from the UI when it has dropped.
128    pub buffer: Entity<Buffer>,
129    pub file: Arc<dyn File>,
130    pub version: clock::Global,
131    pub text: SharedString,
132}
133
134impl std::fmt::Debug for ContextBuffer {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("ContextBuffer")
137            .field("id", &self.id)
138            .field("buffer", &self.buffer)
139            .field("version", &self.version)
140            .field("text", &self.text)
141            .finish()
142    }
143}
144
145#[derive(Debug, Clone)]
146pub struct ContextSymbol {
147    pub id: ContextSymbolId,
148    pub buffer: Entity<Buffer>,
149    pub buffer_version: clock::Global,
150    /// The range that the symbol encloses, e.g. for function symbol, this will
151    /// include not only the signature, but also the body
152    pub enclosing_range: Range<Anchor>,
153    pub text: SharedString,
154}
155
156#[derive(Debug, Clone, PartialEq, Eq, Hash)]
157pub struct ContextSymbolId {
158    pub path: ProjectPath,
159    pub name: SharedString,
160    pub range: Range<Anchor>,
161}
162
163#[derive(Debug, Clone)]
164pub struct ExcerptContext {
165    pub id: ContextId,
166    pub range: Range<Anchor>,
167    pub line_range: Range<Point>,
168    pub context_buffer: ContextBuffer,
169}
170
171/// Formats a collection of contexts into a string representation
172pub fn format_context_as_string<'a>(
173    contexts: impl Iterator<Item = &'a AssistantContext>,
174    cx: &App,
175) -> Option<String> {
176    let mut file_context = Vec::new();
177    let mut directory_context = Vec::new();
178    let mut symbol_context = Vec::new();
179    let mut excerpt_context = Vec::new();
180    let mut fetch_context = Vec::new();
181    let mut thread_context = Vec::new();
182
183    for context in contexts {
184        match context {
185            AssistantContext::File(context) => file_context.push(context),
186            AssistantContext::Directory(context) => directory_context.push(context),
187            AssistantContext::Symbol(context) => symbol_context.push(context),
188            AssistantContext::Excerpt(context) => excerpt_context.push(context),
189            AssistantContext::FetchedUrl(context) => fetch_context.push(context),
190            AssistantContext::Thread(context) => thread_context.push(context),
191        }
192    }
193
194    if file_context.is_empty()
195        && directory_context.is_empty()
196        && symbol_context.is_empty()
197        && excerpt_context.is_empty()
198        && fetch_context.is_empty()
199        && thread_context.is_empty()
200    {
201        return None;
202    }
203
204    let mut result = String::new();
205    result.push_str("\n<context>\n\
206        The following items were attached by the user. You don't need to use other tools to read them.\n\n");
207
208    if !file_context.is_empty() {
209        result.push_str("<files>\n");
210        for context in file_context {
211            result.push_str(&context.context_buffer.text);
212        }
213        result.push_str("</files>\n");
214    }
215
216    if !directory_context.is_empty() {
217        result.push_str("<directories>\n");
218        for context in directory_context {
219            for context_buffer in &context.context_buffers {
220                result.push_str(&context_buffer.text);
221            }
222        }
223        result.push_str("</directories>\n");
224    }
225
226    if !symbol_context.is_empty() {
227        result.push_str("<symbols>\n");
228        for context in symbol_context {
229            result.push_str(&context.context_symbol.text);
230            result.push('\n');
231        }
232        result.push_str("</symbols>\n");
233    }
234
235    if !excerpt_context.is_empty() {
236        result.push_str("<excerpts>\n");
237        for context in excerpt_context {
238            result.push_str(&context.context_buffer.text);
239            result.push('\n');
240        }
241        result.push_str("</excerpts>\n");
242    }
243
244    if !fetch_context.is_empty() {
245        result.push_str("<fetched_urls>\n");
246        for context in &fetch_context {
247            result.push_str(&context.url);
248            result.push('\n');
249            result.push_str(&context.text);
250            result.push('\n');
251        }
252        result.push_str("</fetched_urls>\n");
253    }
254
255    if !thread_context.is_empty() {
256        result.push_str("<conversation_threads>\n");
257        for context in &thread_context {
258            result.push_str(&context.summary(cx));
259            result.push('\n');
260            result.push_str(&context.text);
261            result.push('\n');
262        }
263        result.push_str("</conversation_threads>\n");
264    }
265
266    result.push_str("</context>\n");
267    Some(result)
268}
269
270pub fn attach_context_to_message<'a>(
271    message: &mut LanguageModelRequestMessage,
272    contexts: impl Iterator<Item = &'a AssistantContext>,
273    cx: &App,
274) {
275    if let Some(context_string) = format_context_as_string(contexts, cx) {
276        message.content.push(context_string.into());
277    }
278}