//usr/bin/env go run "$0" "$@"; exit
package main

import (
	"bytes"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"net/http"
	"os"
	"os/exec"
	"path/filepath"
	"regexp"
	"sort"
	"strings"
	"sync"
	"time"
)

const (
	syntheticAPI = "https://api.synthetic.new/anthropic/v1/messages/count_tokens"
	model        = "hf:moonshotai/Kimi-K2.5"
	workerCount  = 5 // Number of parallel API workers
)

var httpClient = &http.Client{
	Timeout: 30 * time.Second,
}

type Frontmatter struct {
	Name        string
	Description string
}

type TokenCount struct {
	Name        int
	Description int
	Body        int
	References  map[string]int
	Total       int
}

type SkillInfo struct {
	Dir         string
	Frontmatter Frontmatter
	BodyLines   int
	Tokens      TokenCount
	Errors      []string
}

type TokenJob struct {
	ID   string
	Text string
}

type TokenResult struct {
	ID    string
	Count int
	Err   error
}

type SkillComparison struct {
	PrevTotal      int
	PrevMetadata   int
	PrevBody       int
	Delta          int
	MetadataDelta  int
	BodyDelta      int
	Percent        float64
	IsNew          bool
}

func main() {
	compare := flag.Bool("compare", false, "Compare with HEAD commit")
	workers := flag.Int("workers", workerCount, "Number of parallel API workers")
	flag.Parse()

	apiKey := os.Getenv("SYNTHETIC_API_KEY")
	if apiKey == "" {
		fmt.Fprintln(os.Stderr, "Error: SYNTHETIC_API_KEY environment variable not set")
		os.Exit(1)
	}

	// Start worker pool
	counter := newTokenCounter(apiKey, *workers)
	defer counter.Close()

	skills, err := analyzeSkills(counter)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Error: %v\n", err)
		os.Exit(1)
	}

	// Build comparison map if requested
	var comparisons map[string]SkillComparison
	if *compare {
		comparisons = buildComparisons(skills, counter)
	}

	// Sort skills by name for consistent output
	sort.Slice(skills, func(i, j int) bool {
		return skills[i].Dir < skills[j].Dir
	})

	// Print reports
	for _, skill := range skills {
		var comp *SkillComparison
		if comparisons != nil {
			if c, ok := comparisons[skill.Dir]; ok {
				comp = &c
			}
		}
		printSkillReport(skill, comp)
	}

	// Print summary
	printSummary(skills, comparisons)
}

// TokenCounter manages a pool of workers for parallel token counting
type TokenCounter struct {
	apiKey  string
	jobs    chan TokenJob
	results chan TokenResult
	wg      sync.WaitGroup
}

func newTokenCounter(apiKey string, workers int) *TokenCounter {
	tc := &TokenCounter{
		apiKey:  apiKey,
		jobs:    make(chan TokenJob, 100),
		results: make(chan TokenResult, 100),
	}

	// Start workers
	for i := 0; i < workers; i++ {
		tc.wg.Add(1)
		go tc.worker()
	}

	return tc
}

func (tc *TokenCounter) worker() {
	defer tc.wg.Done()
	for job := range tc.jobs {
		count, err := countTokensAPI(tc.apiKey, job.Text)
		tc.results <- TokenResult{ID: job.ID, Count: count, Err: err}
	}
}

func (tc *TokenCounter) Count(id, text string) {
	tc.jobs <- TokenJob{ID: id, Text: text}
}

func (tc *TokenCounter) GetResult() TokenResult {
	return <-tc.results
}

func (tc *TokenCounter) TryGetResult() (TokenResult, bool) {
	select {
	case r := <-tc.results:
		return r, true
	default:
		return TokenResult{}, false
	}
}

func (tc *TokenCounter) Close() {
	close(tc.jobs)
	tc.wg.Wait()
	close(tc.results)
}

