stats.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"database/sql"
  7	_ "embed"
  8	"encoding/json"
  9	"fmt"
 10	"html/template"
 11	"os"
 12	"os/user"
 13	"path/filepath"
 14	"strings"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/db"
 19	"github.com/charmbracelet/crush/internal/event"
 20	"github.com/pkg/browser"
 21	"github.com/spf13/cobra"
 22)
 23
 24//go:embed stats/index.html
 25var statsTemplate string
 26
 27//go:embed stats/index.css
 28var statsCSS string
 29
 30//go:embed stats/index.js
 31var statsJS string
 32
 33//go:embed stats/header.svg
 34var headerSVG string
 35
 36//go:embed stats/heartbit.svg
 37var heartbitSVG string
 38
 39//go:embed stats/footer.svg
 40var footerSVG string
 41
 42var statsCmd = &cobra.Command{
 43	Use:   "stats",
 44	Short: "Show usage statistics",
 45	Long:  "Generate and display usage statistics including token usage, costs, and activity patterns",
 46	RunE:  runStats,
 47}
 48
 49// Day names for day of week statistics.
 50var dayNames = []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
 51
 52// Stats holds all the statistics data.
 53type Stats struct {
 54	GeneratedAt       time.Time          `json:"generated_at"`
 55	Total             TotalStats         `json:"total"`
 56	UsageByDay        []DailyUsage       `json:"usage_by_day"`
 57	UsageByModel      []ModelUsage       `json:"usage_by_model"`
 58	UsageByHour       []HourlyUsage      `json:"usage_by_hour"`
 59	UsageByDayOfWeek  []DayOfWeekUsage   `json:"usage_by_day_of_week"`
 60	RecentActivity    []DailyActivity    `json:"recent_activity"`
 61	AvgResponseTimeMs float64            `json:"avg_response_time_ms"`
 62	ToolUsage         []ToolUsage        `json:"tool_usage"`
 63	HourDayHeatmap    []HourDayHeatmapPt `json:"hour_day_heatmap"`
 64}
 65
 66type TotalStats struct {
 67	TotalSessions         int64   `json:"total_sessions"`
 68	TotalPromptTokens     int64   `json:"total_prompt_tokens"`
 69	TotalCompletionTokens int64   `json:"total_completion_tokens"`
 70	TotalTokens           int64   `json:"total_tokens"`
 71	TotalCost             float64 `json:"total_cost"`
 72	TotalMessages         int64   `json:"total_messages"`
 73	AvgTokensPerSession   float64 `json:"avg_tokens_per_session"`
 74	AvgMessagesPerSession float64 `json:"avg_messages_per_session"`
 75}
 76
 77type DailyUsage struct {
 78	Day              string  `json:"day"`
 79	PromptTokens     int64   `json:"prompt_tokens"`
 80	CompletionTokens int64   `json:"completion_tokens"`
 81	TotalTokens      int64   `json:"total_tokens"`
 82	Cost             float64 `json:"cost"`
 83	SessionCount     int64   `json:"session_count"`
 84}
 85
 86type ModelUsage struct {
 87	Model        string `json:"model"`
 88	Provider     string `json:"provider"`
 89	MessageCount int64  `json:"message_count"`
 90}
 91
 92type HourlyUsage struct {
 93	Hour         int   `json:"hour"`
 94	SessionCount int64 `json:"session_count"`
 95}
 96
 97type DayOfWeekUsage struct {
 98	DayOfWeek        int    `json:"day_of_week"`
 99	DayName          string `json:"day_name"`
100	SessionCount     int64  `json:"session_count"`
101	PromptTokens     int64  `json:"prompt_tokens"`
102	CompletionTokens int64  `json:"completion_tokens"`
103}
104
105type DailyActivity struct {
106	Day          string  `json:"day"`
107	SessionCount int64   `json:"session_count"`
108	TotalTokens  int64   `json:"total_tokens"`
109	Cost         float64 `json:"cost"`
110}
111
112type ToolUsage struct {
113	ToolName  string `json:"tool_name"`
114	CallCount int64  `json:"call_count"`
115}
116
117type HourDayHeatmapPt struct {
118	DayOfWeek    int   `json:"day_of_week"`
119	Hour         int   `json:"hour"`
120	SessionCount int64 `json:"session_count"`
121}
122
123func runStats(cmd *cobra.Command, _ []string) error {
124	dataDir, _ := cmd.Flags().GetString("data-dir")
125	ctx := cmd.Context()
126
127	cfg, err := config.Init("", dataDir, false)
128	if err != nil {
129		return fmt.Errorf("failed to initialize config: %w", err)
130	}
131	if dataDir == "" {
132		dataDir = cfg.Config().Options.DataDirectory
133	}
134	if shouldEnableMetrics(cfg.Config()) {
135		event.Init()
136	}
137
138	event.StatsViewed()
139
140	conn, err := db.Connect(ctx, dataDir)
141	if err != nil {
142		return fmt.Errorf("failed to connect to database: %w", err)
143	}
144	defer conn.Close()
145
146	stats, err := gatherStats(ctx, conn)
147	if err != nil {
148		return fmt.Errorf("failed to gather stats: %w", err)
149	}
150
151	if stats.Total.TotalSessions == 0 {
152		return fmt.Errorf("no data available: no sessions found in database")
153	}
154
155	currentUser, err := user.Current()
156	if err != nil {
157		return fmt.Errorf("failed to get current user: %w", err)
158	}
159	username := currentUser.Username
160	project, err := os.Getwd()
161	if err != nil {
162		return fmt.Errorf("failed to get current directory: %w", err)
163	}
164	project = strings.Replace(project, currentUser.HomeDir, "~", 1)
165
166	htmlPath := filepath.Join(dataDir, "stats/index.html")
167	if err := generateHTML(stats, project, username, htmlPath); err != nil {
168		return fmt.Errorf("failed to generate HTML: %w", err)
169	}
170
171	fmt.Printf("Stats generated: %s\n", htmlPath)
172
173	if err := browser.OpenFile(htmlPath); err != nil {
174		fmt.Printf("Could not open browser: %v\n", err)
175		fmt.Println("Please open the file manually.")
176	}
177
178	return nil
179}
180
181func gatherStats(ctx context.Context, conn *sql.DB) (*Stats, error) {
182	queries := db.New(conn)
183
184	stats := &Stats{
185		GeneratedAt: time.Now(),
186	}
187
188	// Total stats.
189	total, err := queries.GetTotalStats(ctx)
190	if err != nil {
191		return nil, fmt.Errorf("get total stats: %w", err)
192	}
193	stats.Total = TotalStats{
194		TotalSessions:         total.TotalSessions,
195		TotalPromptTokens:     toInt64(total.TotalPromptTokens),
196		TotalCompletionTokens: toInt64(total.TotalCompletionTokens),
197		TotalTokens:           toInt64(total.TotalPromptTokens) + toInt64(total.TotalCompletionTokens),
198		TotalCost:             toFloat64(total.TotalCost),
199		TotalMessages:         toInt64(total.TotalMessages),
200		AvgTokensPerSession:   toFloat64(total.AvgTokensPerSession),
201		AvgMessagesPerSession: toFloat64(total.AvgMessagesPerSession),
202	}
203
204	// Usage by day.
205	dailyUsage, err := queries.GetUsageByDay(ctx)
206	if err != nil {
207		return nil, fmt.Errorf("get usage by day: %w", err)
208	}
209	for _, d := range dailyUsage {
210		prompt := nullFloat64ToInt64(d.PromptTokens)
211		completion := nullFloat64ToInt64(d.CompletionTokens)
212		stats.UsageByDay = append(stats.UsageByDay, DailyUsage{
213			Day:              fmt.Sprintf("%v", d.Day),
214			PromptTokens:     prompt,
215			CompletionTokens: completion,
216			TotalTokens:      prompt + completion,
217			Cost:             d.Cost.Float64,
218			SessionCount:     d.SessionCount,
219		})
220	}
221
222	// Usage by model.
223	modelUsage, err := queries.GetUsageByModel(ctx)
224	if err != nil {
225		return nil, fmt.Errorf("get usage by model: %w", err)
226	}
227	for _, m := range modelUsage {
228		stats.UsageByModel = append(stats.UsageByModel, ModelUsage{
229			Model:        m.Model,
230			Provider:     m.Provider,
231			MessageCount: m.MessageCount,
232		})
233	}
234
235	// Usage by hour.
236	hourlyUsage, err := queries.GetUsageByHour(ctx)
237	if err != nil {
238		return nil, fmt.Errorf("get usage by hour: %w", err)
239	}
240	for _, h := range hourlyUsage {
241		stats.UsageByHour = append(stats.UsageByHour, HourlyUsage{
242			Hour:         int(h.Hour),
243			SessionCount: h.SessionCount,
244		})
245	}
246
247	// Usage by day of week.
248	dowUsage, err := queries.GetUsageByDayOfWeek(ctx)
249	if err != nil {
250		return nil, fmt.Errorf("get usage by day of week: %w", err)
251	}
252	for _, d := range dowUsage {
253		stats.UsageByDayOfWeek = append(stats.UsageByDayOfWeek, DayOfWeekUsage{
254			DayOfWeek:        int(d.DayOfWeek),
255			DayName:          dayNames[int(d.DayOfWeek)],
256			SessionCount:     d.SessionCount,
257			PromptTokens:     nullFloat64ToInt64(d.PromptTokens),
258			CompletionTokens: nullFloat64ToInt64(d.CompletionTokens),
259		})
260	}
261
262	// Recent activity (last 30 days).
263	recent, err := queries.GetRecentActivity(ctx)
264	if err != nil {
265		return nil, fmt.Errorf("get recent activity: %w", err)
266	}
267	for _, r := range recent {
268		stats.RecentActivity = append(stats.RecentActivity, DailyActivity{
269			Day:          fmt.Sprintf("%v", r.Day),
270			SessionCount: r.SessionCount,
271			TotalTokens:  nullFloat64ToInt64(r.TotalTokens),
272			Cost:         r.Cost.Float64,
273		})
274	}
275
276	// Average response time.
277	avgResp, err := queries.GetAverageResponseTime(ctx)
278	if err != nil {
279		return nil, fmt.Errorf("get average response time: %w", err)
280	}
281	stats.AvgResponseTimeMs = toFloat64(avgResp) * 1000
282
283	// Tool usage.
284	toolUsage, err := queries.GetToolUsage(ctx)
285	if err != nil {
286		return nil, fmt.Errorf("get tool usage: %w", err)
287	}
288	for _, t := range toolUsage {
289		if name, ok := t.ToolName.(string); ok && name != "" {
290			stats.ToolUsage = append(stats.ToolUsage, ToolUsage{
291				ToolName:  name,
292				CallCount: t.CallCount,
293			})
294		}
295	}
296
297	// Hour/day heatmap.
298	heatmap, err := queries.GetHourDayHeatmap(ctx)
299	if err != nil {
300		return nil, fmt.Errorf("get hour day heatmap: %w", err)
301	}
302	for _, h := range heatmap {
303		stats.HourDayHeatmap = append(stats.HourDayHeatmap, HourDayHeatmapPt{
304			DayOfWeek:    int(h.DayOfWeek),
305			Hour:         int(h.Hour),
306			SessionCount: h.SessionCount,
307		})
308	}
309
310	return stats, nil
311}
312
313func toInt64(v any) int64 {
314	switch val := v.(type) {
315	case int64:
316		return val
317	case float64:
318		return int64(val)
319	case int:
320		return int64(val)
321	default:
322		return 0
323	}
324}
325
326func toFloat64(v any) float64 {
327	switch val := v.(type) {
328	case float64:
329		return val
330	case int64:
331		return float64(val)
332	case int:
333		return float64(val)
334	default:
335		return 0
336	}
337}
338
339func nullFloat64ToInt64(n sql.NullFloat64) int64 {
340	if n.Valid {
341		return int64(n.Float64)
342	}
343	return 0
344}
345
346func generateHTML(stats *Stats, projName, username, path string) error {
347	statsJSON, err := json.Marshal(stats)
348	if err != nil {
349		return err
350	}
351
352	tmpl, err := template.New("stats").Parse(statsTemplate)
353	if err != nil {
354		return fmt.Errorf("parse template: %w", err)
355	}
356
357	data := struct {
358		StatsJSON   template.JS
359		CSS         template.CSS
360		JS          template.JS
361		Header      template.HTML
362		Heartbit    template.HTML
363		Footer      template.HTML
364		GeneratedAt string
365		ProjectName string
366		Username    string
367	}{
368		StatsJSON:   template.JS(statsJSON),
369		CSS:         template.CSS(statsCSS),
370		JS:          template.JS(statsJS),
371		Header:      template.HTML(headerSVG),
372		Heartbit:    template.HTML(heartbitSVG),
373		Footer:      template.HTML(footerSVG),
374		GeneratedAt: stats.GeneratedAt.Format("2006-01-02"),
375		ProjectName: projName,
376		Username:    username,
377	}
378
379	var buf bytes.Buffer
380	if err := tmpl.Execute(&buf, data); err != nil {
381		return fmt.Errorf("execute template: %w", err)
382	}
383
384	// Ensure parent directory exists.
385	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
386		return fmt.Errorf("create directory: %w", err)
387	}
388
389	return os.WriteFile(path, buf.Bytes(), 0o644)
390}