stats.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"database/sql"
  7	_ "embed"
  8	"encoding/json"
  9	"fmt"
 10	"html/template"
 11	"os"
 12	"os/user"
 13	"path/filepath"
 14	"strings"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/db"
 19	"github.com/pkg/browser"
 20	"github.com/spf13/cobra"
 21)
 22
 23//go:embed stats/index.html
 24var statsTemplate string
 25
 26//go:embed stats/index.css
 27var statsCSS string
 28
 29//go:embed stats/index.js
 30var statsJS string
 31
 32//go:embed stats/header.svg
 33var headerSVG string
 34
 35//go:embed stats/heartbit.svg
 36var heartbitSVG string
 37
 38//go:embed stats/footer.svg
 39var footerSVG string
 40
 41var statsCmd = &cobra.Command{
 42	Use:   "stats",
 43	Short: "Show usage statistics",
 44	Long:  "Generate and display usage statistics including token usage, costs, and activity patterns",
 45	RunE:  runStats,
 46}
 47
 48// Day names for day of week statistics.
 49var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
 50
 51// Stats holds all the statistics data.
 52type Stats struct {
 53	GeneratedAt       time.Time          `json:"generated_at"`
 54	Total             TotalStats         `json:"total"`
 55	UsageByDay        []DailyUsage       `json:"usage_by_day"`
 56	UsageByModel      []ModelUsage       `json:"usage_by_model"`
 57	UsageByHour       []HourlyUsage      `json:"usage_by_hour"`
 58	UsageByDayOfWeek  []DayOfWeekUsage   `json:"usage_by_day_of_week"`
 59	RecentActivity    []DailyActivity    `json:"recent_activity"`
 60	AvgResponseTimeMs float64            `json:"avg_response_time_ms"`
 61	ToolUsage         []ToolUsage        `json:"tool_usage"`
 62	HourDayHeatmap    []HourDayHeatmapPt `json:"hour_day_heatmap"`
 63}
 64
 65type TotalStats struct {
 66	TotalSessions         int64   `json:"total_sessions"`
 67	TotalPromptTokens     int64   `json:"total_prompt_tokens"`
 68	TotalCompletionTokens int64   `json:"total_completion_tokens"`
 69	TotalTokens           int64   `json:"total_tokens"`
 70	TotalCost             float64 `json:"total_cost"`
 71	TotalMessages         int64   `json:"total_messages"`
 72	AvgTokensPerSession   float64 `json:"avg_tokens_per_session"`
 73	AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
 74}
 75
 76type DailyUsage struct {
 77	Day              string  `json:"day"`
 78	PromptTokens     int64   `json:"prompt_tokens"`
 79	CompletionTokens int64   `json:"completion_tokens"`
 80	TotalTokens      int64   `json:"total_tokens"`
 81	Cost             float64 `json:"cost"`
 82	SessionCount     int64   `json:"session_count"`
 83}
 84
 85type ModelUsage struct {
 86	Model        string `json:"model"`
 87	Provider     string `json:"provider"`
 88	MessageCount int64  `json:"message_count"`
 89}
 90
 91type HourlyUsage struct {
 92	Hour         int   `json:"hour"`
 93	SessionCount int64 `json:"session_count"`
 94}
 95
 96type DayOfWeekUsage struct {
 97	DayOfWeek        int    `json:"day_of_week"`
 98	DayName          string `json:"day_name"`
 99	SessionCount     int64  `json:"session_count"`
100	PromptTokens     int64  `json:"prompt_tokens"`
101	CompletionTokens int64  `json:"completion_tokens"`
102}
103
104type DailyActivity struct {
105	Day          string  `json:"day"`
106	SessionCount int64   `json:"session_count"`
107	TotalTokens  int64   `json:"total_tokens"`
108	Cost         float64 `json:"cost"`
109}
110
111type ToolUsage struct {
112	ToolName  string `json:"tool_name"`
113	CallCount int64  `json:"call_count"`
114}
115
116type HourDayHeatmapPt struct {
117	DayOfWeek    int   `json:"day_of_week"`
118	Hour         int   `json:"hour"`
119	SessionCount int64 `json:"session_count"`
120}
121
122func runStats(cmd *cobra.Command, _ []string) error {
123	dataDir, _ := cmd.Flags().GetString("data-dir")
124	ctx := cmd.Context()
125
126	if dataDir == "" {
127		cfg, err := config.Init("", "", false)
128		if err != nil {
129			return fmt.Errorf("failed to initialize config: %w", err)
130		}
131		dataDir = cfg.Options.DataDirectory
132	}
133
134	conn, err := db.Connect(ctx, dataDir)
135	if err != nil {
136		return fmt.Errorf("failed to connect to database: %w", err)
137	}
138	defer conn.Close()
139
140	stats, err := gatherStats(ctx, conn)
141	if err != nil {
142		return fmt.Errorf("failed to gather stats: %w", err)
143	}
144
145	if stats.Total.TotalSessions == 0 {
146		return fmt.Errorf("no data available: no sessions found in database")
147	}
148
149	currentUser, err := user.Current()
150	if err != nil {
151		return fmt.Errorf("failed to get current user: %w", err)
152	}
153	username := currentUser.Username
154	project, err := os.Getwd()
155	if err != nil {
156		return fmt.Errorf("failed to get current directory: %w", err)
157	}
158	project = strings.Replace(project, currentUser.HomeDir, "~", 1)
159
160	htmlPath := filepath.Join(dataDir, "stats/index.html")
161	if err := generateHTML(stats, project, username, htmlPath); err != nil {
162		return fmt.Errorf("failed to generate HTML: %w", err)
163	}
164
165	fmt.Printf("Stats generated: %s\n", htmlPath)
166
167	if err := browser.OpenFile(htmlPath); err != nil {
168		fmt.Printf("Could not open browser: %v\n", err)
169		fmt.Println("Please open the file manually.")
170	}
171
172	return nil
173}
174
175func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
176	queries := db.New(conn)
177
178	stats := &Stats{
179		GeneratedAt: time.Now(),
180	}
181
182	// Total stats.
183	total, err := queries.GetTotalStats(ctx)
184	if err != nil {
185		return nil, fmt.Errorf("get total stats: %w", err)
186	}
187	stats.Total = TotalStats{
188		TotalSessions:         total.TotalSessions,
189		TotalPromptTokens:     toInt64(total.TotalPromptTokens),
190		TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
191		TotalTokens:           toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
192		TotalCost:             toFloat64(total.TotalCost),
193		TotalMessages:         toInt64(total.TotalMessages),
194		AvgTokensPerSession:   toFloat64(total.AvgTokensPerSession),
195		AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
196	}
197
198	// Usage by day.
199	dailyUsage, err := queries.GetUsageByDay(ctx)
200	if err != nil {
201		return nil, fmt.Errorf("get usage by day: %w", err)
202	}
203	for _, d := range dailyUsage {
204		prompt := nullFloat64ToInt64(d.PromptTokens)
205		completion := nullFloat64ToInt64(d.CompletionTokens)
206		stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
207			Day:              fmt.Sprintf("%v", d.Day),
208			PromptTokens:     prompt,
209			CompletionTokens: completion,
210			TotalTokens:      prompt + completion,
211			Cost:             d.Cost.Float64,
212			SessionCount:     d.SessionCount,
213		})
214	}
215
216	// Usage by model.
217	modelUsage, err := queries.GetUsageByModel(ctx)
218	if err != nil {
219		return nil, fmt.Errorf("get usage by model: %w", err)
220	}
221	for _, m := range modelUsage {
222		stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
223			Model:        m.Model,
224			Provider:     m.Provider,
225			MessageCount: m.MessageCount,
226		})
227	}
228
229	// Usage by hour.
230	hourlyUsage, err := queries.GetUsageByHour(ctx)
231	if err != nil {
232		return nil, fmt.Errorf("get usage by hour: %w", err)
233	}
234	for _, h := range hourlyUsage {
235		stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
236			Hour:         int(h.Hour),
237			SessionCount: h.SessionCount,
238		})
239	}
240
241	// Usage by day of week.
242	dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
243	if err != nil {
244		return nil, fmt.Errorf("get usage by day of week: %w", err)
245	}
246	for _, d := range dowUsage {
247		stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
248			DayOfWeek:        int(d.DayOfWeek),
249			DayName:          dayNames[int(d.DayOfWeek)],
250			SessionCount:     d.SessionCount,
251			PromptTokens:     nullFloat64ToInt64(d.PromptTokens),
252			CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
253		})
254	}
255
256	// Recent activity (last 30 days).
257	recent, err := queries.GetRecentActivity(ctx)
258	if err != nil {
259		return nil, fmt.Errorf("get recent activity: %w", err)
260	}
261	for _, r := range recent {
262		stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
263			Day:          fmt.Sprintf("%v", r.Day),
264			SessionCount: r.SessionCount,
265			TotalTokens:  nullFloat64ToInt64(r.TotalTokens),
266			Cost:         r.Cost.Float64,
267		})
268	}
269
270	// Average response time.
271	avgResp, err := queries.GetAverageResponseTime(ctx)
272	if err != nil {
273		return nil, fmt.Errorf("get average response time: %w", err)
274	}
275	stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
276
277	// Tool usage.
278	toolUsage, err := queries.GetToolUsage(ctx)
279	if err != nil {
280		return nil, fmt.Errorf("get tool usage: %w", err)
281	}
282	for _, t := range toolUsage {
283		if name, ok := t.ToolName.(string); ok && name != "" {
284			stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
285				ToolName:  name,
286				CallCount: t.CallCount,
287			})
288		}
289	}
290
291	// Hour/day heatmap.
292	heatmap, err := queries.GetHourDayHeatmap(ctx)
293	if err != nil {
294		return nil, fmt.Errorf("get hour day heatmap: %w", err)
295	}
296	for _, h := range heatmap {
297		stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
298			DayOfWeek:    int(h.DayOfWeek),
299			Hour:         int(h.Hour),
300			SessionCount: h.SessionCount,
301		})
302	}
303
304	return stats, nil
305}
306
307func toInt64(v any) int64 {
308	switch val := v.(type) {
309	case int64:
310		return val
311	case float64:
312		return int64(val)
313	case int:
314		return int64(val)
315	default:
316		return 0
317	}
318}
319
320func toFloat64(v any) float64 {
321	switch val := v.(type) {
322	case float64:
323		return val
324	case int64:
325		return float64(val)
326	case int:
327		return float64(val)
328	default:
329		return 0
330	}
331}
332
333func nullFloat64ToInt64(n sql.NullFloat64) int64 {
334	if n.Valid {
335		return int64(n.Float64)
336	}
337	return 0
338}
339
340func generateHTML(stats *Stats, projName, username, path string) error {
341	statsJSON, err := json.Marshal(stats)
342	if err != nil {
343		return err
344	}
345
346	tmpl, err := template.New("stats").Parse(statsTemplate)
347	if err != nil {
348		return fmt.Errorf("parse template: %w", err)
349	}
350
351	data := struct {
352		StatsJSON   template.JS
353		CSS         template.CSS
354		JS          template.JS
355		Header      template.HTML
356		Heartbit    template.HTML
357		Footer      template.HTML
358		GeneratedAt string
359		ProjectName string
360		Username    string
361	}{
362		StatsJSON:   template.JS(statsJSON),
363		CSS:         template.CSS(statsCSS),
364		JS:          template.JS(statsJS),
365		Header:      template.HTML(headerSVG),
366		Heartbit:    template.HTML(heartbitSVG),
367		Footer:      template.HTML(footerSVG),
368		GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
369		ProjectName: projName,
370		Username:    username,
371	}
372
373	var buf bytes.Buffer
374	if err := tmpl.Execute(&buf, data); err != nil {
375		return fmt.Errorf("execute template: %w", err)
376	}
377
378	// Ensure parent directory exists.
379	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
380		return fmt.Errorf("create directory: %w", err)
381	}
382
383	return os.WriteFile(path, buf.Bytes(), 0o644)
384}