func analyzeSkills(counter *TokenCounter) ([]SkillInfo, error) {
	skillsDir := "skills"
	entries, err := os.ReadDir(skillsDir)
	if err != nil {
		return nil, fmt.Errorf("cannot read skills directory: %w", err)
	}

	var skills []SkillInfo
	for _, entry := range entries {
		if !entry.IsDir() {
			continue
		}

		skillPath := filepath.Join(skillsDir, entry.Name())
		skill, err := analyzeSkill(skillPath, counter)
		if err != nil {
			fmt.Fprintf(os.Stderr, "Warning: error analyzing %s: %v\n", entry.Name(), err)
			continue
		}
		skills = append(skills, skill)
	}

	return skills, nil
}

func analyzeSkill(path string, counter *TokenCounter) (SkillInfo, error) {
	skill := SkillInfo{
		Dir: filepath.Base(path),
		Tokens: TokenCount{
			References: make(map[string]int),
		},
	}

	// Read SKILL.md
	skillMdPath := filepath.Join(path, "SKILL.md")
	content, err := os.ReadFile(skillMdPath)
	if err != nil {
		skill.Errors = append(skill.Errors, fmt.Sprintf("Cannot read SKILL.md: %v", err))
		return skill, nil
	}

	// Parse frontmatter and body
	fm, body, err := parseFrontmatter(string(content))
	if err != nil {
		skill.Errors = append(skill.Errors, fmt.Sprintf("Cannot parse frontmatter: %v", err))
		return skill, nil
	}
	skill.Frontmatter = fm
	trimmedBody := strings.TrimSpace(body)
	if trimmedBody == "" {
		skill.BodyLines = 0
	} else {
		skill.BodyLines = len(strings.Split(trimmedBody, "\n"))
	}

	// Validate
	skill.Errors = append(skill.Errors, validateSkill(skill)...)

	fmt.Fprintf(os.Stderr, "Analyzing %s...\n", skill.Dir)

	// Collect all jobs first (no channel use yet)
	var jobs []TokenJob
	jobs = append(jobs, TokenJob{ID: fmt.Sprintf("%s:name", skill.Dir), Text: fm.Name})
	jobs = append(jobs, TokenJob{ID: fmt.Sprintf("%s:description", skill.Dir), Text: fm.Description})
	jobs = append(jobs, TokenJob{ID: fmt.Sprintf("%s:body", skill.Dir), Text: body})

	// Collect reference file jobs
	refsPath := filepath.Join(path, "references")
	entries, err := os.ReadDir(refsPath)
	if err != nil {
		if !os.IsNotExist(err) {
			skill.Errors = append(skill.Errors, fmt.Sprintf("Cannot read references directory: %v", err))
		}
	} else {
		for _, entry := range entries {
			if entry.IsDir() {
				continue
			}
			refPath := filepath.Join(refsPath, entry.Name())
			refContent, err := os.ReadFile(refPath)
			if err != nil {
				skill.Errors = append(skill.Errors, fmt.Sprintf("Cannot read reference %s: %v", entry.Name(), err))
				continue
			}
			jobs = append(jobs, TokenJob{
				ID:   fmt.Sprintf("%s:ref:%s", skill.Dir, entry.Name()),
				Text: string(refContent),
			})
		}
	}

	// Interleave enqueue and drain to prevent deadlock
	processResult := func(result TokenResult) {
		if result.Err != nil {
			skill.Errors = append(skill.Errors, fmt.Sprintf("Token count failed for %s: %v", result.ID, result.Err))
			return
		}
		parts := strings.SplitN(result.ID, ":", 3)
		if len(parts) < 2 {
			return
		}
		switch parts[1] {
		case "name":
			skill.Tokens.Name = result.Count
		case "description":
			skill.Tokens.Description = result.Count
		case "body":
			skill.Tokens.Body = result.Count
		case "ref":
			if len(parts) == 3 {
				skill.Tokens.References[parts[2]] = result.Count
			}
		}
	}

	outstanding := 0
	for _, job := range jobs {
		counter.Count(job.ID, job.Text)
		outstanding++
		// Drain any available results to prevent backpressure
		for {
			if result, ok := counter.TryGetResult(); ok {
				processResult(result)
				outstanding--
			} else {
				break
			}
		}
	}

	// Drain remaining results
	for outstanding > 0 {
		result := counter.GetResult()
		processResult(result)
		outstanding--
	}

	// Calculate total
	skill.Tokens.Total = skill.Tokens.Name + skill.Tokens.Description + skill.Tokens.Body
	for _, count := range skill.Tokens.References {
		skill.Tokens.Total += count
	}

	return skill, nil
}

