context.rs

  1use std::{ops::Range, sync::Arc};
  2
  3use gpui::{App, Entity, SharedString};
  4use language::{Buffer, File};
  5use language_model::{LanguageModelRequestMessage, MessageContent};
  6use project::ProjectPath;
  7use serde::{Deserialize, Serialize};
  8use text::{Anchor, BufferId};
  9use ui::IconName;
 10use util::post_inc;
 11
 12use crate::thread::Thread;
 13
 14#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 15pub struct ContextId(pub(crate) usize);
 16
 17impl ContextId {
 18    pub fn post_inc(&mut self) -> Self {
 19        Self(post_inc(&mut self.0))
 20    }
 21}
 22pub enum ContextKind {
 23    File,
 24    Directory,
 25    Symbol,
 26    FetchedUrl,
 27    Thread,
 28}
 29
 30impl ContextKind {
 31    pub fn icon(&self) -> IconName {
 32        match self {
 33            ContextKind::File => IconName::File,
 34            ContextKind::Directory => IconName::Folder,
 35            ContextKind::Symbol => IconName::Code,
 36            ContextKind::FetchedUrl => IconName::Globe,
 37            ContextKind::Thread => IconName::MessageBubbles,
 38        }
 39    }
 40}
 41
 42#[derive(Debug, Clone)]
 43pub enum AssistantContext {
 44    File(FileContext),
 45    Directory(DirectoryContext),
 46    Symbol(SymbolContext),
 47    FetchedUrl(FetchedUrlContext),
 48    Thread(ThreadContext),
 49}
 50
 51impl AssistantContext {
 52    pub fn id(&self) -> ContextId {
 53        match self {
 54            Self::File(file) => file.id,
 55            Self::Directory(directory) => directory.id,
 56            Self::Symbol(symbol) => symbol.id,
 57            Self::FetchedUrl(url) => url.id,
 58            Self::Thread(thread) => thread.id,
 59        }
 60    }
 61}
 62
 63#[derive(Debug, Clone)]
 64pub struct FileContext {
 65    pub id: ContextId,
 66    pub context_buffer: ContextBuffer,
 67}
 68
 69#[derive(Debug, Clone)]
 70pub struct DirectoryContext {
 71    pub id: ContextId,
 72    pub project_path: ProjectPath,
 73    pub context_buffers: Vec<ContextBuffer>,
 74}
 75
 76#[derive(Debug, Clone)]
 77pub struct SymbolContext {
 78    pub id: ContextId,
 79    pub context_symbol: ContextSymbol,
 80}
 81
 82#[derive(Debug, Clone)]
 83pub struct FetchedUrlContext {
 84    pub id: ContextId,
 85    pub url: SharedString,
 86    pub text: SharedString,
 87}
 88
 89// TODO: Model<Thread> holds onto the thread even if the thread is deleted. Can either handle this
 90// explicitly or have a WeakModel<Thread> and remove during snapshot.
 91
 92#[derive(Debug, Clone)]
 93pub struct ThreadContext {
 94    pub id: ContextId,
 95    pub thread: Entity<Thread>,
 96    pub text: SharedString,
 97}
 98
 99impl ThreadContext {
100    pub fn summary(&self, cx: &App) -> SharedString {
101        self.thread
102            .read(cx)
103            .summary()
104            .unwrap_or("New thread".into())
105    }
106}
107
108// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
109// the context from the message editor in this case.
110
111#[derive(Clone)]
112pub struct ContextBuffer {
113    pub id: BufferId,
114    pub buffer: Entity<Buffer>,
115    pub file: Arc<dyn File>,
116    pub version: clock::Global,
117    pub text: SharedString,
118}
119
120impl std::fmt::Debug for ContextBuffer {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("ContextBuffer")
123            .field("id", &self.id)
124            .field("buffer", &self.buffer)
125            .field("version", &self.version)
126            .field("text", &self.text)
127            .finish()
128    }
129}
130
131#[derive(Debug, Clone)]
132pub struct ContextSymbol {
133    pub id: ContextSymbolId,
134    pub buffer: Entity<Buffer>,
135    pub buffer_version: clock::Global,
136    /// The range that the symbol encloses, e.g. for function symbol, this will
137    /// include not only the signature, but also the body
138    pub enclosing_range: Range<Anchor>,
139    pub text: SharedString,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Hash)]
143pub struct ContextSymbolId {
144    pub path: ProjectPath,
145    pub name: SharedString,
146    pub range: Range<Anchor>,
147}
148
149pub fn attach_context_to_message<'a>(
150    message: &mut LanguageModelRequestMessage,
151    contexts: impl Iterator<Item = &'a AssistantContext>,
152    cx: &App,
153) {
154    let mut file_context = Vec::new();
155    let mut directory_context = Vec::new();
156    let mut symbol_context = Vec::new();
157    let mut fetch_context = Vec::new();
158    let mut thread_context = Vec::new();
159
160    for context in contexts {
161        match context {
162            AssistantContext::File(context) => file_context.push(context),
163            AssistantContext::Directory(context) => directory_context.push(context),
164            AssistantContext::Symbol(context) => symbol_context.push(context),
165            AssistantContext::FetchedUrl(context) => fetch_context.push(context),
166            AssistantContext::Thread(context) => thread_context.push(context),
167        }
168    }
169
170    let mut context_chunks = Vec::new();
171
172    if !file_context.is_empty() {
173        context_chunks.push("The following files are available:\n");
174        for context in file_context {
175            context_chunks.push(&context.context_buffer.text);
176        }
177    }
178
179    if !directory_context.is_empty() {
180        context_chunks.push("The following directories are available:\n");
181        for context in directory_context {
182            for context_buffer in &context.context_buffers {
183                context_chunks.push(&context_buffer.text);
184            }
185        }
186    }
187
188    if !symbol_context.is_empty() {
189        context_chunks.push("The following symbols are available:\n");
190        for context in symbol_context {
191            context_chunks.push(&context.context_symbol.text);
192        }
193    }
194
195    if !fetch_context.is_empty() {
196        context_chunks.push("The following fetched results are available:\n");
197        for context in &fetch_context {
198            context_chunks.push(&context.url);
199            context_chunks.push(&context.text);
200        }
201    }
202
203    // Need to own the SharedString for summary so that it can be referenced.
204    let mut thread_context_chunks = Vec::new();
205    if !thread_context.is_empty() {
206        context_chunks.push("The following previous conversation threads are available:\n");
207        for context in &thread_context {
208            thread_context_chunks.push(context.summary(cx));
209            thread_context_chunks.push(context.text.clone());
210        }
211    }
212    for chunk in &thread_context_chunks {
213        context_chunks.push(chunk);
214    }
215
216    if !context_chunks.is_empty() {
217        message
218            .content
219            .push(MessageContent::Text(context_chunks.join("\n")));
220    }
221}