context.rs

  1use std::{ops::Range, sync::Arc};
  2
  3use gpui::{App, Entity, SharedString};
  4use language::{Buffer, File};
  5use language_model::{LanguageModelRequestMessage, MessageContent};
  6use project::ProjectPath;
  7use serde::{Deserialize, Serialize};
  8use text::{Anchor, BufferId};
  9use ui::IconName;
 10use util::post_inc;
 11
 12use crate::thread::Thread;
 13
 14#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 15pub struct ContextId(pub(crate) usize);
 16
 17impl ContextId {
 18    pub fn post_inc(&mut self) -> Self {
 19        Self(post_inc(&mut self.0))
 20    }
 21}
 22
 23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 24pub enum ContextKind {
 25    File,
 26    Directory,
 27    Symbol,
 28    FetchedUrl,
 29    Thread,
 30}
 31
 32impl ContextKind {
 33    pub fn icon(&self) -> IconName {
 34        match self {
 35            ContextKind::File => IconName::File,
 36            ContextKind::Directory => IconName::Folder,
 37            ContextKind::Symbol => IconName::Code,
 38            ContextKind::FetchedUrl => IconName::Globe,
 39            ContextKind::Thread => IconName::MessageBubbles,
 40        }
 41    }
 42}
 43
 44#[derive(Debug, Clone)]
 45pub enum AssistantContext {
 46    File(FileContext),
 47    Directory(DirectoryContext),
 48    Symbol(SymbolContext),
 49    FetchedUrl(FetchedUrlContext),
 50    Thread(ThreadContext),
 51}
 52
 53impl AssistantContext {
 54    pub fn id(&self) -> ContextId {
 55        match self {
 56            Self::File(file) => file.id,
 57            Self::Directory(directory) => directory.id,
 58            Self::Symbol(symbol) => symbol.id,
 59            Self::FetchedUrl(url) => url.id,
 60            Self::Thread(thread) => thread.id,
 61        }
 62    }
 63}
 64
 65#[derive(Debug, Clone)]
 66pub struct FileContext {
 67    pub id: ContextId,
 68    pub context_buffer: ContextBuffer,
 69}
 70
 71#[derive(Debug, Clone)]
 72pub struct DirectoryContext {
 73    pub id: ContextId,
 74    pub project_path: ProjectPath,
 75    pub context_buffers: Vec<ContextBuffer>,
 76}
 77
 78#[derive(Debug, Clone)]
 79pub struct SymbolContext {
 80    pub id: ContextId,
 81    pub context_symbol: ContextSymbol,
 82}
 83
 84#[derive(Debug, Clone)]
 85pub struct FetchedUrlContext {
 86    pub id: ContextId,
 87    pub url: SharedString,
 88    pub text: SharedString,
 89}
 90
 91// TODO: Model<Thread> holds onto the thread even if the thread is deleted. Can either handle this
 92// explicitly or have a WeakModel<Thread> and remove during snapshot.
 93
 94#[derive(Debug, Clone)]
 95pub struct ThreadContext {
 96    pub id: ContextId,
 97    pub thread: Entity<Thread>,
 98    pub text: SharedString,
 99}
100
101impl ThreadContext {
102    pub fn summary(&self, cx: &App) -> SharedString {
103        self.thread
104            .read(cx)
105            .summary()
106            .unwrap_or("New thread".into())
107    }
108}
109
110// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
111// the context from the message editor in this case.
112
113#[derive(Clone)]
114pub struct ContextBuffer {
115    pub id: BufferId,
116    pub buffer: Entity<Buffer>,
117    pub file: Arc<dyn File>,
118    pub version: clock::Global,
119    pub text: SharedString,
120}
121
122impl std::fmt::Debug for ContextBuffer {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.debug_struct("ContextBuffer")
125            .field("id", &self.id)
126            .field("buffer", &self.buffer)
127            .field("version", &self.version)
128            .field("text", &self.text)
129            .finish()
130    }
131}
132
133#[derive(Debug, Clone)]
134pub struct ContextSymbol {
135    pub id: ContextSymbolId,
136    pub buffer: Entity<Buffer>,
137    pub buffer_version: clock::Global,
138    /// The range that the symbol encloses, e.g. for function symbol, this will
139    /// include not only the signature, but also the body
140    pub enclosing_range: Range<Anchor>,
141    pub text: SharedString,
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, Hash)]
145pub struct ContextSymbolId {
146    pub path: ProjectPath,
147    pub name: SharedString,
148    pub range: Range<Anchor>,
149}
150
151pub fn attach_context_to_message<'a>(
152    message: &mut LanguageModelRequestMessage,
153    contexts: impl Iterator<Item = &'a AssistantContext>,
154    cx: &App,
155) {
156    let mut file_context = Vec::new();
157    let mut directory_context = Vec::new();
158    let mut symbol_context = Vec::new();
159    let mut fetch_context = Vec::new();
160    let mut thread_context = Vec::new();
161
162    for context in contexts {
163        match context {
164            AssistantContext::File(context) => file_context.push(context),
165            AssistantContext::Directory(context) => directory_context.push(context),
166            AssistantContext::Symbol(context) => symbol_context.push(context),
167            AssistantContext::FetchedUrl(context) => fetch_context.push(context),
168            AssistantContext::Thread(context) => thread_context.push(context),
169        }
170    }
171
172    let mut context_chunks = Vec::new();
173
174    if !file_context.is_empty() {
175        context_chunks.push("The following files are available:\n");
176        for context in file_context {
177            context_chunks.push(&context.context_buffer.text);
178        }
179    }
180
181    if !directory_context.is_empty() {
182        context_chunks.push("The following directories are available:\n");
183        for context in directory_context {
184            for context_buffer in &context.context_buffers {
185                context_chunks.push(&context_buffer.text);
186            }
187        }
188    }
189
190    if !symbol_context.is_empty() {
191        context_chunks.push("The following symbols are available:\n");
192        for context in symbol_context {
193            context_chunks.push(&context.context_symbol.text);
194        }
195    }
196
197    if !fetch_context.is_empty() {
198        context_chunks.push("The following fetched results are available:\n");
199        for context in &fetch_context {
200            context_chunks.push(&context.url);
201            context_chunks.push(&context.text);
202        }
203    }
204
205    // Need to own the SharedString for summary so that it can be referenced.
206    let mut thread_context_chunks = Vec::new();
207    if !thread_context.is_empty() {
208        context_chunks.push("The following previous conversation threads are available:\n");
209        for context in &thread_context {
210            thread_context_chunks.push(context.summary(cx));
211            thread_context_chunks.push(context.text.clone());
212        }
213    }
214    for chunk in &thread_context_chunks {
215        context_chunks.push(chunk);
216    }
217
218    if !context_chunks.is_empty() {
219        message
220            .content
221            .push(MessageContent::Text(context_chunks.join("\n")));
222    }
223}