basic_transfer.go

  1// Package lfs provides Git LFS (Large File Storage) functionality.
  2package lfs
  3
  4import (
  5	"bytes"
  6	"context"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"net/http"
 12
 13	"github.com/charmbracelet/log/v2"
 14)
 15
 16// BasicTransferAdapter implements the "basic" adapter.
 17type BasicTransferAdapter struct {
 18	client *http.Client
 19}
 20
 21// Name returns the name of the adapter.
 22func (a *BasicTransferAdapter) Name() string {
 23	return "basic"
 24}
 25
 26// Download reads the download location and downloads the data.
 27func (a *BasicTransferAdapter) Download(ctx context.Context, _ Pointer, l *Link) (io.ReadCloser, error) {
 28	resp, err := a.performRequest(ctx, "GET", l, nil, nil)
 29	if err != nil {
 30		return nil, err
 31	}
 32	return resp.Body, nil
 33}
 34
 35// Upload sends the content to the LFS server.
 36func (a *BasicTransferAdapter) Upload(ctx context.Context, p Pointer, r io.Reader, l *Link) error {
 37	res, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) {
 38		if len(req.Header.Get("Content-Type")) == 0 {
 39			req.Header.Set("Content-Type", "application/octet-stream")
 40		}
 41
 42		if req.Header.Get("Transfer-Encoding") == "chunked" {
 43			req.TransferEncoding = []string{"chunked"}
 44		}
 45
 46		req.ContentLength = p.Size
 47	})
 48	if err != nil {
 49		return err
 50	}
 51	return res.Body.Close() //nolint:wrapcheck
 52}
 53
 54// Verify calls the verify handler on the LFS server.
 55func (a *BasicTransferAdapter) Verify(ctx context.Context, p Pointer, l *Link) error {
 56	logger := log.FromContext(ctx).WithPrefix("lfs")
 57	b, err := json.Marshal(p)
 58	if err != nil {
 59		logger.Errorf("Error encoding json: %v", err)
 60		return err //nolint:wrapcheck
 61	}
 62
 63	res, err := a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) {
 64		req.Header.Set("Content-Type", MediaType)
 65	})
 66	if err != nil {
 67		return err
 68	}
 69	return res.Body.Close() //nolint:wrapcheck
 70}
 71
 72func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) {
 73	logger := log.FromContext(ctx).WithPrefix("lfs")
 74	logger.Debugf("Calling: %s %s", method, l.Href)
 75
 76	req, err := http.NewRequestWithContext(ctx, method, l.Href, body)
 77	if err != nil {
 78		logger.Errorf("Error creating request: %v", err)
 79		return nil, err //nolint:wrapcheck
 80	}
 81	for key, value := range l.Header {
 82		req.Header.Set(key, value)
 83	}
 84	req.Header.Set("Accept", MediaType)
 85
 86	if callback != nil {
 87		callback(req)
 88	}
 89
 90	res, err := a.client.Do(req)
 91	if err != nil {
 92		select {
 93		case <-ctx.Done():
 94			return res, ctx.Err() //nolint:wrapcheck
 95		default:
 96		}
 97		logger.Errorf("Error while processing request: %v", err)
 98		return res, err //nolint:wrapcheck
 99	}
100
101	if res.StatusCode != http.StatusOK {
102		return res, handleErrorResponse(res)
103	}
104
105	return res, nil
106}
107
108func handleErrorResponse(resp *http.Response) error {
109	defer resp.Body.Close() //nolint: errcheck
110
111	er, err := decodeResponseError(resp.Body)
112	if err != nil {
113		return fmt.Errorf("request failed with status %s", resp.Status)
114	}
115	return errors.New(er.Message)
116}
117
118func decodeResponseError(r io.Reader) (ErrorResponse, error) {
119	var er ErrorResponse
120	err := json.NewDecoder(r).Decode(&er)
121	if err != nil {
122		log.Error("Error decoding json: %v", err)
123	}
124	return er, err //nolint:wrapcheck
125}