// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

package event

import (
	"context"
	"encoding/binary"
	"encoding/json"
	"errors"
	"fmt"
	"strings"
	"time"

	"git.secluded.site/np/internal/db"
	"git.secluded.site/np/internal/timeutil"
)

// ErrInvalidType indicates that an event type was not recognised.
var ErrInvalidType = errors.New("event: invalid type")

// ErrEmptyCommand indicates that a command string was not provided.
var ErrEmptyCommand = errors.New("event: command is required")

// Type enumerates known event kinds.
type Type string

const (
	TypeGoalSet          Type = "goal_set"
	TypeGoalUpdated      Type = "goal_updated"
	TypeTaskAdded        Type = "task_added"
	TypeTaskUpdated      Type = "task_updated"
	TypeTaskStatusChange Type = "task_status_changed"
)

// Valid reports whether t is a recognised event type.
func (t Type) Valid() bool {
	switch t {
	case TypeGoalSet,
		TypeGoalUpdated,
		TypeTaskAdded,
		TypeTaskUpdated,
		TypeTaskStatusChange:
		return true
	default:
		return false
	}
}

// Record represents a stored event entry.
type Record struct {
	Seq     uint64          `json:"seq"`
	At      time.Time       `json:"at"`
	Type    Type            `json:"type"`
	Reason  *string         `json:"reason,omitempty"`
	Command string          `json:"cmd"`
	Payload json.RawMessage `json:"payload"`
}

// HasReason reports whether a reason was supplied.
func (r Record) HasReason() bool {
	return r.Reason != nil && *r.Reason != ""
}

// UnmarshalPayload decodes the payload into dst.
func (r Record) UnmarshalPayload(dst any) error {
	if len(r.Payload) == 0 {
		return errors.New("event: payload is empty")
	}
	return json.Unmarshal(r.Payload, dst)
}

// AppendInput captures the data necessary to append an event.
type AppendInput struct {
	Type    Type
	Command string
	Reason  string
	Payload any
	At      time.Time
}

// ListOptions controls event listing.
type ListOptions struct {
	// After skips events with seq <= After.
	After uint64
	// Limit restricts the number of events returned. Zero or negative returns all.
	Limit int
}

// Store provides high-level helpers for working with events.
type Store struct {
	db    *db.Database
	clock timeutil.Clock
}

// NewStore constructs a Store. When clock is nil, a UTC system clock is used.
func NewStore(database *db.Database, clock timeutil.Clock) *Store {
	if clock == nil {
		clock = timeutil.UTCClock{}
	}
	return &Store{
		db:    database,
		clock: clock,
	}
}

// WithTxn exposes transactional helpers for use within db.Update.
func (s *Store) WithTxn(txn *db.Txn) TxnStore {
	return TxnStore{
		txn:   txn,
		clock: s.clock,
	}
}

// Append records an event for sid.
func (s *Store) Append(ctx context.Context, sid string, input AppendInput) (Record, error) {
	var rec Record
	err := s.db.Update(ctx, func(txn *db.Txn) error {
		var err error
		rec, err = appendRecord(txn, s.clock, sid, input)
		return err
	})
	return rec, err
}

// List returns events for sid subject to opts.
func (s *Store) List(ctx context.Context, sid string, opts ListOptions) ([]Record, error) {
	var records []Record
	err := s.db.View(ctx, func(txn *db.Txn) error {
		var err error
		records, err = listRecords(txn, sid, opts)
		return err
	})
	return records, err
}

// LatestSequence returns the latest event sequence for sid.
func (s *Store) LatestSequence(ctx context.Context, sid string) (uint64, error) {
	var seq uint64
	err := s.db.View(ctx, func(txn *db.Txn) error {
		var err error
		seq, err = latestSequence(txn, sid)
		return err
	})
	return seq, err
}

// TxnStore wraps a db transaction for event operations.
type TxnStore struct {
	txn   *db.Txn
	clock timeutil.Clock
}