func parseFrontmatter(content string) (Frontmatter, string, error) {
	lines := strings.Split(content, "\n")
	if len(lines) < 3 || lines[0] != "---" {
		return Frontmatter{}, "", fmt.Errorf("missing frontmatter")
	}

	var fm Frontmatter
	var endIdx int
	var inDescription bool
	var descriptionLines []string

	for i := 1; i < len(lines); i++ {
		if lines[i] == "---" {
			endIdx = i
			break
		}

		line := lines[i]

		// Parse name
		if strings.HasPrefix(line, "name:") {
			fm.Name = strings.TrimSpace(strings.TrimPrefix(line, "name:"))
			continue
		}

		// Parse description (might be multi-line)
		if strings.HasPrefix(line, "description:") {
			descPart := strings.TrimSpace(strings.TrimPrefix(line, "description:"))
			if descPart != "" {
				descriptionLines = append(descriptionLines, descPart)
			}
			inDescription = true
			continue
		}

		// Continue multi-line description
		if inDescription && strings.HasPrefix(line, "  ") {
			descriptionLines = append(descriptionLines, strings.TrimSpace(line))
			continue
		}

		// End of description
		if inDescription && !strings.HasPrefix(line, "  ") {
			inDescription = false
		}
	}

	fm.Description = strings.Join(descriptionLines, " ")

	if endIdx == 0 {
		return Frontmatter{}, "", fmt.Errorf("unclosed frontmatter")
	}

	body := strings.Join(lines[endIdx+1:], "\n")
	return fm, body, nil
}

func validateSkill(skill SkillInfo) []string {
	var errors []string

	// Validate name
	if len(skill.Frontmatter.Name) < 1 || len(skill.Frontmatter.Name) > 64 {
		errors = append(errors, "name must be 1-64 characters")
	}

	namePattern := regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)
	if !namePattern.MatchString(skill.Frontmatter.Name) {
		errors = append(errors, "name must be lowercase letters, numbers, hyphens only; no leading/trailing/consecutive hyphens")
	}

	if skill.Frontmatter.Name != skill.Dir {
		errors = append(errors, fmt.Sprintf("name '%s' doesn't match directory '%s'", skill.Frontmatter.Name, skill.Dir))
	}

	// Validate description
	if len(skill.Frontmatter.Description) < 1 {
		errors = append(errors, "description is empty")
	} else if len(skill.Frontmatter.Description) > 1024 {
		errors = append(errors, fmt.Sprintf("description is %d characters (max 1024)", len(skill.Frontmatter.Description)))
	}

	// Check body line count
	if skill.BodyLines > 500 {
		errors = append(errors, fmt.Sprintf("body has %d lines (recommended: < 500)", skill.BodyLines))
	}

	return errors
}

