use std::str::FromStr;
use std::sync::Arc;

use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumIter, EnumString};
use uuid::Uuid;

/// The name of the header used to indicate which version of Zed the client is running.
pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";

/// The name of the header used to indicate when a request failed due to an
/// expired LLM token.
///
/// The client may use this as a signal to refresh the token.
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";

/// The name of the header used to indicate what plan the user is currently on.
pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";

/// The name of the header used to indicate the usage limit for model requests.
pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";

/// The name of the header used to indicate the usage amount for model requests.
pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";

/// The name of the header used to indicate the usage limit for edit predictions.
pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";

/// The name of the header used to indicate the usage amount for edit predictions.
pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";

/// The name of the header used to indicate the resource for which the subscription limit has been reached.
pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";

pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";

/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";

/// The name of the header used to indicate the the minimum required Zed version.
///
/// This can be used to force a Zed upgrade in order to continue communicating
/// with the LLM service.
pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";

/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
    "x-zed-client-supports-status-messages";

/// The name of the header used by the server to indicate to the client that it supports sending status messages.
pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
    "x-zed-server-supports-status-messages";

#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UsageLimit {
    Limited(i32),
    Unlimited,
}

impl FromStr for UsageLimit {
    type Err = anyhow::Error;

    fn from_str(value: &str) -> Result<Self, Self::Err> {
        match value {
            "unlimited" => Ok(Self::Unlimited),
            limit => limit
                .parse::<i32>()
                .map(Self::Limited)
                .context("failed to parse limit"),
        }
    }
}

#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Plan {
    #[default]
    #[serde(alias = "Free")]
    ZedFree,
    #[serde(alias = "ZedPro")]
    ZedPro,
    #[serde(alias = "ZedProTrial")]
    ZedProTrial,
}

impl Plan {
    pub fn as_str(&self) -> &'static str {
        match self {
            Plan::ZedFree => "zed_free",
            Plan::ZedPro => "zed_pro",
            Plan::ZedProTrial => "zed_pro_trial",
        }
    }

    pub fn model_requests_limit(&self) -> UsageLimit {
        match self {
            Plan::ZedPro => UsageLimit::Limited(500),
            Plan::ZedProTrial => UsageLimit::Limited(150),
            Plan::ZedFree => UsageLimit::Limited(50),
        }
    }

    pub fn edit_predictions_limit(&self) -> UsageLimit {
        match self {
            Plan::ZedPro => UsageLimit::Unlimited,
            Plan::ZedProTrial => UsageLimit::Unlimited,
            Plan::ZedFree => UsageLimit::Limited(2_000),
        }
    }
}

impl FromStr for Plan {
    type Err = anyhow::Error;

    fn from_str(value: &str) -> Result<Self, Self::Err> {
        match value {
            "zed_free" => Ok(Plan::ZedFree),
            "zed_pro" => Ok(Plan::ZedPro),
            "zed_pro_trial" => Ok(Plan::ZedProTrial),
            plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
        }
    }
}

