1use std::ops::Range;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::{HashSet, IndexSet};
7use futures::future::join_all;
8use futures::{self, FutureExt};
9use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
10use language::Buffer;
11use language_model::LanguageModelImage;
12use project::{Project, ProjectItem, ProjectPath, Symbol};
13use prompt_store::UserPromptId;
14use ref_cast::RefCast as _;
15use text::{Anchor, OffsetRangeExt};
16use util::ResultExt as _;
17
18use crate::ThreadStore;
19use crate::context::{
20 AgentContext, AgentContextKey, ContextId, DirectoryContext, FetchedUrlContext, FileContext,
21 ImageContext, RulesContext, SelectionContext, SymbolContext, ThreadContext,
22};
23use crate::context_strip::SuggestedContext;
24use crate::thread::{Thread, ThreadId};
25
26pub struct ContextStore {
27 project: WeakEntity<Project>,
28 thread_store: Option<WeakEntity<ThreadStore>>,
29 thread_summary_tasks: Vec<Task<()>>,
30 next_context_id: ContextId,
31 context_set: IndexSet<AgentContextKey>,
32 context_thread_ids: HashSet<ThreadId>,
33}
34
35impl ContextStore {
36 pub fn new(
37 project: WeakEntity<Project>,
38 thread_store: Option<WeakEntity<ThreadStore>>,
39 ) -> Self {
40 Self {
41 project,
42 thread_store,
43 thread_summary_tasks: Vec::new(),
44 next_context_id: ContextId::zero(),
45 context_set: IndexSet::default(),
46 context_thread_ids: HashSet::default(),
47 }
48 }
49
50 pub fn context(&self) -> impl Iterator<Item = &AgentContext> {
51 self.context_set.iter().map(|entry| entry.as_ref())
52 }
53
54 pub fn clear(&mut self) {
55 self.context_set.clear();
56 self.context_thread_ids.clear();
57 }
58
59 pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContext> {
60 let existing_context = thread
61 .messages()
62 .flat_map(|message| &message.loaded_context.contexts)
63 .map(AgentContextKey::ref_cast)
64 .collect::<HashSet<_>>();
65 self.context_set
66 .iter()
67 .filter(|context| !existing_context.contains(context))
68 .map(|entry| entry.0.clone())
69 .collect::<Vec<_>>()
70 }
71
72 pub fn add_file_from_path(
73 &mut self,
74 project_path: ProjectPath,
75 remove_if_exists: bool,
76 cx: &mut Context<Self>,
77 ) -> Task<Result<()>> {
78 let Some(project) = self.project.upgrade() else {
79 return Task::ready(Err(anyhow!("failed to read project")));
80 };
81
82 cx.spawn(async move |this, cx| {
83 let open_buffer_task = project.update(cx, |project, cx| {
84 project.open_buffer(project_path.clone(), cx)
85 })?;
86 let buffer = open_buffer_task.await?;
87 this.update(cx, |this, cx| {
88 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
89 })
90 })
91 }
92
93 pub fn add_file_from_buffer(
94 &mut self,
95 project_path: &ProjectPath,
96 buffer: Entity<Buffer>,
97 remove_if_exists: bool,
98 cx: &mut Context<Self>,
99 ) {
100 let context_id = self.next_context_id.post_inc();
101 let context = AgentContext::File(FileContext { buffer, context_id });
102
103 let already_included = if self.has_context(&context) {
104 if remove_if_exists {
105 self.remove_context(&context, cx);
106 }
107 true
108 } else {
109 self.path_included_in_directory(project_path, cx).is_some()
110 };
111
112 if !already_included {
113 self.insert_context(context, cx);
114 }
115 }
116
117 pub fn add_directory(
118 &mut self,
119 project_path: &ProjectPath,
120 remove_if_exists: bool,
121 cx: &mut Context<Self>,
122 ) -> Result<()> {
123 let Some(project) = self.project.upgrade() else {
124 return Err(anyhow!("failed to read project"));
125 };
126
127 let Some(entry_id) = project
128 .read(cx)
129 .entry_for_path(project_path, cx)
130 .map(|entry| entry.id)
131 else {
132 return Err(anyhow!("no entry found for directory context"));
133 };
134
135 let context_id = self.next_context_id.post_inc();
136 let context = AgentContext::Directory(DirectoryContext {
137 entry_id,
138 context_id,
139 });
140
141 if self.has_context(&context) {
142 if remove_if_exists {
143 self.remove_context(&context, cx);
144 }
145 } else if self.path_included_in_directory(project_path, cx).is_none() {
146 self.insert_context(context, cx);
147 }
148
149 anyhow::Ok(())
150 }
151
152 pub fn add_symbol(
153 &mut self,
154 buffer: Entity<Buffer>,
155 symbol: SharedString,
156 range: Range<Anchor>,
157 enclosing_range: Range<Anchor>,
158 remove_if_exists: bool,
159 cx: &mut Context<Self>,
160 ) -> bool {
161 let context_id = self.next_context_id.post_inc();
162 let context = AgentContext::Symbol(SymbolContext {
163 buffer,
164 symbol,
165 range,
166 enclosing_range,
167 context_id,
168 });
169
170 if self.has_context(&context) {
171 if remove_if_exists {
172 self.remove_context(&context, cx);
173 }
174 return false;
175 }
176
177 self.insert_context(context, cx)
178 }
179
180 pub fn add_thread(
181 &mut self,
182 thread: Entity<Thread>,
183 remove_if_exists: bool,
184 cx: &mut Context<Self>,
185 ) {
186 let context_id = self.next_context_id.post_inc();
187 let context = AgentContext::Thread(ThreadContext { thread, context_id });
188
189 if self.has_context(&context) {
190 if remove_if_exists {
191 self.remove_context(&context, cx);
192 }
193 } else {
194 self.insert_context(context, cx);
195 }
196 }
197
198 fn start_summarizing_thread_if_needed(
199 &mut self,
200 thread: &Entity<Thread>,
201 cx: &mut Context<Self>,
202 ) {
203 if let Some(summary_task) =
204 thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx))
205 {
206 let thread = thread.clone();
207 let thread_store = self.thread_store.clone();
208
209 self.thread_summary_tasks.push(cx.spawn(async move |_, cx| {
210 summary_task.await;
211
212 if let Some(thread_store) = thread_store {
213 // Save thread so its summary can be reused later
214 let save_task = thread_store
215 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx));
216
217 if let Some(save_task) = save_task.ok() {
218 save_task.await.log_err();
219 }
220 }
221 }));
222 }
223 }
224
225 pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> {
226 let tasks = std::mem::take(&mut self.thread_summary_tasks);
227
228 cx.spawn(async move |_cx| {
229 join_all(tasks).await;
230 })
231 }
232
233 pub fn add_rules(
234 &mut self,
235 prompt_id: UserPromptId,
236 remove_if_exists: bool,
237 cx: &mut Context<ContextStore>,
238 ) {
239 let context_id = self.next_context_id.post_inc();
240 let context = AgentContext::Rules(RulesContext {
241 prompt_id,
242 context_id,
243 });
244
245 if self.has_context(&context) {
246 if remove_if_exists {
247 self.remove_context(&context, cx);
248 }
249 } else {
250 self.insert_context(context, cx);
251 }
252 }
253
254 pub fn add_fetched_url(
255 &mut self,
256 url: String,
257 text: impl Into<SharedString>,
258 cx: &mut Context<ContextStore>,
259 ) {
260 let context = AgentContext::FetchedUrl(FetchedUrlContext {
261 url: url.into(),
262 text: text.into(),
263 context_id: self.next_context_id.post_inc(),
264 });
265
266 self.insert_context(context, cx);
267 }
268
269 pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
270 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
271 let context = AgentContext::Image(ImageContext {
272 original_image: image,
273 image_task,
274 context_id: self.next_context_id.post_inc(),
275 });
276 self.insert_context(context, cx);
277 }
278
279 pub fn add_selection(
280 &mut self,
281 buffer: Entity<Buffer>,
282 range: Range<Anchor>,
283 cx: &mut Context<ContextStore>,
284 ) {
285 let context_id = self.next_context_id.post_inc();
286 let context = AgentContext::Selection(SelectionContext {
287 buffer,
288 range,
289 context_id,
290 });
291 self.insert_context(context, cx);
292 }
293
294 pub fn add_suggested_context(
295 &mut self,
296 suggested: &SuggestedContext,
297 cx: &mut Context<ContextStore>,
298 ) {
299 match suggested {
300 SuggestedContext::File {
301 buffer,
302 icon_path: _,
303 name: _,
304 } => {
305 if let Some(buffer) = buffer.upgrade() {
306 let context_id = self.next_context_id.post_inc();
307 self.insert_context(AgentContext::File(FileContext { buffer, context_id }), cx);
308 };
309 }
310 SuggestedContext::Thread { thread, name: _ } => {
311 if let Some(thread) = thread.upgrade() {
312 let context_id = self.next_context_id.post_inc();
313 self.insert_context(
314 AgentContext::Thread(ThreadContext { thread, context_id }),
315 cx,
316 );
317 }
318 }
319 }
320 }
321
322 fn insert_context(&mut self, context: AgentContext, cx: &mut Context<Self>) -> bool {
323 match &context {
324 AgentContext::Thread(thread_context) => {
325 self.context_thread_ids
326 .insert(thread_context.thread.read(cx).id().clone());
327 self.start_summarizing_thread_if_needed(&thread_context.thread, cx);
328 }
329 _ => {}
330 }
331 let inserted = self.context_set.insert(AgentContextKey(context));
332 if inserted {
333 cx.notify();
334 }
335 inserted
336 }
337
338 pub fn remove_context(&mut self, context: &AgentContext, cx: &mut Context<Self>) {
339 if self
340 .context_set
341 .shift_remove(AgentContextKey::ref_cast(context))
342 {
343 match context {
344 AgentContext::Thread(thread_context) => {
345 self.context_thread_ids
346 .remove(thread_context.thread.read(cx).id());
347 }
348 _ => {}
349 }
350 cx.notify();
351 }
352 }
353
354 pub fn has_context(&mut self, context: &AgentContext) -> bool {
355 self.context_set
356 .contains(AgentContextKey::ref_cast(context))
357 }
358
359 /// Returns whether this file path is already included directly in the context, or if it will be
360 /// included in the context via a directory.
361 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
362 let project = self.project.upgrade()?.read(cx);
363 self.context().find_map(|context| match context {
364 AgentContext::File(file_context) => FileInclusion::check_file(file_context, path, cx),
365 AgentContext::Directory(directory_context) => {
366 FileInclusion::check_directory(directory_context, path, project, cx)
367 }
368 _ => None,
369 })
370 }
371
372 pub fn path_included_in_directory(
373 &self,
374 path: &ProjectPath,
375 cx: &App,
376 ) -> Option<FileInclusion> {
377 let project = self.project.upgrade()?.read(cx);
378 self.context().find_map(|context| match context {
379 AgentContext::Directory(directory_context) => {
380 FileInclusion::check_directory(directory_context, path, project, cx)
381 }
382 _ => None,
383 })
384 }
385
386 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
387 self.context().any(|context| match context {
388 AgentContext::Symbol(context) => {
389 if context.symbol != symbol.name {
390 return false;
391 }
392 let buffer = context.buffer.read(cx);
393 let Some(context_path) = buffer.project_path(cx) else {
394 return false;
395 };
396 if context_path != symbol.path {
397 return false;
398 }
399 let context_range = context.range.to_point_utf16(&buffer.snapshot());
400 context_range.start == symbol.range.start.0
401 && context_range.end == symbol.range.end.0
402 }
403 _ => false,
404 })
405 }
406
407 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
408 self.context_thread_ids.contains(thread_id)
409 }
410
411 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
412 self.context_set
413 .contains(&RulesContext::lookup_key(prompt_id))
414 }
415
416 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
417 self.context_set
418 .contains(&FetchedUrlContext::lookup_key(url.into()))
419 }
420
421 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
422 self.context()
423 .filter_map(|context| match context {
424 AgentContext::File(file) => {
425 let buffer = file.buffer.read(cx);
426 buffer.project_path(cx)
427 }
428 AgentContext::Directory(_)
429 | AgentContext::Symbol(_)
430 | AgentContext::Selection(_)
431 | AgentContext::FetchedUrl(_)
432 | AgentContext::Thread(_)
433 | AgentContext::Rules(_)
434 | AgentContext::Image(_) => None,
435 })
436 .collect()
437 }
438
439 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
440 &self.context_thread_ids
441 }
442}
443
444pub enum FileInclusion {
445 Direct,
446 InDirectory { full_path: PathBuf },
447}
448
449impl FileInclusion {
450 fn check_file(file_context: &FileContext, path: &ProjectPath, cx: &App) -> Option<Self> {
451 let file_path = file_context.buffer.read(cx).project_path(cx)?;
452 if path == &file_path {
453 Some(FileInclusion::Direct)
454 } else {
455 None
456 }
457 }
458
459 fn check_directory(
460 directory_context: &DirectoryContext,
461 path: &ProjectPath,
462 project: &Project,
463 cx: &App,
464 ) -> Option<Self> {
465 let worktree = project
466 .worktree_for_entry(directory_context.entry_id, cx)?
467 .read(cx);
468 let entry = worktree.entry_for_id(directory_context.entry_id)?;
469 let directory_path = ProjectPath {
470 worktree_id: worktree.id(),
471 path: entry.path.clone(),
472 };
473 if path.starts_with(&directory_path) {
474 if path == &directory_path {
475 Some(FileInclusion::Direct)
476 } else {
477 Some(FileInclusion::InDirectory {
478 full_path: worktree.full_path(&entry.path),
479 })
480 }
481 } else {
482 None
483 }
484 }
485}