// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

package auth_test

import (
	"net/http"
	"net/http/httptest"
	"testing"

	"git.secluded.site/lune/internal/mcp/auth"
)

func TestHashAndVerify(t *testing.T) {
	t.Parallel()

	token := "my-secret-token"

	hash, err := auth.Hash(token)
	if err != nil {
		t.Fatalf("Hash() error = %v", err)
	}

	if hash == "" {
		t.Fatal("Hash() returned empty string")
	}

	// Verify with correct token
	if !auth.Verify(token, hash) {
		t.Error("Verify() returned false for correct token")
	}

	// Verify with wrong token
	if auth.Verify("wrong-token", hash) {
		t.Error("Verify() returned true for wrong token")
	}
}

func TestHashUniqueSalts(t *testing.T) {
	t.Parallel()

	token := "same-token"

	hash1, err := auth.Hash(token)
	if err != nil {
		t.Fatalf("Hash() error = %v", err)
	}

	hash2, err := auth.Hash(token)
	if err != nil {
		t.Fatalf("Hash() error = %v", err)
	}

	if hash1 == hash2 {
		t.Error("Hash() produced identical hashes for same token (salt not random)")
	}

	// Both should still verify correctly
	if !auth.Verify(token, hash1) {
		t.Error("Verify() failed for hash1")
	}

	if !auth.Verify(token, hash2) {
		t.Error("Verify() failed for hash2")
	}
}

func TestVerifyInvalidHash(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name string
		hash string
	}{
		{"empty", ""},
		{"garbage", "not-a-valid-hash"},
		{"wrong algorithm", "$bcrypt$v=2$something$salt$hash"},
		{"too few parts", "$argon2id$v=19$m=19456$salt"},
		{"invalid base64 salt", "$argon2id$v=19$m=19456,t=2,p=1$!!!invalid$hash"},
		{"invalid base64 hash", "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$!!!invalid"},
	}

	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			if auth.Verify("token", tc.hash) {
				t.Errorf("Verify() returned true for invalid hash %q", tc.hash)
			}
		})
	}
}

func TestMiddleware(t *testing.T) {
	t.Parallel()

	token := "test-bearer-token"

	hash, err := auth.Hash(token)
	if err != nil {
		t.Fatalf("Hash() error = %v", err)
	}

	handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
		rw.WriteHeader(http.StatusOK)
		_, _ = rw.Write([]byte("OK"))
	})

	protected := auth.Middleware(hash)(handler)

	testCases := middlewareTestCases(token)
	for _, testCase := range testCases {
		t.Run(testCase.name, func(t *testing.T) {
			t.Parallel()
			runMiddlewareTest(t, protected, testCase)
		})
	}
}

type middlewareTestCase struct {
	name           string
	authHeader     string
	wantStatus     int
	wantAuthHeader bool
}

func middlewareTestCases(token string) []middlewareTestCase {
	return []middlewareTestCase{
		{
			name:           "missing header",
			authHeader:     "",
			wantStatus:     http.StatusUnauthorized,
			wantAuthHeader: true,
		},
		{
			name:           "wrong scheme",
			authHeader:     "Basic dXNlcjpwYXNz",
			wantStatus:     http.StatusUnauthorized,
			wantAuthHeader: true,
		},
		{
			name:           "wrong token",
			authHeader:     "Bearer wrong-token",
			wantStatus:     http.StatusUnauthorized,
			wantAuthHeader: true,
		},
		{
			name:           "valid token",
			authHeader:     "Bearer " + token,
			wantStatus:     http.StatusOK,
			wantAuthHeader: false,
		},
	}
}

func runMiddlewareTest(t *testing.T, protected http.Handler, testCase middlewareTestCase) {
	t.Helper()

	req := httptest.NewRequest(http.MethodGet, "/", nil)
	if testCase.authHeader != "" {
		req.Header.Set("Authorization", testCase.authHeader)
	}

	recorder := httptest.NewRecorder()
	protected.ServeHTTP(recorder, req)

	if recorder.Code != testCase.wantStatus {
		t.Errorf("status = %d, want %d", recorder.Code, testCase.wantStatus)
	}

	wwwAuth := recorder.Header().Get("WWW-Authenticate")
	hasAuthHeader := wwwAuth != ""

	if hasAuthHeader != testCase.wantAuthHeader {
		t.Errorf("WWW-Authenticate header present = %v, want %v", hasAuthHeader, testCase.wantAuthHeader)
	}

	if testCase.wantAuthHeader && wwwAuth != `Bearer realm="lune-mcp"` {
		t.Errorf("WWW-Authenticate = %q, want %q", wwwAuth, `Bearer realm="lune-mcp"`)
	}
}

func TestMiddlewareQueryParam(t *testing.T) {
	t.Parallel()

	token := "query-param-token" //nolint:gosec // test fixture, not a real credential

	hash, err := auth.Hash(token)
	if err != nil {
		t.Fatalf("Hash() error = %v", err)
	}

	handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
		rw.WriteHeader(http.StatusOK)
	})

	protected := auth.Middleware(hash)(handler)

	t.Run("valid query param", func(t *testing.T) {
		t.Parallel()

		req := httptest.NewRequest(http.MethodGet, "/?access_token="+token, nil)
		recorder := httptest.NewRecorder()
		protected.ServeHTTP(recorder, req)

		if recorder.Code != http.StatusOK {
			t.Errorf("status = %d, want %d", recorder.Code, http.StatusOK)
		}
	})

	t.Run("invalid query param", func(t *testing.T) {
		t.Parallel()

		req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
		recorder := httptest.NewRecorder()
		protected.ServeHTTP(recorder, req)

		if recorder.Code != http.StatusUnauthorized {
			t.Errorf("status = %d, want %d", recorder.Code, http.StatusUnauthorized)
		}
	})

	t.Run("header takes precedence over query param", func(t *testing.T) {
		t.Parallel()

		req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
		req.Header.Set("Authorization", "Bearer "+token)

		recorder := httptest.NewRecorder()
		protected.ServeHTTP(recorder, req)

		if recorder.Code != http.StatusOK {
			t.Errorf("status = %d, want %d (header should take precedence)", recorder.Code, http.StatusOK)
		}
	})
}