// Append records an event using the wrapped transaction.
func (s TxnStore) Append(sid string, input AppendInput) (Record, error) {
	return appendRecord(s.txn, s.clock, sid, input)
}

// List returns events for sid using opts.
func (s TxnStore) List(sid string, opts ListOptions) ([]Record, error) {
	return listRecords(s.txn, sid, opts)
}

// LatestSequence returns the most recent event sequence.
func (s TxnStore) LatestSequence(sid string) (uint64, error) {
	return latestSequence(s.txn, sid)
}

// Prefix returns the subscription prefix for events in sid.
func Prefix(sid string) []byte {
	return db.PrefixSessionEvents(sid)
}

// Key returns the event key for seq within sid.
func Key(sid string, seq uint64) []byte {
	return db.KeySessionEvent(sid, seq)
}

// SequenceCounterKey returns the key that stores the latest sequence for sid.
func SequenceCounterKey(sid string) []byte {
	return db.KeySessionEventSeq(sid)
}

var nullPayload = json.RawMessage("null")

func appendRecord(txn *db.Txn, clock timeutil.Clock, sid string, input AppendInput) (Record, error) {
	if txn == nil {
		return Record{}, errors.New("event: transaction is nil")
	}
	if !input.Type.Valid() {
		return Record{}, ErrInvalidType
	}

	cmd := strings.TrimSpace(input.Command)
	if cmd == "" {
		return Record{}, ErrEmptyCommand
	}

	at := input.At
	if clock == nil {
		clock = timeutil.UTCClock{}
	}
	if at.IsZero() {
		at = clock.Now()
	}
	at = timeutil.EnsureUTC(at)

	payload, err := marshalPayload(input.Payload)
	if err != nil {
		return Record{}, err
	}

	seq, err := txn.IncrementUint64(db.KeySessionEventSeq(sid), 1)
	if err != nil {
		return Record{}, err
	}

	var reasonPtr *string
	if trimmed := strings.TrimSpace(input.Reason); trimmed != "" {
		reasonPtr = &trimmed
	}

	record := Record{
		Seq:     seq,
		At:      at,
		Type:    input.Type,
		Reason:  reasonPtr,
		Command: cmd,
		Payload: payload,
	}

	if err := txn.SetJSON(db.KeySessionEvent(sid, seq), record); err != nil {
		return Record{}, err
	}

	return record, nil
}

func listRecords(txn *db.Txn, sid string, opts ListOptions) ([]Record, error) {
	if txn == nil {
		return nil, errors.New("event: transaction is nil")
	}

	var records []Record
	err := txn.Iterate(db.IterateOptions{
		Prefix:         db.PrefixSessionEvents(sid),
		PrefetchValues: true,
	}, func(item db.Item) error {
		var rec Record
		if err := item.ValueJSON(&rec); err != nil {
			return err
		}

		if rec.Seq <= opts.After {
			return nil
		}

		records = append(records, rec)
		if opts.Limit > 0 && len(records) >= opts.Limit {
			return db.ErrTxnAborted
		}

		return nil
	})
	if err != nil && !errors.Is(err, db.ErrTxnAborted) {
		return nil, err
	}

	return records, nil
}

func latestSequence(txn *db.Txn, sid string) (uint64, error) {
	if txn == nil {
		return 0, errors.New("event: transaction is nil")
	}

	key := db.KeySessionEventSeq(sid)
	exists, err := txn.Exists(key)
	if err != nil {
		return 0, err
	}
	if !exists {
		return 0, nil
	}

	value, err := txn.Get(key)
	if err != nil {
		return 0, err
	}

	if len(value) != 8 {
		return 0, fmt.Errorf("event: corrupt sequence counter for session %q", sid)
	}

	return binary.BigEndian.Uint64(value), nil
}

func marshalPayload(payload any) (json.RawMessage, error) {
	if payload == nil {
		return nullPayload, nil
	}
	data, err := json.Marshal(payload)
	if err != nil {
		return nil, err
	}
	if len(data) == 0 {
		return nullPayload, nil
	}
	return json.RawMessage(data), nil
}
