context.rs

  1use std::{ops::Range, sync::Arc};
  2
  3use gpui::{App, Entity, SharedString};
  4use language::{Buffer, File};
  5use language_model::LanguageModelRequestMessage;
  6use project::ProjectPath;
  7use serde::{Deserialize, Serialize};
  8use text::{Anchor, BufferId};
  9use ui::IconName;
 10use util::post_inc;
 11
 12use crate::thread::Thread;
 13
 14#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 15pub struct ContextId(pub(crate) usize);
 16
 17impl ContextId {
 18    pub fn post_inc(&mut self) -> Self {
 19        Self(post_inc(&mut self.0))
 20    }
 21}
 22pub enum ContextKind {
 23    File,
 24    Directory,
 25    Symbol,
 26    FetchedUrl,
 27    Thread,
 28}
 29
 30impl ContextKind {
 31    pub fn icon(&self) -> IconName {
 32        match self {
 33            ContextKind::File => IconName::File,
 34            ContextKind::Directory => IconName::Folder,
 35            ContextKind::Symbol => IconName::Code,
 36            ContextKind::FetchedUrl => IconName::Globe,
 37            ContextKind::Thread => IconName::MessageBubbles,
 38        }
 39    }
 40}
 41
 42#[derive(Debug, Clone)]
 43pub enum AssistantContext {
 44    File(FileContext),
 45    Directory(DirectoryContext),
 46    Symbol(SymbolContext),
 47    FetchedUrl(FetchedUrlContext),
 48    Thread(ThreadContext),
 49}
 50
 51impl AssistantContext {
 52    pub fn id(&self) -> ContextId {
 53        match self {
 54            Self::File(file) => file.id,
 55            Self::Directory(directory) => directory.id,
 56            Self::Symbol(symbol) => symbol.id,
 57            Self::FetchedUrl(url) => url.id,
 58            Self::Thread(thread) => thread.id,
 59        }
 60    }
 61}
 62
 63#[derive(Debug, Clone)]
 64pub struct FileContext {
 65    pub id: ContextId,
 66    pub context_buffer: ContextBuffer,
 67}
 68
 69#[derive(Debug, Clone)]
 70pub struct DirectoryContext {
 71    pub id: ContextId,
 72    pub project_path: ProjectPath,
 73    pub context_buffers: Vec<ContextBuffer>,
 74}
 75
 76#[derive(Debug, Clone)]
 77pub struct SymbolContext {
 78    pub id: ContextId,
 79    pub context_symbol: ContextSymbol,
 80}
 81
 82#[derive(Debug, Clone)]
 83pub struct FetchedUrlContext {
 84    pub id: ContextId,
 85    pub url: SharedString,
 86    pub text: SharedString,
 87}
 88
 89// TODO: Model<Thread> holds onto the thread even if the thread is deleted. Can either handle this
 90// explicitly or have a WeakModel<Thread> and remove during snapshot.
 91
 92#[derive(Debug, Clone)]
 93pub struct ThreadContext {
 94    pub id: ContextId,
 95    pub thread: Entity<Thread>,
 96    pub text: SharedString,
 97}
 98
 99impl ThreadContext {
100    pub fn summary(&self, cx: &App) -> SharedString {
101        self.thread
102            .read(cx)
103            .summary()
104            .unwrap_or("New thread".into())
105    }
106}
107
108// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
109// the context from the message editor in this case.
110
111#[derive(Clone)]
112pub struct ContextBuffer {
113    pub id: BufferId,
114    pub buffer: Entity<Buffer>,
115    pub file: Arc<dyn File>,
116    pub version: clock::Global,
117    pub text: SharedString,
118}
119
120impl std::fmt::Debug for ContextBuffer {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("ContextBuffer")
123            .field("id", &self.id)
124            .field("buffer", &self.buffer)
125            .field("version", &self.version)
126            .field("text", &self.text)
127            .finish()
128    }
129}
130
131#[derive(Debug, Clone)]
132pub struct ContextSymbol {
133    pub id: ContextSymbolId,
134    pub buffer: Entity<Buffer>,
135    pub buffer_version: clock::Global,
136    /// The range that the symbol encloses, e.g. for function symbol, this will
137    /// include not only the signature, but also the body
138    pub enclosing_range: Range<Anchor>,
139    pub text: SharedString,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Hash)]
143pub struct ContextSymbolId {
144    pub path: ProjectPath,
145    pub name: SharedString,
146    pub range: Range<Anchor>,
147}
148
149/// Formats a collection of contexts into a string representation
150pub fn format_context_as_string<'a>(
151    contexts: impl Iterator<Item = &'a AssistantContext>,
152    cx: &App,
153) -> Option<String> {
154    let mut file_context = Vec::new();
155    let mut directory_context = Vec::new();
156    let mut symbol_context = Vec::new();
157    let mut fetch_context = Vec::new();
158    let mut thread_context = Vec::new();
159
160    for context in contexts {
161        match context {
162            AssistantContext::File(context) => file_context.push(context),
163            AssistantContext::Directory(context) => directory_context.push(context),
164            AssistantContext::Symbol(context) => symbol_context.push(context),
165            AssistantContext::FetchedUrl(context) => fetch_context.push(context),
166            AssistantContext::Thread(context) => thread_context.push(context),
167        }
168    }
169
170    if file_context.is_empty()
171        && directory_context.is_empty()
172        && symbol_context.is_empty()
173        && fetch_context.is_empty()
174        && thread_context.is_empty()
175    {
176        return None;
177    }
178
179    let mut result = String::new();
180    result.push_str("\n<context>\n\
181        The following items were attached by the user. You don't need to use other tools to read them.\n\n");
182
183    if !file_context.is_empty() {
184        result.push_str("<files>\n");
185        for context in file_context {
186            result.push_str(&context.context_buffer.text);
187        }
188        result.push_str("</files>\n");
189    }
190
191    if !directory_context.is_empty() {
192        result.push_str("<directories>\n");
193        for context in directory_context {
194            for context_buffer in &context.context_buffers {
195                result.push_str(&context_buffer.text);
196            }
197        }
198        result.push_str("</directories>\n");
199    }
200
201    if !symbol_context.is_empty() {
202        result.push_str("<symbols>\n");
203        for context in symbol_context {
204            result.push_str(&context.context_symbol.text);
205            result.push('\n');
206        }
207        result.push_str("</symbols>\n");
208    }
209
210    if !fetch_context.is_empty() {
211        result.push_str("<fetched_urls>\n");
212        for context in &fetch_context {
213            result.push_str(&context.url);
214            result.push('\n');
215            result.push_str(&context.text);
216            result.push('\n');
217        }
218        result.push_str("</fetched_urls>\n");
219    }
220
221    if !thread_context.is_empty() {
222        result.push_str("<conversation_threads>\n");
223        for context in &thread_context {
224            result.push_str(&context.summary(cx));
225            result.push('\n');
226            result.push_str(&context.text);
227            result.push('\n');
228        }
229        result.push_str("</conversation_threads>\n");
230    }
231
232    result.push_str("</context>\n");
233    Some(result)
234}
235
236pub fn attach_context_to_message<'a>(
237    message: &mut LanguageModelRequestMessage,
238    contexts: impl Iterator<Item = &'a AssistantContext>,
239    cx: &App,
240) {
241    if let Some(context_string) = format_context_as_string(contexts, cx) {
242        message.content.push(context_string.into());
243    }
244}