use std::path::PathBuf;
use std::sync::Arc;

use anyhow::{anyhow, Result};
use assistant_tool::{ToolId, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use futures::future::{self, BoxFuture, Shared};
use futures::FutureExt as _;
use gpui::{
    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
};
use heed::types::{SerdeBincode, SerdeJson};
use heed::Database;
use language_model::{LanguageModelToolUseId, Role};
use project::Project;
use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize};
use util::ResultExt as _;

use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId};

pub fn init(cx: &mut App) {
    ThreadsDatabase::init(cx);
}

pub struct ThreadStore {
    project: Entity<Project>,
    tools: Arc<ToolWorkingSet>,
    prompt_builder: Arc<PromptBuilder>,
    context_server_manager: Entity<ContextServerManager>,
    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
    threads: Vec<SerializedThreadMetadata>,
}

impl ThreadStore {
    pub fn new(
        project: Entity<Project>,
        tools: Arc<ToolWorkingSet>,
        prompt_builder: Arc<PromptBuilder>,
        cx: &mut App,
    ) -> Result<Entity<Self>> {
        let this = cx.new(|cx| {
            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
            let context_server_manager = cx.new(|cx| {
                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
            });

            let this = Self {
                project,
                tools,
                prompt_builder,
                context_server_manager,
                context_server_tool_ids: HashMap::default(),
                threads: Vec::new(),
            };
            this.register_context_server_handlers(cx);
            this.reload(cx).detach_and_log_err(cx);

            this
        });

        Ok(this)
    }

    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
        self.context_server_manager.clone()
    }

    pub fn tools(&self) -> Arc<ToolWorkingSet> {
        self.tools.clone()
    }

    /// Returns the number of threads.
    pub fn thread_count(&self) -> usize {
        self.threads.len()
    }

    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
        threads
    }

    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
        self.threads().into_iter().take(limit).collect()
    }

    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
        cx.new(|cx| {
            Thread::new(
                self.project.clone(),
                self.tools.clone(),
                self.prompt_builder.clone(),
                cx,
            )
        })
    }

    pub fn open_thread(
        &self,
        id: &ThreadId,
        cx: &mut Context<Self>,
    ) -> Task<Result<Entity<Thread>>> {
        let id = id.clone();
        let database_future = ThreadsDatabase::global_future(cx);
        cx.spawn(async move |this, cx| {
            let database = database_future.await.map_err(|err| anyhow!(err))?;
            let thread = database
                .try_find_thread(id.clone())
                .await?
                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;

            this.update(cx, |this, cx| {
                cx.new(|cx| {
                    Thread::deserialize(
                        id.clone(),
                        thread,
                        this.project.clone(),
                        this.tools.clone(),
                        this.prompt_builder.clone(),
                        cx,
                    )
                })
            })
        })
    }

    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
        let (metadata, serialized_thread) =
            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));

        let database_future = ThreadsDatabase::global_future(cx);
        cx.spawn(async move |this, cx| {
            let serialized_thread = serialized_thread.await?;
            let database = database_future.await.map_err(|err| anyhow!(err))?;
            database.save_thread(metadata, serialized_thread).await?;

            this.update(cx, |this, cx| this.reload(cx))?.await
        })
    }

    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
        let id = id.clone();
        let database_future = ThreadsDatabase::global_future(cx);
        cx.spawn(async move |this, cx| {
            let database = database_future.await.map_err(|err| anyhow!(err))?;
            database.delete_thread(id.clone()).await?;

            this.update(cx, |this, _cx| {
                this.threads.retain(|thread| thread.id != id)
            })
        })
    }

    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
        let database_future = ThreadsDatabase::global_future(cx);
        cx.spawn(async move |this, cx| {
            let threads = database_future
                .await
                .map_err(|err| anyhow!(err))?
                .list_threads()
                .await?;

            this.update(cx, |this, cx| {
                this.threads = threads;
                cx.notify();
            })
        })
    }

    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
        cx.subscribe(
            &self.context_server_manager.clone(),
            Self::handle_context_server_event,
        )
        .detach();
    }

    fn handle_context_server_event(
        &mut self,
        context_server_manager: Entity<ContextServerManager>,
        event: &context_server::manager::Event,
        cx: &mut Context<Self>,
    ) {
        let tool_working_set = self.tools.clone();
        match event {
            context_server::manager::Event::ServerStarted { server_id } => {
                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
                    let context_server_manager = context_server_manager.clone();
                    cx.spawn({
                        let server = server.clone();
                        let server_id = server_id.clone();
                        async move |this, cx| {
                            let Some(protocol) = server.client() else {
                                return;
                            };

                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
                                if let Some(tools) = protocol.list_tools().await.log_err() {
                                    let tool_ids = tools
                                        .tools
                                        .into_iter()
                                        .map(|tool| {
                                            log::info!(
                                                "registering context server tool: {:?}",
                                                tool.name
                                            );
                                            tool_working_set.insert(Arc::new(
                                                ContextServerTool::new(
                                                    context_server_manager.clone(),
                                                    server.id(),
                                                    tool,
                                                ),
                                            ))
                                        })
                                        .collect::<Vec<_>>();

                                    this.update(cx, |this, _cx| {
                                        this.context_server_tool_ids.insert(server_id, tool_ids);
                                    })
                                    .log_err();
                                }
                            }
                        }
                    })
                    .detach();
                }
            }
            context_server::manager::Event::ServerStopped { server_id } => {
                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
                    tool_working_set.remove(&tool_ids);
                }
            }
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedThreadMetadata {
    pub id: ThreadId,
    pub summary: SharedString,
    pub updated_at: DateTime<Utc>,
}