#[derive(
    Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum LanguageModelProvider {
    Anthropic,
    OpenAi,
    Google,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsBody {
    #[serde(skip_serializing_if = "Option::is_none", default)]
    pub outline: Option<String>,
    pub input_events: String,
    pub input_excerpt: String,
    #[serde(skip_serializing_if = "Option::is_none", default)]
    pub speculated_output: Option<String>,
    /// Whether the user provided consent for sampling this interaction.
    #[serde(default, alias = "data_collection_permission")]
    pub can_collect_data: bool,
    #[serde(skip_serializing_if = "Option::is_none", default)]
    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsResponse {
    pub request_id: Uuid,
    pub output_excerpt: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AcceptEditPredictionBody {
    pub request_id: Uuid,
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionMode {
    Normal,
    Max,
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionIntent {
    UserPrompt,
    ToolResults,
    ThreadSummarization,
    ThreadContextSummarization,
    CreateFile,
    EditFile,
    InlineAssist,
    TerminalInlineAssist,
    GenerateGitCommitMessage,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionBody {
    #[serde(skip_serializing_if = "Option::is_none", default)]
    pub thread_id: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none", default)]
    pub prompt_id: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none", default)]
    pub intent: Option<CompletionIntent>,
    #[serde(skip_serializing_if = "Option::is_none", default)]
    pub mode: Option<CompletionMode>,
    pub provider: LanguageModelProvider,
    pub model: String,
    pub provider_request: serde_json::Value,
}

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionRequestStatus {
    Queued {
        position: usize,
    },
    Started,
    Failed {
        code: String,
        message: String,
        request_id: Uuid,
        /// Retry duration in seconds.
        retry_after: Option<f64>,
    },
    UsageUpdated {
        amount: usize,
        limit: UsageLimit,
    },
    ToolUseLimitReached,
}

#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionEvent<T> {
    Status(CompletionRequestStatus),
    Event(T),
}

impl<T> CompletionEvent<T> {
    pub fn into_status(self) -> Option<CompletionRequestStatus> {
        match self {
            Self::Status(status) => Some(status),
            Self::Event(_) => None,
        }
    }

    pub fn into_event(self) -> Option<T> {
        match self {
            Self::Event(event) => Some(event),
            Self::Status(_) => None,
        }
    }
}

#[derive(Serialize, Deserialize)]
pub struct WebSearchBody {
    pub query: String,
}

#[derive(Serialize, Deserialize, Clone)]
pub struct WebSearchResponse {
    pub results: Vec<WebSearchResult>,
}

#[derive(Serialize, Deserialize, Clone)]
pub struct WebSearchResult {
    pub title: String,
    pub url: String,
    pub text: String,
}

#[derive(Serialize, Deserialize)]
pub struct CountTokensBody {
    pub provider: LanguageModelProvider,
    pub model: String,
    pub provider_request: serde_json::Value,
}

#[derive(Serialize, Deserialize)]
pub struct CountTokensResponse {
    pub tokens: usize,
}

#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelId(pub Arc<str>);

impl std::fmt::Display for LanguageModelId {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LanguageModel {
    pub provider: LanguageModelProvider,
    pub id: LanguageModelId,
    pub display_name: String,
    pub max_token_count: usize,
    pub max_token_count_in_max_mode: Option<usize>,
    pub max_output_tokens: usize,
    pub supports_tools: bool,
    pub supports_images: bool,
    pub supports_thinking: bool,
    pub supports_max_mode: bool,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ListModelsResponse {
    pub models: Vec<LanguageModel>,
    pub default_model: LanguageModelId,
    pub default_fast_model: LanguageModelId,
    pub recommended_models: Vec<LanguageModelId>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct GetSubscriptionResponse {
    pub plan: Plan,
    pub usage: Option<CurrentUsage>,
}

#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct CurrentUsage {
    pub model_requests: UsageData,
    pub edit_predictions: UsageData,
}

#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct UsageData {
    pub used: u32,
    pub limit: UsageLimit,
}

#[cfg(test)]
mod tests {
    use pretty_assertions::assert_eq;
    use serde_json::json;

    use super::*;

    #[test]
    fn test_plan_deserialize_snake_case() {
        let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
        assert_eq!(plan, Plan::ZedFree);

        let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
        assert_eq!(plan, Plan::ZedPro);

        let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
        assert_eq!(plan, Plan::ZedProTrial);
    }

    #[test]
    fn test_plan_deserialize_aliases() {
        let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
        assert_eq!(plan, Plan::ZedFree);

        let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
        assert_eq!(plan, Plan::ZedPro);

        let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
        assert_eq!(plan, Plan::ZedProTrial);
    }

    #[test]
    fn test_usage_limit_from_str() {
        let limit = UsageLimit::from_str("unlimited").unwrap();
        assert!(matches!(limit, UsageLimit::Unlimited));

        let limit = UsageLimit::from_str(&0.to_string()).unwrap();
        assert!(matches!(limit, UsageLimit::Limited(0)));

        let limit = UsageLimit::from_str(&50.to_string()).unwrap();
        assert!(matches!(limit, UsageLimit::Limited(50)));

        for value in ["not_a_number", "50xyz"] {
            let limit = UsageLimit::from_str(value);
            assert!(limit.is_err());
        }
    }
}