func countTokensAPI(apiKey string, text string) (int, error) {
	reqBody := map[string]interface{}{
		"model": model,
		"messages": []map[string]string{
			{
				"role":    "user",
				"content": text,
			},
		},
	}

	jsonData, err := json.Marshal(reqBody)
	if err != nil {
		return 0, fmt.Errorf("marshal request: %w", err)
	}

	req, err := http.NewRequest("POST", syntheticAPI, bytes.NewBuffer(jsonData))
	if err != nil {
		return 0, fmt.Errorf("create request: %w", err)
	}

	req.Header.Set("Authorization", "Bearer "+apiKey)
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("anthropic-version", "2023-06-01")

	resp, err := httpClient.Do(req)
	if err != nil {
		return 0, fmt.Errorf("HTTP request: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
		body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
		return 0, fmt.Errorf("API status %d: %s", resp.StatusCode, string(body))
	}

	var result struct {
		InputTokens int `json:"input_tokens"`
	}

	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
		return 0, fmt.Errorf("decode response: %w", err)
	}

	return result.InputTokens, nil
}

type PrevTokens struct {
	Total    int
	Metadata int
	Body     int
}

func buildComparisons(currentSkills []SkillInfo, counter *TokenCounter) map[string]SkillComparison {
	comparisons := make(map[string]SkillComparison)

	for _, skill := range currentSkills {
		prev, err := getSkillTokensFromGit(skill.Dir, counter)
		if err != nil {
			// Skill is new
			if skill.Tokens.Total > 0 {
				comparisons[skill.Dir] = SkillComparison{
					PrevTotal:     0,
					PrevMetadata:  0,
					PrevBody:      0,
					Delta:         skill.Tokens.Total,
					MetadataDelta: skill.Tokens.Name + skill.Tokens.Description,
					BodyDelta:     skill.Tokens.Body,
					Percent:       100.0,
					IsNew:         true,
				}
			}
			continue
		}

		delta := skill.Tokens.Total - prev.Total
		metadataDelta := (skill.Tokens.Name + skill.Tokens.Description) - prev.Metadata
		bodyDelta := skill.Tokens.Body - prev.Body
		var percent float64
		if prev.Total > 0 {
			percent = (float64(delta) / float64(prev.Total)) * 100
		}

		if delta != 0 || metadataDelta != 0 || bodyDelta != 0 {
			comparisons[skill.Dir] = SkillComparison{
				PrevTotal:     prev.Total,
				PrevMetadata:  prev.Metadata,
				PrevBody:      prev.Body,
				Delta:         delta,
				MetadataDelta: metadataDelta,
				BodyDelta:     bodyDelta,
				Percent:       percent,
				IsNew:         false,
			}
		}
	}

	return comparisons
}

