ssestream.go

  1// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
  2
  3package ssestream
  4
  5import (
  6	"bufio"
  7	"bytes"
  8	"encoding/json"
  9	"fmt"
 10	"io"
 11	"net/http"
 12	"strings"
 13
 14	"github.com/tidwall/gjson"
 15)
 16
 17type Decoder interface {
 18	Event() Event
 19	Next() bool
 20	Close() error
 21	Err() error
 22}
 23
 24func NewDecoder(res *http.Response) Decoder {
 25	if res == nil || res.Body == nil {
 26		return nil
 27	}
 28
 29	var decoder Decoder
 30	contentType := res.Header.Get("content-type")
 31	if t, ok := decoderTypes[contentType]; ok {
 32		decoder = t(res.Body)
 33	} else {
 34		scn := bufio.NewScanner(res.Body)
 35		scn.Buffer(nil, bufio.MaxScanTokenSize<<4)
 36		decoder = &eventStreamDecoder{rc: res.Body, scn: scn}
 37	}
 38	return decoder
 39}
 40
 41var decoderTypes = map[string](func(io.ReadCloser) Decoder){}
 42
 43func RegisterDecoder(contentType string, decoder func(io.ReadCloser) Decoder) {
 44	decoderTypes[strings.ToLower(contentType)] = decoder
 45}
 46
 47type Event struct {
 48	Type string
 49	Data []byte
 50}
 51
 52// A base implementation of a Decoder for text/event-stream.
 53type eventStreamDecoder struct {
 54	evt Event
 55	rc  io.ReadCloser
 56	scn *bufio.Scanner
 57	err error
 58}
 59
 60func (s *eventStreamDecoder) Next() bool {
 61	if s.err != nil {
 62		return false
 63	}
 64
 65	event := ""
 66	data := bytes.NewBuffer(nil)
 67
 68	for s.scn.Scan() {
 69		txt := s.scn.Bytes()
 70
 71		// Dispatch event on an empty line
 72		if len(txt) == 0 {
 73			s.evt = Event{
 74				Type: event,
 75				Data: data.Bytes(),
 76			}
 77			return true
 78		}
 79
 80		// Split a string like "event: bar" into name="event" and value=" bar".
 81		name, value, _ := bytes.Cut(txt, []byte(":"))
 82
 83		// Consume an optional space after the colon if it exists.
 84		if len(value) > 0 && value[0] == ' ' {
 85			value = value[1:]
 86		}
 87
 88		switch string(name) {
 89		case "":
 90			// An empty line in the for ": something" is a comment and should be ignored.
 91			continue
 92		case "event":
 93			event = string(value)
 94		case "data":
 95			_, s.err = data.Write(value)
 96			if s.err != nil {
 97				break
 98			}
 99			_, s.err = data.WriteRune('\n')
100			if s.err != nil {
101				break
102			}
103		}
104	}
105
106	if s.scn.Err() != nil {
107		s.err = s.scn.Err()
108	}
109
110	return false
111}
112
113func (s *eventStreamDecoder) Event() Event {
114	return s.evt
115}
116
117func (s *eventStreamDecoder) Close() error {
118	return s.rc.Close()
119}
120
121func (s *eventStreamDecoder) Err() error {
122	return s.err
123}
124
125type Stream[T any] struct {
126	decoder Decoder
127	cur     T
128	err     error
129	done    bool
130}
131
132func NewStream[T any](decoder Decoder, err error) *Stream[T] {
133	return &Stream[T]{
134		decoder: decoder,
135		err:     err,
136	}
137}
138
139// Next returns false if the stream has ended or an error occurred.
140// Call Stream.Current() to get the current value.
141// Call Stream.Err() to get the error.
142//
143//		for stream.Next() {
144//			data := stream.Current()
145//		}
146//
147//	 	if stream.Err() != nil {
148//			...
149//	 	}
150func (s *Stream[T]) Next() bool {
151	if s.err != nil {
152		return false
153	}
154
155	for s.decoder.Next() {
156		if s.done {
157			continue
158		}
159
160		if bytes.HasPrefix(s.decoder.Event().Data, []byte("[DONE]")) {
161			// In this case we don't break because we still want to iterate through the full stream.
162			s.done = true
163			continue
164		}
165
166		var nxt T
167		if s.decoder.Event().Type == "" || strings.HasPrefix(s.decoder.Event().Type, "response.") {
168			ep := gjson.GetBytes(s.decoder.Event().Data, "error")
169			if ep.Exists() {
170				s.err = fmt.Errorf("received error while streaming: %s", ep.String())
171				return false
172			}
173			s.err = json.Unmarshal(s.decoder.Event().Data, &nxt)
174			if s.err != nil {
175				return false
176			}
177			s.cur = nxt
178			return true
179		} else {
180			ep := gjson.GetBytes(s.decoder.Event().Data, "error")
181			if ep.Exists() {
182				s.err = fmt.Errorf("received error while streaming: %s", ep.String())
183				return false
184			}
185			event := s.decoder.Event().Type
186			data := s.decoder.Event().Data
187			s.err = json.Unmarshal([]byte(fmt.Sprintf(`{ "event": %q, "data": %s }`, event, data)), &nxt)
188			if s.err != nil {
189				return false
190			}
191			s.cur = nxt
192			return true
193		}
194	}
195
196	// decoder.Next() may be false because of an error
197	s.err = s.decoder.Err()
198
199	return false
200}
201
202func (s *Stream[T]) Current() T {
203	return s.cur
204}
205
206func (s *Stream[T]) Err() error {
207	return s.err
208}
209
210func (s *Stream[T]) Close() error {
211	if s.decoder == nil {
212		// already closed
213		return nil
214	}
215	return s.decoder.Close()
216}