1use std::{ops::Range, sync::Arc};
2
3use gpui::{App, Entity, SharedString};
4use language::{Buffer, File};
5use language_model::{LanguageModelRequestMessage, MessageContent};
6use project::ProjectPath;
7use serde::{Deserialize, Serialize};
8use text::{Anchor, BufferId};
9use ui::IconName;
10use util::post_inc;
11
12use crate::thread::Thread;
13
14#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
15pub struct ContextId(pub(crate) usize);
16
17impl ContextId {
18 pub fn post_inc(&mut self) -> Self {
19 Self(post_inc(&mut self.0))
20 }
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ContextKind {
25 File,
26 Directory,
27 Symbol,
28 FetchedUrl,
29 Thread,
30}
31
32impl ContextKind {
33 pub fn icon(&self) -> IconName {
34 match self {
35 ContextKind::File => IconName::File,
36 ContextKind::Directory => IconName::Folder,
37 ContextKind::Symbol => IconName::Code,
38 ContextKind::FetchedUrl => IconName::Globe,
39 ContextKind::Thread => IconName::MessageBubbles,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
45pub enum AssistantContext {
46 File(FileContext),
47 Directory(DirectoryContext),
48 Symbol(SymbolContext),
49 FetchedUrl(FetchedUrlContext),
50 Thread(ThreadContext),
51}
52
53impl AssistantContext {
54 pub fn id(&self) -> ContextId {
55 match self {
56 Self::File(file) => file.id,
57 Self::Directory(directory) => directory.id,
58 Self::Symbol(symbol) => symbol.id,
59 Self::FetchedUrl(url) => url.id,
60 Self::Thread(thread) => thread.id,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct FileContext {
67 pub id: ContextId,
68 pub context_buffer: ContextBuffer,
69}
70
71#[derive(Debug, Clone)]
72pub struct DirectoryContext {
73 pub id: ContextId,
74 pub project_path: ProjectPath,
75 pub context_buffers: Vec<ContextBuffer>,
76}
77
78#[derive(Debug, Clone)]
79pub struct SymbolContext {
80 pub id: ContextId,
81 pub context_symbol: ContextSymbol,
82}
83
84#[derive(Debug, Clone)]
85pub struct FetchedUrlContext {
86 pub id: ContextId,
87 pub url: SharedString,
88 pub text: SharedString,
89}
90
91// TODO: Model<Thread> holds onto the thread even if the thread is deleted. Can either handle this
92// explicitly or have a WeakModel<Thread> and remove during snapshot.
93
94#[derive(Debug, Clone)]
95pub struct ThreadContext {
96 pub id: ContextId,
97 pub thread: Entity<Thread>,
98 pub text: SharedString,
99}
100
101impl ThreadContext {
102 pub fn summary(&self, cx: &App) -> SharedString {
103 self.thread
104 .read(cx)
105 .summary()
106 .unwrap_or("New thread".into())
107 }
108}
109
110// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
111// the context from the message editor in this case.
112
113#[derive(Clone)]
114pub struct ContextBuffer {
115 pub id: BufferId,
116 pub buffer: Entity<Buffer>,
117 pub file: Arc<dyn File>,
118 pub version: clock::Global,
119 pub text: SharedString,
120}
121
122impl std::fmt::Debug for ContextBuffer {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("ContextBuffer")
125 .field("id", &self.id)
126 .field("buffer", &self.buffer)
127 .field("version", &self.version)
128 .field("text", &self.text)
129 .finish()
130 }
131}
132
133#[derive(Debug, Clone)]
134pub struct ContextSymbol {
135 pub id: ContextSymbolId,
136 pub buffer: Entity<Buffer>,
137 pub buffer_version: clock::Global,
138 /// The range that the symbol encloses, e.g. for function symbol, this will
139 /// include not only the signature, but also the body
140 pub enclosing_range: Range<Anchor>,
141 pub text: SharedString,
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, Hash)]
145pub struct ContextSymbolId {
146 pub path: ProjectPath,
147 pub name: SharedString,
148 pub range: Range<Anchor>,
149}
150
151pub fn attach_context_to_message<'a>(
152 message: &mut LanguageModelRequestMessage,
153 contexts: impl Iterator<Item = &'a AssistantContext>,
154 cx: &App,
155) {
156 let mut file_context = Vec::new();
157 let mut directory_context = Vec::new();
158 let mut symbol_context = Vec::new();
159 let mut fetch_context = Vec::new();
160 let mut thread_context = Vec::new();
161
162 for context in contexts {
163 match context {
164 AssistantContext::File(context) => file_context.push(context),
165 AssistantContext::Directory(context) => directory_context.push(context),
166 AssistantContext::Symbol(context) => symbol_context.push(context),
167 AssistantContext::FetchedUrl(context) => fetch_context.push(context),
168 AssistantContext::Thread(context) => thread_context.push(context),
169 }
170 }
171
172 let mut context_chunks = Vec::new();
173
174 if !file_context.is_empty() {
175 context_chunks.push("The following files are available:\n");
176 for context in file_context {
177 context_chunks.push(&context.context_buffer.text);
178 }
179 }
180
181 if !directory_context.is_empty() {
182 context_chunks.push("The following directories are available:\n");
183 for context in directory_context {
184 for context_buffer in &context.context_buffers {
185 context_chunks.push(&context_buffer.text);
186 }
187 }
188 }
189
190 if !symbol_context.is_empty() {
191 context_chunks.push("The following symbols are available:\n");
192 for context in symbol_context {
193 context_chunks.push(&context.context_symbol.text);
194 }
195 }
196
197 if !fetch_context.is_empty() {
198 context_chunks.push("The following fetched results are available:\n");
199 for context in &fetch_context {
200 context_chunks.push(&context.url);
201 context_chunks.push(&context.text);
202 }
203 }
204
205 // Need to own the SharedString for summary so that it can be referenced.
206 let mut thread_context_chunks = Vec::new();
207 if !thread_context.is_empty() {
208 context_chunks.push("The following previous conversation threads are available:\n");
209 for context in &thread_context {
210 thread_context_chunks.push(context.summary(cx));
211 thread_context_chunks.push(context.text.clone());
212 }
213 }
214 for chunk in &thread_context_chunks {
215 context_chunks.push(chunk);
216 }
217
218 if !context_chunks.is_empty() {
219 message
220 .content
221 .push(MessageContent::Text(context_chunks.join("\n")));
222 }
223}