func getSkillTokensFromGit(skillDir string, counter *TokenCounter) (PrevTokens, error) {
	// Get file from HEAD
	skillPath := fmt.Sprintf("skills/%s/SKILL.md", skillDir)
	cmd := exec.Command("git", "show", fmt.Sprintf("HEAD:%s", skillPath))
	output, err := cmd.Output()
	if err != nil {
		return PrevTokens{}, err
	}

	// Parse frontmatter and body
	fm, body, err := parseFrontmatter(string(output))
	if err != nil {
		return PrevTokens{}, err
	}

	// Collect all jobs first (no channel use yet)
	var jobs []TokenJob
	jobs = append(jobs, TokenJob{ID: fmt.Sprintf("prev:%s:name", skillDir), Text: fm.Name})
	jobs = append(jobs, TokenJob{ID: fmt.Sprintf("prev:%s:description", skillDir), Text: fm.Description})
	jobs = append(jobs, TokenJob{ID: fmt.Sprintf("prev:%s:body", skillDir), Text: body})

	// Get reference files from HEAD
	refsPath := fmt.Sprintf("skills/%s/references", skillDir)
	cmd = exec.Command("git", "ls-tree", "-r", "--name-only", "HEAD", refsPath)
	output, err = cmd.Output()
	if err != nil {
		// Log non-fatal git errors (e.g., refs directory doesn't exist in HEAD)
		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() != 0 {
			fmt.Fprintf(os.Stderr, "Warning: cannot list git refs for %s (may not exist in HEAD)\n", refsPath)
		}
	} else {
		refPaths := strings.Split(strings.TrimSpace(string(output)), "\n")
		for _, refPath := range refPaths {
			if refPath == "" {
				continue
			}
			cmd = exec.Command("git", "show", fmt.Sprintf("HEAD:%s", refPath))
			refContent, err := cmd.Output()
			if err != nil {
				fmt.Fprintf(os.Stderr, "Warning: cannot read %s from HEAD: %v\n", refPath, err)
				continue
			}
			jobs = append(jobs, TokenJob{ID: fmt.Sprintf("prev:%s:ref", skillDir), Text: string(refContent)})
		}
	}

	// Interleave enqueue and drain to prevent deadlock
	prev := PrevTokens{}
	processResult := func(result TokenResult) {
		if result.Err != nil {
			fmt.Fprintf(os.Stderr, "Warning: token count failed for %s: %v\n", result.ID, result.Err)
			return
		}
		parts := strings.SplitN(result.ID, ":", 3)
		if len(parts) < 3 {
			prev.Total += result.Count
			return
		}
		switch parts[2] {
		case "name":
			prev.Metadata += result.Count
		case "description":
			prev.Metadata += result.Count
		case "body":
			prev.Body += result.Count
		case "ref":
			// References are counted in total but not metadata or body
		}
		prev.Total += result.Count
	}

	outstanding := 0
	for _, job := range jobs {
		counter.Count(job.ID, job.Text)
		outstanding++
		// Drain any available results to prevent backpressure
		for {
			if result, ok := counter.TryGetResult(); ok {
				processResult(result)
				outstanding--
			} else {
				break
			}
		}
	}

	// Drain remaining results
	for outstanding > 0 {
		result := counter.GetResult()
		processResult(result)
		outstanding--
	}

	return prev, nil
}

func printSkillReport(skill SkillInfo, comp *SkillComparison) {
	fmt.Printf("\n=== %s ===\n", skill.Dir)

	if len(skill.Errors) > 0 {
		fmt.Println("\nValidation errors:")
		for _, err := range skill.Errors {
			fmt.Printf("  ✗ %s\n", err)
		}
	}

	fmt.Println("\nToken breakdown:")
	fmt.Printf("  Name:        %5d tokens\n", skill.Tokens.Name)
	fmt.Printf("  Description: %5d tokens\n", skill.Tokens.Description)
	fmt.Printf("  Body:        %5d tokens (%d lines)\n", skill.Tokens.Body, skill.BodyLines)

	if len(skill.Tokens.References) > 0 {
		fmt.Println("  References:")
		// Sort reference names for consistent output
		refNames := make([]string, 0, len(skill.Tokens.References))
		for name := range skill.Tokens.References {
			refNames = append(refNames, name)
		}
		sort.Strings(refNames)

		for _, name := range refNames {
			count := skill.Tokens.References[name]
			fmt.Printf("    %-40s %5d tokens\n", name, count)
		}
	}

	fmt.Println("  ───────────────────────────────────────────────")

	// Print total with comparison if available
	if comp != nil {
		sign := "+"
		if comp.Delta < 0 {
			sign = ""
		}
		indicator := ""
		if comp.IsNew {
			indicator = " [NEW]"
		} else if comp.Percent > 20 {
			indicator = " ⚠️"
		} else if comp.Percent < -20 {
			indicator = " ✓"
		}
		fmt.Printf("  Total:       %5d tokens (%s%d, %s%.1f%% from HEAD)%s\n",
			skill.Tokens.Total, sign, comp.Delta, sign, comp.Percent, indicator)
	} else {
		fmt.Printf("  Total:       %5d tokens\n", skill.Tokens.Total)
	}

	// Warn if approaching budget
	if skill.Tokens.Body > 5000 {
		fmt.Println("  ⚠️  Body exceeds recommended 5000 token budget!")
	} else if skill.Tokens.Body > 4000 {
		fmt.Println("  ⚠️  Body approaching 5000 token budget")
	}
}

