From 7e04c6fd4e2a3b03c8dc2b850250c3327f944fbb Mon Sep 17 00:00:00 2001 From: Amolith Date: Thu, 18 Dec 2025 23:24:27 -0700 Subject: [PATCH] refactor(lunatask): add Ping, extract shared code - Move error types, APIError, and doRequest helper to client.go - Add Ping() method for token verification - Add ListTasksOptions with source_id filter support Assisted-by: Claude Sonnet 4 via Crush --- lunatask/client.go | 102 ++++++++++++++++++++++++++++++++++++++++++ lunatask/tasks.go | 107 ++++++++------------------------------------- 2 files changed, 120 insertions(+), 89 deletions(-) diff --git a/lunatask/client.go b/lunatask/client.go index d643891b38dd84537d072172d163a68b5fa07f1d..1d045df7f9255e1798daf91276e61e98ef933c98 100644 --- a/lunatask/client.go +++ b/lunatask/client.go @@ -5,9 +5,77 @@ package lunatask import ( + "context" + "errors" + "fmt" + "io" "net/http" ) +// API error types for typed error handling +var ( + // ErrBadRequest indicates invalid, malformed, or missing parameters (400) + ErrBadRequest = errors.New("bad request") + // ErrUnauthorized indicates missing, wrong, or revoked access token (401) + ErrUnauthorized = errors.New("unauthorized") + // ErrPaymentRequired indicates a subscription is required (402) + ErrPaymentRequired = errors.New("subscription required") + // ErrNotFound indicates the specified entity could not be found (404) + ErrNotFound = errors.New("not found") + // ErrUnprocessableEntity indicates the provided entity is not valid (422) + ErrUnprocessableEntity = errors.New("unprocessable entity") + // ErrServerError indicates an internal server error (500) + ErrServerError = errors.New("server error") + // ErrServiceUnavailable indicates temporary maintenance (503) + ErrServiceUnavailable = errors.New("service unavailable") + // ErrTimeout indicates request timed out (524) + ErrTimeout = errors.New("request timed out") +) + +// APIError wraps an API error with status code and response body +type APIError struct { + StatusCode int + Body string + Err error +} + +func (e *APIError) Error() string { + if e.Body != "" { + return fmt.Sprintf("%s (status %d): %s", e.Err.Error(), e.StatusCode, e.Body) + } + return fmt.Sprintf("%s (status %d)", e.Err.Error(), e.StatusCode) +} + +func (e *APIError) Unwrap() error { + return e.Err +} + +// newAPIError creates an APIError from an HTTP status code +func newAPIError(statusCode int, body string) *APIError { + var err error + switch statusCode { + case http.StatusBadRequest: + err = ErrBadRequest + case http.StatusUnauthorized: + err = ErrUnauthorized + case http.StatusPaymentRequired: + err = ErrPaymentRequired + case http.StatusNotFound: + err = ErrNotFound + case http.StatusUnprocessableEntity: + err = ErrUnprocessableEntity + case http.StatusInternalServerError: + err = ErrServerError + case http.StatusServiceUnavailable: + err = ErrServiceUnavailable + case 524: + err = ErrTimeout + default: + err = fmt.Errorf("unexpected status %d", statusCode) + } + return &APIError{StatusCode: statusCode, Body: body, Err: err} +} + // Client handles communication with the Lunatask API type Client struct { AccessToken string @@ -23,3 +91,37 @@ func NewClient(accessToken string) *Client { HTTPClient: &http.Client{}, } } + +// doRequest performs an HTTP request and handles common response processing +func (c *Client) doRequest(req *http.Request) ([]byte, error) { + req.Header.Set("Authorization", "bearer "+c.AccessToken) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send HTTP request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, newAPIError(resp.StatusCode, string(body)) + } + + return body, nil +} + +// Ping verifies the access token is valid by calling the /ping endpoint. +// Returns nil on success, or an error (typically ErrUnauthorized) on failure. +func (c *Client) Ping(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.BaseURL+"/ping", nil) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + + _, err = c.doRequest(req) + return err +} diff --git a/lunatask/tasks.go b/lunatask/tasks.go index 9b109fa287b35185609e7431415327168ab5e682..15bdcb8069026a78d45e36265c09495e9e4e23ab 100644 --- a/lunatask/tasks.go +++ b/lunatask/tasks.go @@ -10,75 +10,10 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" ) -// API error types for typed error handling -var ( - // ErrBadRequest indicates invalid, malformed, or missing parameters (400) - ErrBadRequest = errors.New("bad request") - // ErrUnauthorized indicates missing, wrong, or revoked access token (401) - ErrUnauthorized = errors.New("unauthorized") - // ErrPaymentRequired indicates a subscription is required (402) - ErrPaymentRequired = errors.New("subscription required") - // ErrNotFound indicates the specified entity could not be found (404) - ErrNotFound = errors.New("not found") - // ErrUnprocessableEntity indicates the provided entity is not valid (422) - ErrUnprocessableEntity = errors.New("unprocessable entity") - // ErrServerError indicates an internal server error (500) - ErrServerError = errors.New("server error") - // ErrServiceUnavailable indicates temporary maintenance (503) - ErrServiceUnavailable = errors.New("service unavailable") - // ErrTimeout indicates request timed out (524) - ErrTimeout = errors.New("request timed out") -) - -// APIError wraps an API error with status code and response body -type APIError struct { - StatusCode int - Body string - Err error -} - -func (e *APIError) Error() string { - if e.Body != "" { - return fmt.Sprintf("%s (status %d): %s", e.Err.Error(), e.StatusCode, e.Body) - } - return fmt.Sprintf("%s (status %d)", e.Err.Error(), e.StatusCode) -} - -func (e *APIError) Unwrap() error { - return e.Err -} - -// newAPIError creates an APIError from an HTTP status code -func newAPIError(statusCode int, body string) *APIError { - var err error - switch statusCode { - case http.StatusBadRequest: - err = ErrBadRequest - case http.StatusUnauthorized: - err = ErrUnauthorized - case http.StatusPaymentRequired: - err = ErrPaymentRequired - case http.StatusNotFound: - err = ErrNotFound - case http.StatusUnprocessableEntity: - err = ErrUnprocessableEntity - case http.StatusInternalServerError: - err = ErrServerError - case http.StatusServiceUnavailable: - err = ErrServiceUnavailable - case 524: - err = ErrTimeout - default: - err = fmt.Errorf("unexpected status %d", statusCode) - } - return &APIError{StatusCode: statusCode, Body: body, Err: err} -} - // Source represents a task source like GitHub or other integrations type Source struct { Source string `json:"source"` @@ -149,33 +84,27 @@ type TasksResponse struct { Tasks []Task `json:"tasks"` } -// doRequest performs an HTTP request and handles common response processing -func (c *Client) doRequest(req *http.Request) ([]byte, error) { - req.Header.Set("Authorization", "bearer "+c.AccessToken) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send HTTP request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, newAPIError(resp.StatusCode, string(body)) - } - - return body, nil +// ListTasksOptions contains optional filters for listing tasks +type ListTasksOptions struct { + Source *string + SourceID *string } -// ListTasks retrieves all tasks, optionally filtered by source -func (c *Client) ListTasks(ctx context.Context, source *string) ([]Task, error) { +// ListTasks retrieves all tasks, optionally filtered by source and/or source_id +func (c *Client) ListTasks(ctx context.Context, opts *ListTasksOptions) ([]Task, error) { u := fmt.Sprintf("%s/tasks", c.BaseURL) - if source != nil && *source != "" { - u = fmt.Sprintf("%s?source=%s", u, url.QueryEscape(*source)) + + if opts != nil { + params := url.Values{} + if opts.Source != nil && *opts.Source != "" { + params.Set("source", *opts.Source) + } + if opts.SourceID != nil && *opts.SourceID != "" { + params.Set("source_id", *opts.SourceID) + } + if len(params) > 0 { + u = fmt.Sprintf("%s?%s", u, params.Encode()) + } } req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)