basic_transfer.go

  1package lfs
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"net/http"
 11
 12	log "github.com/charmbracelet/log/v2"
 13)
 14
 15// BasicTransferAdapter implements the "basic" adapter.
 16type BasicTransferAdapter struct {
 17	client *http.Client
 18}
 19
 20// Name returns the name of the adapter.
 21func (a *BasicTransferAdapter) Name() string {
 22	return "basic"
 23}
 24
 25// Download reads the download location and downloads the data.
 26func (a *BasicTransferAdapter) Download(ctx context.Context, _ Pointer, l *Link) (io.ReadCloser, error) {
 27	resp, err := a.performRequest(ctx, "GET", l, nil, nil)
 28	if err != nil {
 29		return nil, err
 30	}
 31	return resp.Body, nil
 32}
 33
 34// Upload sends the content to the LFS server.
 35func (a *BasicTransferAdapter) Upload(ctx context.Context, p Pointer, r io.Reader, l *Link) error {
 36	res, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) {
 37		if len(req.Header.Get("Content-Type")) == 0 {
 38			req.Header.Set("Content-Type", "application/octet-stream")
 39		}
 40
 41		if req.Header.Get("Transfer-Encoding") == "chunked" {
 42			req.TransferEncoding = []string{"chunked"}
 43		}
 44
 45		req.ContentLength = p.Size
 46	})
 47	if err != nil {
 48		return err
 49	}
 50	return res.Body.Close()
 51}
 52
 53// Verify calls the verify handler on the LFS server.
 54func (a *BasicTransferAdapter) Verify(ctx context.Context, p Pointer, l *Link) error {
 55	logger := log.FromContext(ctx).WithPrefix("lfs")
 56	b, err := json.Marshal(p)
 57	if err != nil {
 58		logger.Errorf("Error encoding json: %v", err)
 59		return err
 60	}
 61
 62	res, err := a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) {
 63		req.Header.Set("Content-Type", MediaType)
 64	})
 65	if err != nil {
 66		return err
 67	}
 68	return res.Body.Close()
 69}
 70
 71func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) {
 72	logger := log.FromContext(ctx).WithPrefix("lfs")
 73	logger.Debugf("Calling: %s %s", method, l.Href)
 74
 75	req, err := http.NewRequestWithContext(ctx, method, l.Href, body)
 76	if err != nil {
 77		logger.Errorf("Error creating request: %v", err)
 78		return nil, err
 79	}
 80	for key, value := range l.Header {
 81		req.Header.Set(key, value)
 82	}
 83	req.Header.Set("Accept", MediaType)
 84
 85	if callback != nil {
 86		callback(req)
 87	}
 88
 89	res, err := a.client.Do(req)
 90	if err != nil {
 91		select {
 92		case <-ctx.Done():
 93			return res, ctx.Err()
 94		default:
 95		}
 96		logger.Errorf("Error while processing request: %v", err)
 97		return res, err
 98	}
 99
100	if res.StatusCode != http.StatusOK {
101		return res, handleErrorResponse(res)
102	}
103
104	return res, nil
105}
106
107func handleErrorResponse(resp *http.Response) error {
108	defer resp.Body.Close()
109
110	er, err := decodeResponseError(resp.Body)
111	if err != nil {
112		return fmt.Errorf("Request failed with status %s", resp.Status)
113	}
114	return errors.New(er.Message)
115}
116
117func decodeResponseError(r io.Reader) (ErrorResponse, error) {
118	var er ErrorResponse
119	err := json.NewDecoder(r).Decode(&er)
120	if err != nil {
121		log.Error("Error decoding json: %v", err)
122	}
123	return er, err
124}