1use std::borrow::Cow;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
8use chrono::{DateTime, Utc};
9use collections::HashMap;
10use context_server::manager::ContextServerManager;
11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
12use futures::FutureExt as _;
13use futures::future::{self, BoxFuture, Shared};
14use gpui::{
15 App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
16 prelude::*,
17};
18use heed::Database;
19use heed::types::SerdeBincode;
20use language_model::{LanguageModelToolUseId, Role, TokenUsage};
21use project::Project;
22use prompt_store::PromptBuilder;
23use serde::{Deserialize, Serialize};
24use settings::{Settings as _, SettingsStore};
25use util::ResultExt as _;
26
27use crate::thread::{
28 DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
29};
30
31pub fn init(cx: &mut App) {
32 ThreadsDatabase::init(cx);
33}
34
35pub struct ThreadStore {
36 project: Entity<Project>,
37 tools: Arc<ToolWorkingSet>,
38 prompt_builder: Arc<PromptBuilder>,
39 context_server_manager: Entity<ContextServerManager>,
40 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
41 threads: Vec<SerializedThreadMetadata>,
42 _subscriptions: Vec<Subscription>,
43}
44
45impl ThreadStore {
46 pub fn new(
47 project: Entity<Project>,
48 tools: Arc<ToolWorkingSet>,
49 prompt_builder: Arc<PromptBuilder>,
50 cx: &mut App,
51 ) -> Result<Entity<Self>> {
52 let this = cx.new(|cx| {
53 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
54 let context_server_manager = cx.new(|cx| {
55 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
56 });
57 let settings_subscription =
58 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
59 this.load_default_profile(cx);
60 });
61
62 let this = Self {
63 project,
64 tools,
65 prompt_builder,
66 context_server_manager,
67 context_server_tool_ids: HashMap::default(),
68 threads: Vec::new(),
69 _subscriptions: vec![settings_subscription],
70 };
71 this.load_default_profile(cx);
72 this.register_context_server_handlers(cx);
73 this.reload(cx).detach_and_log_err(cx);
74
75 this
76 });
77
78 Ok(this)
79 }
80
81 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
82 self.context_server_manager.clone()
83 }
84
85 pub fn tools(&self) -> Arc<ToolWorkingSet> {
86 self.tools.clone()
87 }
88
89 /// Returns the number of threads.
90 pub fn thread_count(&self) -> usize {
91 self.threads.len()
92 }
93
94 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
95 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
96 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
97 threads
98 }
99
100 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
101 self.threads().into_iter().take(limit).collect()
102 }
103
104 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
105 cx.new(|cx| {
106 Thread::new(
107 self.project.clone(),
108 self.tools.clone(),
109 self.prompt_builder.clone(),
110 cx,
111 )
112 })
113 }
114
115 pub fn open_thread(
116 &self,
117 id: &ThreadId,
118 cx: &mut Context<Self>,
119 ) -> Task<Result<Entity<Thread>>> {
120 let id = id.clone();
121 let database_future = ThreadsDatabase::global_future(cx);
122 cx.spawn(async move |this, cx| {
123 let database = database_future.await.map_err(|err| anyhow!(err))?;
124 let thread = database
125 .try_find_thread(id.clone())
126 .await?
127 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
128
129 let thread = this.update(cx, |this, cx| {
130 cx.new(|cx| {
131 Thread::deserialize(
132 id.clone(),
133 thread,
134 this.project.clone(),
135 this.tools.clone(),
136 this.prompt_builder.clone(),
137 cx,
138 )
139 })
140 })?;
141
142 let (system_prompt_context, load_error) = thread
143 .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
144 .await;
145 thread.update(cx, |thread, cx| {
146 thread.set_system_prompt_context(system_prompt_context);
147 if let Some(load_error) = load_error {
148 cx.emit(ThreadEvent::ShowError(load_error));
149 }
150 })?;
151
152 Ok(thread)
153 })
154 }
155
156 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
157 let (metadata, serialized_thread) =
158 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
159
160 let database_future = ThreadsDatabase::global_future(cx);
161 cx.spawn(async move |this, cx| {
162 let serialized_thread = serialized_thread.await?;
163 let database = database_future.await.map_err(|err| anyhow!(err))?;
164 database.save_thread(metadata, serialized_thread).await?;
165
166 this.update(cx, |this, cx| this.reload(cx))?.await
167 })
168 }
169
170 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
171 let id = id.clone();
172 let database_future = ThreadsDatabase::global_future(cx);
173 cx.spawn(async move |this, cx| {
174 let database = database_future.await.map_err(|err| anyhow!(err))?;
175 database.delete_thread(id.clone()).await?;
176
177 this.update(cx, |this, cx| {
178 this.threads.retain(|thread| thread.id != id);
179 cx.notify();
180 })
181 })
182 }
183
184 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
185 let database_future = ThreadsDatabase::global_future(cx);
186 cx.spawn(async move |this, cx| {
187 let threads = database_future
188 .await
189 .map_err(|err| anyhow!(err))?
190 .list_threads()
191 .await?;
192
193 this.update(cx, |this, cx| {
194 this.threads = threads;
195 cx.notify();
196 })
197 })
198 }
199
200 fn load_default_profile(&self, cx: &Context<Self>) {
201 let assistant_settings = AssistantSettings::get_global(cx);
202
203 self.load_profile_by_id(&assistant_settings.default_profile, cx);
204 }
205
206 pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
207 let assistant_settings = AssistantSettings::get_global(cx);
208
209 if let Some(profile) = assistant_settings.profiles.get(profile_id) {
210 self.load_profile(profile, cx);
211 }
212 }
213
214 pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
215 self.tools.disable_all_tools();
216 self.tools.enable(
217 ToolSource::Native,
218 &profile
219 .tools
220 .iter()
221 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
222 .collect::<Vec<_>>(),
223 );
224
225 if profile.enable_all_context_servers {
226 for context_server in self.context_server_manager.read(cx).all_servers() {
227 self.tools.enable_source(
228 ToolSource::ContextServer {
229 id: context_server.id().into(),
230 },
231 cx,
232 );
233 }
234 } else {
235 for (context_server_id, preset) in &profile.context_servers {
236 self.tools.enable(
237 ToolSource::ContextServer {
238 id: context_server_id.clone().into(),
239 },
240 &preset
241 .tools
242 .iter()
243 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
244 .collect::<Vec<_>>(),
245 )
246 }
247 }
248 }
249
250 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
251 cx.subscribe(
252 &self.context_server_manager.clone(),
253 Self::handle_context_server_event,
254 )
255 .detach();
256 }
257
258 fn handle_context_server_event(
259 &mut self,
260 context_server_manager: Entity<ContextServerManager>,
261 event: &context_server::manager::Event,
262 cx: &mut Context<Self>,
263 ) {
264 let tool_working_set = self.tools.clone();
265 match event {
266 context_server::manager::Event::ServerStarted { server_id } => {
267 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
268 let context_server_manager = context_server_manager.clone();
269 cx.spawn({
270 let server = server.clone();
271 let server_id = server_id.clone();
272 async move |this, cx| {
273 let Some(protocol) = server.client() else {
274 return;
275 };
276
277 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
278 if let Some(tools) = protocol.list_tools().await.log_err() {
279 let tool_ids = tools
280 .tools
281 .into_iter()
282 .map(|tool| {
283 log::info!(
284 "registering context server tool: {:?}",
285 tool.name
286 );
287 tool_working_set.insert(Arc::new(
288 ContextServerTool::new(
289 context_server_manager.clone(),
290 server.id(),
291 tool,
292 ),
293 ))
294 })
295 .collect::<Vec<_>>();
296
297 this.update(cx, |this, cx| {
298 this.context_server_tool_ids.insert(server_id, tool_ids);
299 this.load_default_profile(cx);
300 })
301 .log_err();
302 }
303 }
304 }
305 })
306 .detach();
307 }
308 }
309 context_server::manager::Event::ServerStopped { server_id } => {
310 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
311 tool_working_set.remove(&tool_ids);
312 self.load_default_profile(cx);
313 }
314 }
315 }
316 }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct SerializedThreadMetadata {
321 pub id: ThreadId,
322 pub summary: SharedString,
323 pub updated_at: DateTime<Utc>,
324}
325
326#[derive(Serialize, Deserialize, Debug)]
327pub struct SerializedThread {
328 pub version: String,
329 pub summary: SharedString,
330 pub updated_at: DateTime<Utc>,
331 pub messages: Vec<SerializedMessage>,
332 #[serde(default)]
333 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
334 #[serde(default)]
335 pub cumulative_token_usage: TokenUsage,
336 #[serde(default)]
337 pub detailed_summary_state: DetailedSummaryState,
338}
339
340impl SerializedThread {
341 pub const VERSION: &'static str = "0.1.0";
342
343 pub fn from_json(json: &[u8]) -> Result<Self> {
344 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
345 match saved_thread_json.get("version") {
346 Some(serde_json::Value::String(version)) => match version.as_str() {
347 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
348 saved_thread_json,
349 )?),
350 _ => Err(anyhow!(
351 "unrecognized serialized thread version: {}",
352 version
353 )),
354 },
355 None => {
356 let saved_thread =
357 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
358 Ok(saved_thread.upgrade())
359 }
360 version => Err(anyhow!(
361 "unrecognized serialized thread version: {:?}",
362 version
363 )),
364 }
365 }
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct SerializedMessage {
370 pub id: MessageId,
371 pub role: Role,
372 #[serde(default)]
373 pub segments: Vec<SerializedMessageSegment>,
374 #[serde(default)]
375 pub tool_uses: Vec<SerializedToolUse>,
376 #[serde(default)]
377 pub tool_results: Vec<SerializedToolResult>,
378 #[serde(default)]
379 pub context: String,
380}
381
382#[derive(Debug, Serialize, Deserialize)]
383#[serde(tag = "type")]
384pub enum SerializedMessageSegment {
385 #[serde(rename = "text")]
386 Text { text: String },
387 #[serde(rename = "thinking")]
388 Thinking { text: String },
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct SerializedToolUse {
393 pub id: LanguageModelToolUseId,
394 pub name: SharedString,
395 pub input: serde_json::Value,
396}
397
398#[derive(Debug, Serialize, Deserialize)]
399pub struct SerializedToolResult {
400 pub tool_use_id: LanguageModelToolUseId,
401 pub is_error: bool,
402 pub content: Arc<str>,
403}
404
405#[derive(Serialize, Deserialize)]
406struct LegacySerializedThread {
407 pub summary: SharedString,
408 pub updated_at: DateTime<Utc>,
409 pub messages: Vec<LegacySerializedMessage>,
410 #[serde(default)]
411 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
412}
413
414impl LegacySerializedThread {
415 pub fn upgrade(self) -> SerializedThread {
416 SerializedThread {
417 version: SerializedThread::VERSION.to_string(),
418 summary: self.summary,
419 updated_at: self.updated_at,
420 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
421 initial_project_snapshot: self.initial_project_snapshot,
422 cumulative_token_usage: TokenUsage::default(),
423 detailed_summary_state: DetailedSummaryState::default(),
424 }
425 }
426}
427
428#[derive(Debug, Serialize, Deserialize)]
429struct LegacySerializedMessage {
430 pub id: MessageId,
431 pub role: Role,
432 pub text: String,
433 #[serde(default)]
434 pub tool_uses: Vec<SerializedToolUse>,
435 #[serde(default)]
436 pub tool_results: Vec<SerializedToolResult>,
437}
438
439impl LegacySerializedMessage {
440 fn upgrade(self) -> SerializedMessage {
441 SerializedMessage {
442 id: self.id,
443 role: self.role,
444 segments: vec![SerializedMessageSegment::Text { text: self.text }],
445 tool_uses: self.tool_uses,
446 tool_results: self.tool_results,
447 context: String::new(),
448 }
449 }
450}
451
452struct GlobalThreadsDatabase(
453 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
454);
455
456impl Global for GlobalThreadsDatabase {}
457
458pub(crate) struct ThreadsDatabase {
459 executor: BackgroundExecutor,
460 env: heed::Env,
461 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
462}
463
464impl heed::BytesEncode<'_> for SerializedThread {
465 type EItem = SerializedThread;
466
467 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
468 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
469 }
470}
471
472impl<'a> heed::BytesDecode<'a> for SerializedThread {
473 type DItem = SerializedThread;
474
475 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
476 // We implement this type manually because we want to call `SerializedThread::from_json`,
477 // instead of the Deserialize trait implementation for `SerializedThread`.
478 SerializedThread::from_json(bytes).map_err(Into::into)
479 }
480}
481
482impl ThreadsDatabase {
483 fn global_future(
484 cx: &mut App,
485 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
486 GlobalThreadsDatabase::global(cx).0.clone()
487 }
488
489 fn init(cx: &mut App) {
490 let executor = cx.background_executor().clone();
491 let database_future = executor
492 .spawn({
493 let executor = executor.clone();
494 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
495 async move { ThreadsDatabase::new(database_path, executor) }
496 })
497 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
498 .boxed()
499 .shared();
500
501 cx.set_global(GlobalThreadsDatabase(database_future));
502 }
503
504 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
505 std::fs::create_dir_all(&path)?;
506
507 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
508 let env = unsafe {
509 heed::EnvOpenOptions::new()
510 .map_size(ONE_GB_IN_BYTES)
511 .max_dbs(1)
512 .open(path)?
513 };
514
515 let mut txn = env.write_txn()?;
516 let threads = env.create_database(&mut txn, Some("threads"))?;
517 txn.commit()?;
518
519 Ok(Self {
520 executor,
521 env,
522 threads,
523 })
524 }
525
526 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
527 let env = self.env.clone();
528 let threads = self.threads;
529
530 self.executor.spawn(async move {
531 let txn = env.read_txn()?;
532 let mut iter = threads.iter(&txn)?;
533 let mut threads = Vec::new();
534 while let Some((key, value)) = iter.next().transpose()? {
535 threads.push(SerializedThreadMetadata {
536 id: key,
537 summary: value.summary,
538 updated_at: value.updated_at,
539 });
540 }
541
542 Ok(threads)
543 })
544 }
545
546 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
547 let env = self.env.clone();
548 let threads = self.threads;
549
550 self.executor.spawn(async move {
551 let txn = env.read_txn()?;
552 let thread = threads.get(&txn, &id)?;
553 Ok(thread)
554 })
555 }
556
557 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
558 let env = self.env.clone();
559 let threads = self.threads;
560
561 self.executor.spawn(async move {
562 let mut txn = env.write_txn()?;
563 threads.put(&mut txn, &id, &thread)?;
564 txn.commit()?;
565 Ok(())
566 })
567 }
568
569 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
570 let env = self.env.clone();
571 let threads = self.threads;
572
573 self.executor.spawn(async move {
574 let mut txn = env.write_txn()?;
575 threads.delete(&mut txn, &id)?;
576 txn.commit()?;
577 Ok(())
578 })
579 }
580}