Detailed changes
@@ -3280,7 +3280,10 @@ dependencies = [
name = "collab"
version = "0.44.0"
dependencies = [
+ "agent",
+ "agent-client-protocol",
"agent_settings",
+ "agent_ui",
"anyhow",
"assistant_slash_command",
"assistant_text_thread",
@@ -20662,6 +20665,8 @@ version = "0.219.0"
dependencies = [
"acp_tools",
"activity_indicator",
+ "agent",
+ "agent-client-protocol",
"agent_settings",
"agent_ui",
"agent_ui_v2",
@@ -0,0 +1,6 @@
+collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
+cloud: cd ../cloud; cargo make dev
+dashboard: cd ../cloud/packages/dashboard; pnpm dev
+website: cd ../zed.dev; pnpm dev --port=3000
+livekit: livekit-server --dev
+blob_store: ./script/run-local-minio
@@ -50,6 +50,63 @@ pub struct DbThread {
pub completion_mode: Option<CompletionMode>,
#[serde(default)]
pub profile: Option<AgentProfileId>,
+ #[serde(default)]
+ pub imported: bool,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct SharedThread {
+ pub title: SharedString,
+ pub messages: Vec<DbMessage>,
+ pub updated_at: DateTime<Utc>,
+ #[serde(default)]
+ pub model: Option<DbLanguageModel>,
+ #[serde(default)]
+ pub completion_mode: Option<CompletionMode>,
+ pub version: String,
+}
+
+impl SharedThread {
+ pub const VERSION: &'static str = "1.0.0";
+
+ pub fn from_db_thread(thread: &DbThread) -> Self {
+ Self {
+ title: thread.title.clone(),
+ messages: thread.messages.clone(),
+ updated_at: thread.updated_at,
+ model: thread.model.clone(),
+ completion_mode: thread.completion_mode,
+ version: Self::VERSION.to_string(),
+ }
+ }
+
+ pub fn to_db_thread(self) -> DbThread {
+ DbThread {
+ title: format!("🔗 {}", self.title).into(),
+ messages: self.messages,
+ updated_at: self.updated_at,
+ detailed_summary: None,
+ initial_project_snapshot: None,
+ cumulative_token_usage: Default::default(),
+ request_token_usage: Default::default(),
+ model: self.model,
+ completion_mode: self.completion_mode,
+ profile: None,
+ imported: true,
+ }
+ }
+
+ pub fn to_bytes(&self) -> Result<Vec<u8>> {
+ const COMPRESSION_LEVEL: i32 = 3;
+ let json = serde_json::to_vec(self)?;
+ let compressed = zstd::encode_all(json.as_slice(), COMPRESSION_LEVEL)?;
+ Ok(compressed)
+ }
+
+ pub fn from_bytes(data: &[u8]) -> Result<Self> {
+ let decompressed = zstd::decode_all(data)?;
+ Ok(serde_json::from_slice(&decompressed)?)
+ }
}
impl DbThread {
@@ -209,6 +266,7 @@ impl DbThread {
model: thread.model,
completion_mode: thread.completion_mode,
profile: thread.profile,
+ imported: false,
})
}
}
@@ -441,3 +499,45 @@ impl ThreadsDatabase {
})
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use chrono::TimeZone;
+
+ #[test]
+ fn test_shared_thread_roundtrip() {
+ let original = SharedThread {
+ title: "Test Thread".into(),
+ messages: vec![],
+ updated_at: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
+ model: None,
+ completion_mode: None,
+ version: SharedThread::VERSION.to_string(),
+ };
+
+ let bytes = original.to_bytes().expect("Failed to serialize");
+ let restored = SharedThread::from_bytes(&bytes).expect("Failed to deserialize");
+
+ assert_eq!(restored.title, original.title);
+ assert_eq!(restored.version, original.version);
+ assert_eq!(restored.updated_at, original.updated_at);
+ }
+
+ #[test]
+ fn test_imported_flag_defaults_to_false() {
+ // Simulate deserializing a thread without the imported field (backwards compatibility).
+ let json = r#"{
+ "title": "Old Thread",
+ "messages": [],
+ "updated_at": "2024-01-01T00:00:00Z"
+ }"#;
+
+ let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
+
+ assert!(
+ !db_thread.imported,
+ "Legacy threads without imported field should default to false"
+ );
+ }
+}
@@ -175,6 +175,20 @@ impl HistoryStore {
})
}
+ pub fn save_thread(
+ &mut self,
+ id: acp::SessionId,
+ thread: crate::DbThread,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ let database_future = ThreadsDatabase::connect(cx);
+ cx.spawn(async move |this, cx| {
+ let database = database_future.await.map_err(|err| anyhow!(err))?;
+ database.save_thread(id, thread).await?;
+ this.update(cx, |this, cx| this.reload(cx))
+ })
+ }
+
pub fn delete_thread(
&mut self,
id: acp::SessionId,
@@ -44,7 +44,7 @@ pub struct SerializedThread {
pub profile: Option<AgentProfileId>,
}
-#[derive(Serialize, Deserialize, Debug, PartialEq)]
+#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
pub struct SerializedLanguageModel {
pub provider: String,
pub model: String,
@@ -622,6 +622,8 @@ pub struct Thread {
pub(crate) action_log: Entity<ActionLog>,
/// Tracks the last time files were read by the agent, to detect external modifications
pub(crate) file_read_times: HashMap<PathBuf, fs::MTime>,
+ /// True if this thread was imported from a shared thread and can be synced.
+ imported: bool,
}
impl Thread {
@@ -678,6 +680,7 @@ impl Thread {
project,
action_log,
file_read_times: HashMap::default(),
+ imported: false,
}
}
@@ -685,6 +688,11 @@ impl Thread {
&self.id
}
+ /// Returns true if this thread was imported from a shared thread.
+ pub fn is_imported(&self) -> bool {
+ self.imported
+ }
+
pub fn replay(
&mut self,
cx: &mut Context<Self>,
@@ -866,6 +874,7 @@ impl Thread {
prompt_capabilities_tx,
prompt_capabilities_rx,
file_read_times: HashMap::default(),
+ imported: db_thread.imported,
}
}
@@ -885,6 +894,7 @@ impl Thread {
}),
completion_mode: Some(self.completion_mode),
profile: Some(self.profile_id.clone()),
+ imported: self.imported,
};
cx.background_spawn(async move {
@@ -5,7 +5,9 @@ use acp_thread::{
};
use acp_thread::{AgentConnection, Plan};
use action_log::{ActionLog, ActionLogTelemetry};
-use agent::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer};
+use agent::{
+ DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer, SharedThread,
+};
use agent_client_protocol::{self as acp, PromptCapabilities};
use agent_servers::{AgentServer, AgentServerDelegate};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
@@ -20,15 +22,16 @@ use editor::scroll::Autoscroll;
use editor::{
Editor, EditorEvent, EditorMode, MultiBuffer, PathKey, SelectionEffects, SizingBehavior,
};
+use feature_flags::{AgentSharingFeatureFlag, FeatureFlagAppExt};
use file_icons::FileIcons;
use fs::Fs;
use futures::FutureExt as _;
use gpui::{
- Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, CursorStyle,
- EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset,
- ListState, PlatformDisplay, SharedString, StyleRefinement, Subscription, Task, TextStyle,
- TextStyleRefinement, UnderlineStyle, WeakEntity, Window, WindowHandle, div, ease_in_out,
- linear_color_stop, linear_gradient, list, point, pulsating_between,
+ Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem,
+ CursorStyle, EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length,
+ ListOffset, ListState, PlatformDisplay, SharedString, StyleRefinement, Subscription, Task,
+ TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
+ ease_in_out, linear_color_stop, linear_gradient, list, point, pulsating_between,
};
use language::Buffer;
@@ -52,7 +55,7 @@ use ui::{
WithScrollbar, prelude::*, right_click_menu,
};
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
-use workspace::{CollaboratorId, NewTerminal, Workspace};
+use workspace::{CollaboratorId, NewTerminal, Toast, Workspace, notifications::NotificationId};
use zed_actions::agent::{Chat, ToggleModelSelector};
use zed_actions::assistant::OpenRulesLibrary;
@@ -935,6 +938,124 @@ impl AcpThreadView {
}
}
+ fn share_thread(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
+ let Some(thread) = self.as_native_thread(cx) else {
+ return;
+ };
+
+ let client = self.project.read(cx).client();
+ let workspace = self.workspace.clone();
+ let session_id = thread.read(cx).id().to_string();
+
+ let load_task = thread.read(cx).to_db(cx);
+
+ cx.spawn(async move |_this, cx| {
+ let db_thread = load_task.await;
+
+ let shared_thread = SharedThread::from_db_thread(&db_thread);
+ let thread_data = shared_thread.to_bytes()?;
+ let title = shared_thread.title.to_string();
+
+ client
+ .request(proto::ShareAgentThread {
+ session_id: session_id.clone(),
+ title,
+ thread_data,
+ })
+ .await?;
+
+ let share_url = client::zed_urls::shared_agent_thread_url(&session_id);
+
+ cx.update(|cx| {
+ if let Some(workspace) = workspace.upgrade() {
+ workspace.update(cx, |workspace, cx| {
+ struct ThreadSharedToast;
+ workspace.show_toast(
+ Toast::new(
+ NotificationId::unique::<ThreadSharedToast>(),
+ "Thread shared!",
+ )
+ .on_click(
+ "Copy URL",
+ move |_window, cx| {
+ cx.write_to_clipboard(ClipboardItem::new_string(
+ share_url.clone(),
+ ));
+ },
+ ),
+ cx,
+ );
+ });
+ }
+ })?;
+
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn sync_thread(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ if !self.is_imported_thread(cx) {
+ return;
+ }
+
+ let Some(thread) = self.as_native_thread(cx) else {
+ return;
+ };
+
+ let client = self.project.read(cx).client();
+ let history_store = self.history_store.clone();
+ let session_id = thread.read(cx).id().clone();
+
+ cx.spawn_in(window, async move |this, cx| {
+ let response = client
+ .request(proto::GetSharedAgentThread {
+ session_id: session_id.to_string(),
+ })
+ .await?;
+
+ let shared_thread = SharedThread::from_bytes(&response.thread_data)?;
+
+ let db_thread = shared_thread.to_db_thread();
+
+ history_store
+ .update(&mut cx.clone(), |store, cx| {
+ store.save_thread(session_id.clone(), db_thread, cx)
+ })?
+ .await?;
+
+ let thread_metadata = agent::DbThreadMetadata {
+ id: session_id,
+ title: format!("🔗 {}", response.title).into(),
+ updated_at: chrono::Utc::now(),
+ };
+
+ this.update_in(cx, |this, window, cx| {
+ this.resume_thread_metadata = Some(thread_metadata);
+ this.reset(window, cx);
+ })?;
+
+ this.update_in(cx, |this, _window, cx| {
+ if let Some(workspace) = this.workspace.upgrade() {
+ workspace.update(cx, |workspace, cx| {
+ struct ThreadSyncedToast;
+ workspace.show_toast(
+ Toast::new(
+ NotificationId::unique::<ThreadSyncedToast>(),
+ "Thread synced with latest version",
+ )
+ .autohide(),
+ cx,
+ );
+ });
+ }
+ })?;
+
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
+
pub fn expand_message_editor(
&mut self,
_: &ExpandMessageEditor,
@@ -4904,6 +5025,13 @@ impl AcpThreadView {
.thread(acp_thread.session_id(), cx)
}
+ fn is_imported_thread(&self, cx: &App) -> bool {
+ let Some(thread) = self.as_native_thread(cx) else {
+ return false;
+ };
+ thread.read(cx).is_imported()
+ }
+
fn is_using_zed_ai_models(&self, cx: &App) -> bool {
self.as_native_thread(cx)
.and_then(|thread| thread.read(cx).model())
@@ -5819,6 +5947,41 @@ impl AcpThreadView {
);
}
+ if cx.has_flag::<AgentSharingFeatureFlag>()
+ && self.is_imported_thread(cx)
+ && self
+ .project
+ .read(cx)
+ .client()
+ .status()
+ .borrow()
+ .is_connected()
+ {
+ let sync_button = IconButton::new("sync-thread", IconName::ArrowCircle)
+ .shape(ui::IconButtonShape::Square)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Ignored)
+ .tooltip(Tooltip::text("Sync with source thread"))
+ .on_click(cx.listener(move |this, _, window, cx| {
+ this.sync_thread(window, cx);
+ }));
+
+ container = container.child(sync_button);
+ }
+
+ if cx.has_flag::<AgentSharingFeatureFlag>() && !self.is_imported_thread(cx) {
+ let share_button = IconButton::new("share-thread", IconName::ArrowUpRight)
+ .shape(ui::IconButtonShape::Square)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Ignored)
+ .tooltip(Tooltip::text("Share Thread"))
+ .on_click(cx.listener(move |this, _, window, cx| {
+ this.share_thread(window, cx);
+ }));
+
+ container = container.child(share_button);
+ }
+
container
.child(open_as_markdown)
.child(scroll_to_recent_user_prompt)
@@ -720,10 +720,25 @@ impl AgentPanel {
&self.prompt_store
}
- pub(crate) fn thread_store(&self) -> &Entity<HistoryStore> {
+ pub fn thread_store(&self) -> &Entity<HistoryStore> {
&self.history_store
}
+ pub fn open_thread(
+ &mut self,
+ thread: DbThreadMetadata,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.external_thread(
+ Some(crate::ExternalAgent::NativeAgent),
+ Some(thread),
+ None,
+ window,
+ cx,
+ );
+ }
+
pub(crate) fn context_server_registry(&self) -> &Entity<ContextServerRegistry> {
&self.context_server_registry
}
@@ -67,3 +67,7 @@ pub fn edit_prediction_docs(cx: &App) -> String {
server_url = server_url(cx)
)
}
+
+pub fn shared_agent_thread_url(session_id: &str) -> String {
+ format!("zed://agent/shared/{}", session_id)
+}
@@ -69,7 +69,10 @@ util.workspace = true
uuid.workspace = true
[dev-dependencies]
+agent = { workspace = true, features = ["test-support"] }
+agent-client-protocol.workspace = true
agent_settings.workspace = true
+agent_ui = { workspace = true, features = ["test-support"] }
assistant_text_thread.workspace = true
assistant_slash_command.workspace = true
async-trait.workspace = true
@@ -460,3 +460,14 @@ CREATE TABLE IF NOT EXISTS "breakpoints" (
);
CREATE INDEX "index_breakpoints_on_project_id" ON "breakpoints" ("project_id");
+
+CREATE TABLE IF NOT EXISTS "shared_threads" (
+ "id" TEXT PRIMARY KEY NOT NULL,
+ "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+ "title" VARCHAR(512) NOT NULL,
+ "data" BLOB NOT NULL,
+ "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE INDEX "index_shared_threads_user_id" ON "shared_threads" ("user_id");
@@ -430,6 +430,15 @@ CREATE SEQUENCE public.servers_id_seq
ALTER SEQUENCE public.servers_id_seq OWNED BY public.servers.id;
+CREATE TABLE public.shared_threads (
+ id uuid NOT NULL,
+ user_id integer NOT NULL,
+ title text NOT NULL,
+ data bytea NOT NULL,
+ created_at timestamp without time zone DEFAULT now() NOT NULL,
+ updated_at timestamp without time zone DEFAULT now() NOT NULL
+);
+
CREATE TABLE public.user_features (
user_id integer NOT NULL,
feature_id integer NOT NULL
@@ -630,6 +639,9 @@ ALTER TABLE ONLY public.rooms
ALTER TABLE ONLY public.servers
ADD CONSTRAINT servers_pkey PRIMARY KEY (id);
+ALTER TABLE ONLY public.shared_threads
+ ADD CONSTRAINT shared_threads_pkey PRIMARY KEY (id);
+
ALTER TABLE ONLY public.user_features
ADD CONSTRAINT user_features_pkey PRIMARY KEY (user_id, feature_id);
@@ -648,6 +660,8 @@ ALTER TABLE ONLY public.worktree_settings_files
ALTER TABLE ONLY public.worktrees
ADD CONSTRAINT worktrees_pkey PRIMARY KEY (project_id, id);
+CREATE INDEX idx_shared_threads_user_id ON public.shared_threads USING btree (user_id);
+
CREATE INDEX index_access_tokens_user_id ON public.access_tokens USING btree (user_id);
CREATE INDEX index_breakpoints_on_project_id ON public.breakpoints USING btree (project_id);
@@ -879,6 +893,9 @@ ALTER TABLE ONLY public.room_participants
ALTER TABLE ONLY public.rooms
ADD CONSTRAINT rooms_channel_id_fkey FOREIGN KEY (channel_id) REFERENCES public.channels(id) ON DELETE CASCADE;
+ALTER TABLE ONLY public.shared_threads
+ ADD CONSTRAINT shared_threads_user_id_fkey FOREIGN KEY (user_id) REFERENCES public.users(id) ON DELETE CASCADE;
+
ALTER TABLE ONLY public.user_features
ADD CONSTRAINT user_features_feature_id_fkey FOREIGN KEY (feature_id) REFERENCES public.feature_flags(id) ON DELETE CASCADE;
@@ -2,6 +2,7 @@ use crate::Result;
use rpc::proto;
use sea_orm::{DbErr, entity::prelude::*};
use serde::{Deserialize, Serialize};
+use uuid::Uuid;
#[macro_export]
macro_rules! id_type {
@@ -92,6 +93,39 @@ id_type!(ServerId);
id_type!(SignupId);
id_type!(UserId);
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, DeriveValueType)]
+pub struct SharedThreadId(pub Uuid);
+
+impl SharedThreadId {
+ pub fn from_proto(id: String) -> Option<Self> {
+ Uuid::parse_str(&id).ok().map(SharedThreadId)
+ }
+
+ pub fn to_proto(self) -> String {
+ self.0.to_string()
+ }
+}
+
+impl sea_orm::TryFromU64 for SharedThreadId {
+ fn try_from_u64(_n: u64) -> std::result::Result<Self, DbErr> {
+ Err(DbErr::ConvertFromU64(
+ "SharedThreadId uses UUID and cannot be converted from u64",
+ ))
+ }
+}
+
+impl sea_orm::sea_query::Nullable for SharedThreadId {
+ fn null() -> Value {
+ Value::Uuid(None)
+ }
+}
+
+impl std::fmt::Display for SharedThreadId {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
/// ChannelRole gives you permissions for both channels and calls.
#[derive(
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
@@ -10,4 +10,5 @@ pub mod notifications;
pub mod projects;
pub mod rooms;
pub mod servers;
+pub mod shared_threads;
pub mod users;
@@ -0,0 +1,77 @@
+use chrono::Utc;
+
+use super::*;
+use crate::db::tables::shared_thread;
+
+impl Database {
+ pub async fn upsert_shared_thread(
+ &self,
+ id: SharedThreadId,
+ user_id: UserId,
+ title: &str,
+ data: Vec<u8>,
+ ) -> Result<()> {
+ let title = title.to_string();
+ self.transaction(|tx| {
+ let title = title.clone();
+ let data = data.clone();
+ async move {
+ let now = Utc::now().naive_utc();
+
+ let existing = shared_thread::Entity::find_by_id(id).one(&*tx).await?;
+
+ match existing {
+ Some(existing) => {
+ if existing.user_id != user_id {
+ Err(anyhow!("Cannot update shared thread owned by another user"))?;
+ }
+
+ let mut active: shared_thread::ActiveModel = existing.into();
+ active.title = ActiveValue::Set(title);
+ active.data = ActiveValue::Set(data);
+ active.updated_at = ActiveValue::Set(now);
+ active.update(&*tx).await?;
+ }
+ None => {
+ shared_thread::ActiveModel {
+ id: ActiveValue::Set(id),
+ user_id: ActiveValue::Set(user_id),
+ title: ActiveValue::Set(title),
+ data: ActiveValue::Set(data),
+ created_at: ActiveValue::Set(now),
+ updated_at: ActiveValue::Set(now),
+ }
+ .insert(&*tx)
+ .await?;
+ }
+ }
+
+ Ok(())
+ }
+ })
+ .await
+ }
+
+ pub async fn get_shared_thread(
+ &self,
+ share_id: SharedThreadId,
+ ) -> Result<Option<(shared_thread::Model, String)>> {
+ self.transaction(|tx| async move {
+ let Some(thread) = shared_thread::Entity::find_by_id(share_id)
+ .one(&*tx)
+ .await?
+ else {
+ return Ok(None);
+ };
+
+ let user = user::Entity::find_by_id(thread.user_id).one(&*tx).await?;
+
+ let username = user
+ .map(|u| u.github_login)
+ .unwrap_or_else(|| "Unknown".to_string());
+
+ Ok(Some((thread, username)))
+ })
+ .await
+ }
+}
@@ -22,6 +22,7 @@ pub mod project_repository_statuses;
pub mod room;
pub mod room_participant;
pub mod server;
+pub mod shared_thread;
pub mod user;
pub mod worktree;
pub mod worktree_diagnostic_summary;
@@ -0,0 +1,32 @@
+use crate::db::{SharedThreadId, UserId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "shared_threads")]
+pub struct Model {
+ #[sea_orm(primary_key, auto_increment = false)]
+ pub id: SharedThreadId,
+ pub user_id: UserId,
+ pub title: String,
+ pub data: Vec<u8>,
+ pub created_at: DateTime,
+ pub updated_at: DateTime,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::user::Entity",
+ from = "Column::UserId",
+ to = "super::user::Column::Id"
+ )]
+ User,
+}
+
+impl Related<super::user::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::User.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -586,3 +586,121 @@ async fn test_fuzzy_search_users(cx: &mut gpui::TestAppContext) {
.collect::<Vec<_>>()
}
}
+
+test_both_dbs!(
+ test_upsert_shared_thread,
+ test_upsert_shared_thread_postgres,
+ test_upsert_shared_thread_sqlite
+);
+
+async fn test_upsert_shared_thread(db: &Arc<Database>) {
+ use crate::db::SharedThreadId;
+ use uuid::Uuid;
+
+ let user_id = new_test_user(db, "user1@example.com").await;
+
+ let thread_id = SharedThreadId(Uuid::new_v4());
+ let title = "My Test Thread";
+ let data = b"test thread data".to_vec();
+
+ db.upsert_shared_thread(thread_id, user_id, title, data.clone())
+ .await
+ .unwrap();
+
+ let result = db.get_shared_thread(thread_id).await.unwrap();
+ assert!(result.is_some(), "Should find the shared thread");
+
+ let (thread, username) = result.unwrap();
+ assert_eq!(thread.title, title);
+ assert_eq!(thread.data, data);
+ assert_eq!(thread.user_id, user_id);
+ assert_eq!(username, "user1");
+}
+
+test_both_dbs!(
+ test_upsert_shared_thread_updates_existing,
+ test_upsert_shared_thread_updates_existing_postgres,
+ test_upsert_shared_thread_updates_existing_sqlite
+);
+
+async fn test_upsert_shared_thread_updates_existing(db: &Arc<Database>) {
+ use crate::db::SharedThreadId;
+ use uuid::Uuid;
+
+ let user_id = new_test_user(db, "user1@example.com").await;
+
+ let thread_id = SharedThreadId(Uuid::new_v4());
+
+ // Create initial thread.
+ db.upsert_shared_thread(
+ thread_id,
+ user_id,
+ "Original Title",
+ b"original data".to_vec(),
+ )
+ .await
+ .unwrap();
+
+ // Update the same thread.
+ db.upsert_shared_thread(
+ thread_id,
+ user_id,
+ "Updated Title",
+ b"updated data".to_vec(),
+ )
+ .await
+ .unwrap();
+
+ let result = db.get_shared_thread(thread_id).await.unwrap();
+ let (thread, _) = result.unwrap();
+
+ assert_eq!(thread.title, "Updated Title");
+ assert_eq!(thread.data, b"updated data".to_vec());
+}
+
+test_both_dbs!(
+ test_cannot_update_another_users_shared_thread,
+ test_cannot_update_another_users_shared_thread_postgres,
+ test_cannot_update_another_users_shared_thread_sqlite
+);
+
+async fn test_cannot_update_another_users_shared_thread(db: &Arc<Database>) {
+ use crate::db::SharedThreadId;
+ use uuid::Uuid;
+
+ let user1_id = new_test_user(db, "user1@example.com").await;
+ let user2_id = new_test_user(db, "user2@example.com").await;
+
+ let thread_id = SharedThreadId(Uuid::new_v4());
+
+ db.upsert_shared_thread(thread_id, user1_id, "User 1 Thread", b"user1 data".to_vec())
+ .await
+ .unwrap();
+
+ let result = db
+ .upsert_shared_thread(thread_id, user2_id, "User 2 Title", b"user2 data".to_vec())
+ .await;
+
+ assert!(
+ result.is_err(),
+ "Should not allow updating another user's thread"
+ );
+}
+
+test_both_dbs!(
+ test_get_nonexistent_shared_thread,
+ test_get_nonexistent_shared_thread_postgres,
+ test_get_nonexistent_shared_thread_sqlite
+);
+
+async fn test_get_nonexistent_shared_thread(db: &Arc<Database>) {
+ use crate::db::SharedThreadId;
+ use uuid::Uuid;
+
+ let result = db
+ .get_shared_thread(SharedThreadId(Uuid::new_v4()))
+ .await
+ .unwrap();
+
+ assert!(result.is_none(), "Should not find non-existent thread");
+}
@@ -6,7 +6,8 @@ use crate::{
db::{
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
- RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
+ RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
+ UserId,
},
executor::Executor,
};
@@ -465,7 +466,9 @@ impl Server {
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
.add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
- .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
+ .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
+ .add_request_handler(share_agent_thread)
+ .add_request_handler(get_shared_agent_thread);
Arc::new(server)
}
@@ -4016,6 +4019,54 @@ fn project_left(project: &db::LeftProject, session: &Session) {
}
}
+async fn share_agent_thread(
+ request: proto::ShareAgentThread,
+ response: Response<proto::ShareAgentThread>,
+ session: MessageContext,
+) -> Result<()> {
+ let user_id = session.user_id();
+
+ let share_id = SharedThreadId::from_proto(request.session_id.clone())
+ .ok_or_else(|| anyhow!("Invalid session ID format"))?;
+
+ session
+ .db()
+ .await
+ .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
+ .await?;
+
+ response.send(proto::Ack {})?;
+
+ Ok(())
+}
+
+async fn get_shared_agent_thread(
+ request: proto::GetSharedAgentThread,
+ response: Response<proto::GetSharedAgentThread>,
+ session: MessageContext,
+) -> Result<()> {
+ let share_id = SharedThreadId::from_proto(request.session_id)
+ .ok_or_else(|| anyhow!("Invalid session ID format"))?;
+
+ let result = session.db().await.get_shared_thread(share_id).await?;
+
+ match result {
+ Some((thread, username)) => {
+ response.send(proto::GetSharedAgentThreadResponse {
+ title: thread.title,
+ thread_data: thread.data,
+ sharer_username: username,
+ created_at: thread.created_at.and_utc().to_rfc3339(),
+ })?;
+ }
+ None => {
+ return Err(anyhow!("Shared thread not found").into());
+ }
+ }
+
+ Ok(())
+}
+
pub trait ResultExt {
type Ok;
@@ -2,6 +2,7 @@ use call::Room;
use client::ChannelId;
use gpui::{Entity, TestAppContext};
+mod agent_sharing_tests;
mod channel_buffer_tests;
mod channel_guest_tests;
mod channel_tests;
@@ -0,0 +1,217 @@
+use agent::SharedThread;
+use gpui::{BackgroundExecutor, TestAppContext};
+use rpc::proto;
+use uuid::Uuid;
+
+use crate::tests::TestServer;
+
+#[gpui::test]
+async fn test_share_and_retrieve_thread(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+
+ executor.run_until_parked();
+
+ let session_id = Uuid::new_v4().to_string();
+
+ let original_thread = SharedThread {
+ title: "Shared Test Thread".into(),
+ messages: vec![],
+ updated_at: chrono::Utc::now(),
+ model: None,
+ completion_mode: None,
+ version: SharedThread::VERSION.to_string(),
+ };
+
+ let thread_data = original_thread
+ .to_bytes()
+ .expect("Failed to serialize thread");
+
+ client_a
+ .client()
+ .request(proto::ShareAgentThread {
+ session_id: session_id.clone(),
+ title: original_thread.title.to_string(),
+ thread_data,
+ })
+ .await
+ .expect("Failed to share thread");
+
+ let get_response = client_b
+ .client()
+ .request(proto::GetSharedAgentThread {
+ session_id: session_id.clone(),
+ })
+ .await
+ .expect("Failed to get shared thread");
+
+ let imported_shared_thread =
+ SharedThread::from_bytes(&get_response.thread_data).expect("Failed to deserialize thread");
+
+ assert_eq!(imported_shared_thread.title, original_thread.title);
+ assert_eq!(imported_shared_thread.version, SharedThread::VERSION);
+
+ let db_thread = imported_shared_thread.to_db_thread();
+
+ assert!(
+ db_thread.title.starts_with("🔗"),
+ "Imported thread title should have link prefix"
+ );
+ assert!(
+ db_thread.title.contains("Shared Test Thread"),
+ "Imported thread should preserve original title"
+ );
+}
+
+#[gpui::test]
+async fn test_reshare_updates_existing_thread(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+
+ executor.run_until_parked();
+
+ let session_id = Uuid::new_v4().to_string();
+
+ client_a
+ .client()
+ .request(proto::ShareAgentThread {
+ session_id: session_id.clone(),
+ title: "Original Title".to_string(),
+ thread_data: b"original data".to_vec(),
+ })
+ .await
+ .expect("Failed to share thread");
+
+ client_a
+ .client()
+ .request(proto::ShareAgentThread {
+ session_id: session_id.clone(),
+ title: "Updated Title".to_string(),
+ thread_data: b"updated data".to_vec(),
+ })
+ .await
+ .expect("Failed to re-share thread");
+
+ let get_response = client_b
+ .client()
+ .request(proto::GetSharedAgentThread {
+ session_id: session_id.clone(),
+ })
+ .await
+ .expect("Failed to get shared thread");
+
+ assert_eq!(get_response.title, "Updated Title");
+ assert_eq!(get_response.thread_data, b"updated data".to_vec());
+}
+
+#[gpui::test]
+async fn test_get_nonexistent_thread(executor: BackgroundExecutor, cx: &mut TestAppContext) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client = server.create_client(cx, "user_a").await;
+
+ executor.run_until_parked();
+
+ let nonexistent_session_id = Uuid::new_v4().to_string();
+
+ let result = client
+ .client()
+ .request(proto::GetSharedAgentThread {
+ session_id: nonexistent_session_id,
+ })
+ .await;
+
+ assert!(result.is_err(), "Should fail for nonexistent thread");
+}
+
+#[gpui::test]
+async fn test_sync_imported_thread(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+
+ executor.run_until_parked();
+
+ let session_id = Uuid::new_v4().to_string();
+
+ // User A shares a thread with initial content.
+ let initial_thread = SharedThread {
+ title: "Initial Title".into(),
+ messages: vec![],
+ updated_at: chrono::Utc::now(),
+ model: None,
+ completion_mode: None,
+ version: SharedThread::VERSION.to_string(),
+ };
+
+ client_a
+ .client()
+ .request(proto::ShareAgentThread {
+ session_id: session_id.clone(),
+ title: initial_thread.title.to_string(),
+ thread_data: initial_thread.to_bytes().expect("Failed to serialize"),
+ })
+ .await
+ .expect("Failed to share thread");
+
+ // User B imports the thread.
+ let initial_response = client_b
+ .client()
+ .request(proto::GetSharedAgentThread {
+ session_id: session_id.clone(),
+ })
+ .await
+ .expect("Failed to get shared thread");
+
+ let initial_imported =
+ SharedThread::from_bytes(&initial_response.thread_data).expect("Failed to deserialize");
+ assert_eq!(initial_imported.title.as_ref(), "Initial Title");
+
+ // User A updates the shared thread.
+ let updated_thread = SharedThread {
+ title: "Updated Title".into(),
+ messages: vec![],
+ updated_at: chrono::Utc::now(),
+ model: None,
+ completion_mode: None,
+ version: SharedThread::VERSION.to_string(),
+ };
+
+ client_a
+ .client()
+ .request(proto::ShareAgentThread {
+ session_id: session_id.clone(),
+ title: updated_thread.title.to_string(),
+ thread_data: updated_thread.to_bytes().expect("Failed to serialize"),
+ })
+ .await
+ .expect("Failed to re-share thread");
+
+ // User B syncs the imported thread (fetches the latest version).
+ let synced_response = client_b
+ .client()
+ .request(proto::GetSharedAgentThread {
+ session_id: session_id.clone(),
+ })
+ .await
+ .expect("Failed to sync shared thread");
+
+ let synced_thread =
+ SharedThread::from_bytes(&synced_response.thread_data).expect("Failed to deserialize");
+
+ // The synced thread should have the updated title.
+ assert_eq!(synced_thread.title.as_ref(), "Updated Title");
+}
@@ -29,3 +29,9 @@ pub struct AcpBetaFeatureFlag;
impl FeatureFlag for AcpBetaFeatureFlag {
const NAME: &'static str = "acp-beta";
}
+
+pub struct AgentSharingFeatureFlag;
+
+impl FeatureFlag for AgentSharingFeatureFlag {
+ const NAME: &'static str = "agent-sharing";
+}
@@ -218,3 +218,20 @@ message NewExternalAgentVersionAvailable {
string name = 2;
string version = 3;
}
+
+message ShareAgentThread {
+ string session_id = 1; // Client-generated UUID (acp::SessionId)
+ string title = 2;
+ bytes thread_data = 3;
+}
+
+message GetSharedAgentThread {
+ string session_id = 1; // UUID string
+}
+
+message GetSharedAgentThreadResponse {
+ string title = 1;
+ bytes thread_data = 2;
+ string sharer_username = 3;
+ string created_at = 4;
+}
@@ -447,7 +447,11 @@ message Envelope {
GitRemoveRemote git_remove_remote = 403;
TrustWorktrees trust_worktrees = 404;
- RestrictWorktrees restrict_worktrees = 405; // current max
+ RestrictWorktrees restrict_worktrees = 405;
+
+ ShareAgentThread share_agent_thread = 406;
+ GetSharedAgentThread get_shared_agent_thread = 407;
+ GetSharedAgentThreadResponse get_shared_agent_thread_response = 408; // current max
}
reserved 87 to 88;
@@ -342,7 +342,10 @@ messages!(
(RemoteStarted, Background),
(GitGetWorktrees, Background),
(GitWorktreesResponse, Background),
- (GitCreateWorktree, Background)
+ (GitCreateWorktree, Background),
+ (ShareAgentThread, Foreground),
+ (GetSharedAgentThread, Foreground),
+ (GetSharedAgentThreadResponse, Foreground)
);
request_messages!(
@@ -441,6 +444,8 @@ request_messages!(
(SendChannelMessage, SendChannelMessageResponse),
(SetChannelMemberRole, Ack),
(SetChannelVisibility, Ack),
+ (ShareAgentThread, Ack),
+ (GetSharedAgentThread, GetSharedAgentThreadResponse),
(ShareProject, ShareProjectResponse),
(SynchronizeBuffers, SynchronizeBuffersResponse),
(TaskContextForLocation, TaskContext),
@@ -50,6 +50,8 @@ required-features = ["visual-tests"]
[dependencies]
acp_tools.workspace = true
activity_indicator.workspace = true
+agent.workspace = true
+agent-client-protocol.workspace = true
agent_settings.workspace = true
agent_ui.workspace = true
agent_ui_v2.workspace = true
@@ -4,6 +4,8 @@
mod reliability;
mod zed;
+use agent::{HistoryStore, SharedThread};
+use agent_client_protocol;
use agent_ui::AgentPanel;
use anyhow::{Context as _, Error, Result};
use clap::Parser;
@@ -33,6 +35,7 @@ use assets::Assets;
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use parking_lot::Mutex;
use project::{project_settings::ProjectSettings, trusted_worktrees};
+use proto;
use recent_projects::{RemoteSettings, open_remote_project};
use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
use session::{AppSession, Session};
@@ -837,6 +840,73 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut
})
.detach_and_log_err(cx);
}
+ OpenRequestKind::SharedAgentThread { session_id } => {
+ cx.spawn(async move |cx| {
+ let workspace =
+ workspace::get_any_active_workspace(app_state.clone(), cx.clone()).await?;
+
+ let (client, history_store) =
+ workspace.update(cx, |workspace, _window, cx| {
+ let client = workspace.project().read(cx).client();
+ let history_store: Option<gpui::Entity<HistoryStore>> = workspace
+ .panel::<AgentPanel>(cx)
+ .map(|panel| panel.read(cx).thread_store().clone());
+ (client, history_store)
+ })?;
+
+ let Some(history_store): Option<gpui::Entity<HistoryStore>> = history_store
+ else {
+ anyhow::bail!("Agent panel not available");
+ };
+
+ let response = client
+ .request(proto::GetSharedAgentThread {
+ session_id: session_id.clone(),
+ })
+ .await
+ .context("Failed to fetch shared thread")?;
+
+ let shared_thread = SharedThread::from_bytes(&response.thread_data)?;
+ let db_thread = shared_thread.to_db_thread();
+ let session_id = agent_client_protocol::SessionId::new(session_id);
+
+ history_store
+ .update(&mut cx.clone(), |store, cx| {
+ store.save_thread(session_id.clone(), db_thread, cx)
+ })?
+ .await?;
+
+ let thread_metadata = agent::DbThreadMetadata {
+ id: session_id,
+ title: format!("🔗 {}", response.title).into(),
+ updated_at: chrono::Utc::now(),
+ };
+
+ workspace.update(cx, |workspace, window, cx| {
+ if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
+ panel.update(cx, |panel, cx| {
+ panel.open_thread(thread_metadata, window, cx);
+ });
+ panel.focus_handle(cx).focus(window, cx);
+ }
+ })?;
+
+ workspace.update(cx, |workspace, _window, cx| {
+ struct ImportedThreadToast;
+ workspace.show_toast(
+ Toast::new(
+ NotificationId::unique::<ImportedThreadToast>(),
+ format!("Imported shared thread from {}", response.sharer_username),
+ )
+ .autohide(),
+ cx,
+ );
+ })?;
+
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
OpenRequestKind::DockMenuAction { index } => {
cx.perform_dock_menu_action(index);
}
@@ -4,6 +4,7 @@ pub mod edit_prediction_registry;
pub(crate) mod mac_only_instance;
mod migrate;
mod open_listener;
+mod open_url_modal;
mod quick_action_bar;
#[cfg(all(target_os = "macos", any(test, feature = "test-support")))]
pub mod visual_tests;
@@ -141,6 +142,8 @@ actions!(
/// audio system (including yourself) on the current call in a tar file
/// in the current working directory.
CaptureRecentAudio,
+ /// Opens a prompt to enter a URL to open.
+ OpenUrlPrompt,
]
);
@@ -823,6 +826,11 @@ fn register_actions(
..Default::default()
})
})
+ .register_action(|workspace, _: &OpenUrlPrompt, window, cx| {
+ workspace.toggle_modal(window, cx, |window, cx| {
+ open_url_modal::OpenUrlModal::new(window, cx)
+ });
+ })
.register_action(|workspace, action: &OpenBrowser, _window, cx| {
// Parse and validate the URL to ensure it's properly formatted
match url::Url::parse(&action.url) {
@@ -49,6 +49,9 @@ pub enum OpenRequestKind {
extension_id: String,
},
AgentPanel,
+ SharedAgentThread {
+ session_id: String,
+ },
DockMenuAction {
index: usize,
},
@@ -107,6 +110,14 @@ impl OpenRequest {
});
} else if url == "zed://agent" {
this.kind = Some(OpenRequestKind::AgentPanel);
+ } else if let Some(session_id_str) = url.strip_prefix("zed://agent/shared/") {
+ if uuid::Uuid::parse_str(session_id_str).is_ok() {
+ this.kind = Some(OpenRequestKind::SharedAgentThread {
+ session_id: session_id_str.to_string(),
+ });
+ } else {
+ log::error!("Invalid session ID in URL: {}", session_id_str);
+ }
} else if let Some(schema_path) = url.strip_prefix("zed://schemas/") {
this.kind = Some(OpenRequestKind::BuiltinJsonSchema {
schema_path: schema_path.to_string(),
@@ -0,0 +1,116 @@
+use editor::Editor;
+use gpui::{AppContext as _, DismissEvent, Entity, EventEmitter, Focusable, ReadGlobal, Styled};
+use ui::{
+ ActiveTheme, App, Color, Context, FluentBuilder, InteractiveElement, IntoElement, Label,
+ LabelCommon, LabelSize, ParentElement, Render, SharedString, StyledExt, Window, div, h_flex,
+ v_flex,
+};
+use workspace::ModalView;
+
+use super::{OpenListener, RawOpenRequest};
+
+pub struct OpenUrlModal {
+ editor: Entity<Editor>,
+ last_error: Option<SharedString>,
+}
+
+impl EventEmitter<DismissEvent> for OpenUrlModal {}
+impl ModalView for OpenUrlModal {}
+
+impl Focusable for OpenUrlModal {
+ fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
+ self.editor.focus_handle(cx)
+ }
+}
+
+impl OpenUrlModal {
+ pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
+ let editor = cx.new(|cx| {
+ let mut editor = Editor::single_line(window, cx);
+ editor.set_placeholder_text("zed://...", window, cx);
+ editor
+ });
+
+ Self {
+ editor,
+ last_error: None,
+ }
+ }
+
+ fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
+ cx.emit(DismissEvent);
+ }
+
+ fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
+ let url = self.editor.update(cx, |editor, cx| {
+ let text = editor.text(cx).trim().to_string();
+ editor.clear(window, cx);
+ text
+ });
+
+ if url.is_empty() {
+ cx.emit(DismissEvent);
+ return;
+ }
+
+ // Handle zed:// URLs internally.
+ if url.starts_with("zed://") || url.starts_with("zed-cli://") {
+ OpenListener::global(cx).open(RawOpenRequest {
+ urls: vec![url],
+ ..Default::default()
+ });
+ cx.emit(DismissEvent);
+ return;
+ }
+
+ match url::Url::parse(&url) {
+ Ok(parsed_url) => {
+ cx.open_url(parsed_url.as_str());
+ cx.emit(DismissEvent);
+ }
+ Err(e) => {
+ self.last_error = Some(format!("Invalid URL: {}", e).into());
+ cx.notify();
+ }
+ }
+ }
+}
+
+impl Render for OpenUrlModal {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let theme = cx.theme();
+
+ v_flex()
+ .key_context("OpenUrlModal")
+ .on_action(cx.listener(Self::cancel))
+ .on_action(cx.listener(Self::confirm))
+ .elevation_3(cx)
+ .w_96()
+ .overflow_hidden()
+ .child(
+ div()
+ .p_2()
+ .border_b_1()
+ .border_color(theme.colors().border_variant)
+ .child(self.editor.clone()),
+ )
+ .child(
+ h_flex()
+ .bg(theme.colors().editor_background)
+ .rounded_b_sm()
+ .w_full()
+ .p_2()
+ .gap_1()
+ .when_some(self.last_error.clone(), |this, error| {
+ this.child(Label::new(error).size(LabelSize::Small).color(Color::Error))
+ })
+ .when(self.last_error.is_none(), |this| {
+ this.child(
+ Label::new("Paste a URL to open.")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ )
+ }),
+ )
+ }
+}