#[derive(Serialize, Deserialize)]
pub struct SerializedThread {
    pub summary: SharedString,
    pub updated_at: DateTime<Utc>,
    pub messages: Vec<SerializedMessage>,
    #[serde(default)]
    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedMessage {
    pub id: MessageId,
    pub role: Role,
    pub text: String,
    #[serde(default)]
    pub tool_uses: Vec<SerializedToolUse>,
    #[serde(default)]
    pub tool_results: Vec<SerializedToolResult>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedToolUse {
    pub id: LanguageModelToolUseId,
    pub name: SharedString,
    pub input: serde_json::Value,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedToolResult {
    pub tool_use_id: LanguageModelToolUseId,
    pub is_error: bool,
    pub content: Arc<str>,
}

struct GlobalThreadsDatabase(
    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
);

impl Global for GlobalThreadsDatabase {}

pub(crate) struct ThreadsDatabase {
    executor: BackgroundExecutor,
    env: heed::Env,
    threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
}

impl ThreadsDatabase {
    fn global_future(
        cx: &mut App,
    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
        GlobalThreadsDatabase::global(cx).0.clone()
    }

    fn init(cx: &mut App) {
        let executor = cx.background_executor().clone();
        let database_future = executor
            .spawn({
                let executor = executor.clone();
                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
                async move { ThreadsDatabase::new(database_path, executor) }
            })
            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
            .boxed()
            .shared();

        cx.set_global(GlobalThreadsDatabase(database_future));
    }

    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
        std::fs::create_dir_all(&path)?;

        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
        let env = unsafe {
            heed::EnvOpenOptions::new()
                .map_size(ONE_GB_IN_BYTES)
                .max_dbs(1)
                .open(path)?
        };

        let mut txn = env.write_txn()?;
        let threads = env.create_database(&mut txn, Some("threads"))?;
        txn.commit()?;

        Ok(Self {
            executor,
            env,
            threads,
        })
    }

    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
        let env = self.env.clone();
        let threads = self.threads;

        self.executor.spawn(async move {
            let txn = env.read_txn()?;
            let mut iter = threads.iter(&txn)?;
            let mut threads = Vec::new();
            while let Some((key, value)) = iter.next().transpose()? {
                threads.push(SerializedThreadMetadata {
                    id: key,
                    summary: value.summary,
                    updated_at: value.updated_at,
                });
            }

            Ok(threads)
        })
    }

    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
        let env = self.env.clone();
        let threads = self.threads;

        self.executor.spawn(async move {
            let txn = env.read_txn()?;
            let thread = threads.get(&txn, &id)?;
            Ok(thread)
        })
    }

    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
        let env = self.env.clone();
        let threads = self.threads;

        self.executor.spawn(async move {
            let mut txn = env.write_txn()?;
            threads.put(&mut txn, &id, &thread)?;
            txn.commit()?;
            Ok(())
        })
    }

    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
        let env = self.env.clone();
        let threads = self.threads;

        self.executor.spawn(async move {
            let mut txn = env.write_txn()?;
            threads.delete(&mut txn, &id)?;
            txn.commit()?;
            Ok(())
        })
    }
}