func printSummary(skills []SkillInfo, comparisons map[string]SkillComparison) {
	fmt.Println("\n" + strings.Repeat("=", 60))
	fmt.Println("SUMMARY")
	fmt.Println(strings.Repeat("=", 60))

	totalTokens := 0
	totalMetadataTokens := 0
	totalBodyTokens := 0
	totalErrors := 0
	totalDelta := 0
	metadataDelta := 0
	bodyDelta := 0

	for _, skill := range skills {
		totalTokens += skill.Tokens.Total
		totalMetadataTokens += skill.Tokens.Name + skill.Tokens.Description
		totalBodyTokens += skill.Tokens.Body
		totalErrors += len(skill.Errors)
		if comp, ok := comparisons[skill.Dir]; ok {
			totalDelta += comp.Delta
			metadataDelta += comp.MetadataDelta
			bodyDelta += comp.BodyDelta
		}
	}

	fmt.Printf("\nSkills: %d\n", len(skills))
	if comparisons != nil && metadataDelta != 0 {
		fmt.Printf("Metadata: %d tokens (%+d)\n", totalMetadataTokens, metadataDelta)
	} else {
		fmt.Printf("Metadata: %d tokens\n", totalMetadataTokens)
	}
	if comparisons != nil && bodyDelta != 0 {
		fmt.Printf("Combined bodies: %d tokens (%+d)\n", totalBodyTokens, bodyDelta)
	} else {
		fmt.Printf("Combined bodies: %d tokens\n", totalBodyTokens)
	}
	if comparisons != nil && totalDelta != 0 {
		fmt.Printf("Overall: %d tokens (%+d from HEAD)\n", totalTokens, totalDelta)
	} else {
		fmt.Printf("Overall: %d tokens\n", totalTokens)
	}
	fmt.Printf("Validation errors: %d\n", totalErrors)

	// Find largest skills
	sort.Slice(skills, func(i, j int) bool {
		return skills[i].Tokens.Total > skills[j].Tokens.Total
	})

	fmt.Println("\nLargest skills (by total tokens):")
	for i := 0; i < 5 && i < len(skills); i++ {
		skill := skills[i]
		if comp, ok := comparisons[skill.Dir]; ok {
			sign := "+"
			if comp.Delta < 0 {
				sign = ""
			}
			fmt.Printf("  %d. %-40s %5d tokens (%s%d)\n",
				i+1, skill.Dir, skill.Tokens.Total, sign, comp.Delta)
		} else {
			fmt.Printf("  %d. %-40s %5d tokens\n", i+1, skill.Dir, skill.Tokens.Total)
		}
	}

	// Show biggest changes if comparing
	if comparisons != nil && len(comparisons) > 0 {
		type changeEntry struct {
			name string
			comp SkillComparison
		}
		var changes []changeEntry
		for name, comp := range comparisons {
			changes = append(changes, changeEntry{name, comp})
		}

		sort.Slice(changes, func(i, j int) bool {
			absI := changes[i].comp.Delta
			if absI < 0 {
				absI = -absI
			}
			absJ := changes[j].comp.Delta
			if absJ < 0 {
				absJ = -absJ
			}
			return absI > absJ
		})

		fmt.Println("\nBiggest changes:")
		displayed := 0
		for _, change := range changes {
			if displayed >= 5 {
				break
			}
			sign := "+"
			if change.comp.Delta < 0 {
				sign = ""
			}
			indicator := ""
			if change.comp.IsNew {
				indicator = " [NEW]"
			} else if change.comp.Percent > 20 {
				indicator = " ⚠️"
			} else if change.comp.Percent < -20 {
				indicator = " ✓"
			}
			fmt.Printf("  %-40s %s%-5d tokens (%s%.1f%%)%s\n",
				change.name, sign, change.comp.Delta, sign, change.comp.Percent, indicator)
			displayed++
		}
	}
}
