123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373 |
- package wsutil
- import (
- "encoding/binary"
- "errors"
- "io"
- "io/ioutil"
- "github.com/gobwas/ws"
- )
- // ErrNoFrameAdvance means that Reader's Read() method was called without
- // preceding NextFrame() call.
- var ErrNoFrameAdvance = errors.New("no frame advance")
- // ErrFrameTooLarge indicates that a message of length higher than
- // MaxFrameSize was being read.
- var ErrFrameTooLarge = errors.New("frame too large")
- // FrameHandlerFunc handles parsed frame header and its body represented by
- // io.Reader.
- //
- // Note that reader represents already unmasked body.
- type FrameHandlerFunc func(ws.Header, io.Reader) error
- // Reader is a wrapper around source io.Reader which represents WebSocket
- // connection. It contains options for reading messages from source.
- //
- // Reader implements io.Reader, which Read() method reads payload of incoming
- // WebSocket frames. It also takes care on fragmented frames and possibly
- // intermediate control frames between them.
- //
- // Note that Reader's methods are not goroutine safe.
- type Reader struct {
- Source io.Reader
- State ws.State
- // SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
- SkipHeaderCheck bool
- // CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
- // bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
- CheckUTF8 bool
- // Extensions is a list of negotiated extensions for reader Source.
- // It is used to meet the specs and clear appropriate bits in fragment
- // header RSV segment.
- Extensions []RecvExtension
- // MaxFrameSize controls the maximum frame size in bytes
- // that can be read. A message exceeding that size will return
- // a ErrFrameTooLarge to the application.
- //
- // Not setting this field means there is no limit.
- MaxFrameSize int64
- OnContinuation FrameHandlerFunc
- OnIntermediate FrameHandlerFunc
- opCode ws.OpCode // Used to store message op code on fragmentation.
- frame io.Reader // Used to as frame reader.
- raw io.LimitedReader // Used to discard frames without cipher.
- utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
- tmp [ws.MaxHeaderSize - 2]byte // Used for reading headers.
- cr *CipherReader // Used by NextFrame() to unmask frame payload.
- }
- // NewReader creates new frame reader that reads from r keeping given state to
- // make some protocol validity checks when it needed.
- func NewReader(r io.Reader, s ws.State) *Reader {
- return &Reader{
- Source: r,
- State: s,
- }
- }
- // NewClientSideReader is a helper function that calls NewReader with r and
- // ws.StateClientSide.
- func NewClientSideReader(r io.Reader) *Reader {
- return NewReader(r, ws.StateClientSide)
- }
- // NewServerSideReader is a helper function that calls NewReader with r and
- // ws.StateServerSide.
- func NewServerSideReader(r io.Reader) *Reader {
- return NewReader(r, ws.StateServerSide)
- }
- // Read implements io.Reader. It reads the next message payload into p.
- // It takes care on fragmented messages.
- //
- // The error is io.EOF only if all of message bytes were read.
- // If an io.EOF happens during reading some but not all the message bytes
- // Read() returns io.ErrUnexpectedEOF.
- //
- // The error is ErrNoFrameAdvance if no NextFrame() call was made before
- // reading next message bytes.
- func (r *Reader) Read(p []byte) (n int, err error) {
- if r.frame == nil {
- if !r.fragmented() {
- // Every new Read() must be preceded by NextFrame() call.
- return 0, ErrNoFrameAdvance
- }
- // Read next continuation or intermediate control frame.
- _, err := r.NextFrame()
- if err != nil {
- return 0, err
- }
- if r.frame == nil {
- // We handled intermediate control and now got nothing to read.
- return 0, nil
- }
- }
- n, err = r.frame.Read(p)
- if err != nil && err != io.EOF {
- return n, err
- }
- if err == nil && r.raw.N != 0 {
- return n, nil
- }
- // EOF condition (either err is io.EOF or r.raw.N is zero).
- switch {
- case r.raw.N != 0:
- err = io.ErrUnexpectedEOF
- case r.fragmented():
- err = nil
- r.resetFragment()
- case r.CheckUTF8 && !r.utf8.Valid():
- // NOTE: check utf8 only when full message received, since partial
- // reads may be invalid.
- n = r.utf8.Accepted()
- err = ErrInvalidUTF8
- default:
- r.reset()
- err = io.EOF
- }
- return n, err
- }
- // Discard discards current message unread bytes.
- // It discards all frames of fragmented message.
- func (r *Reader) Discard() (err error) {
- for {
- _, err = io.Copy(ioutil.Discard, &r.raw)
- if err != nil {
- break
- }
- if !r.fragmented() {
- break
- }
- if _, err = r.NextFrame(); err != nil {
- break
- }
- }
- r.reset()
- return err
- }
- // NextFrame prepares r to read next message. It returns received frame header
- // and non-nil error on failure.
- //
- // Note that next NextFrame() call must be done after receiving or discarding
- // all current message bytes.
- func (r *Reader) NextFrame() (hdr ws.Header, err error) {
- hdr, err = r.readHeader(r.Source)
- if err == io.EOF && r.fragmented() {
- // If we are in fragmented state EOF means that is was totally
- // unexpected.
- //
- // NOTE: This is necessary to prevent callers such that
- // ioutil.ReadAll to receive some amount of bytes without an error.
- // ReadAll() ignores an io.EOF error, thus caller may think that
- // whole message fetched, but actually only part of it.
- err = io.ErrUnexpectedEOF
- }
- if err == nil && !r.SkipHeaderCheck {
- err = ws.CheckHeader(hdr, r.State)
- }
- if err != nil {
- return hdr, err
- }
- if n := r.MaxFrameSize; n > 0 && hdr.Length > n {
- return hdr, ErrFrameTooLarge
- }
- // Save raw reader to use it on discarding frame without ciphering and
- // other streaming checks.
- r.raw = io.LimitedReader{
- R: r.Source,
- N: hdr.Length,
- }
- frame := io.Reader(&r.raw)
- if hdr.Masked {
- if r.cr == nil {
- r.cr = NewCipherReader(frame, hdr.Mask)
- } else {
- r.cr.Reset(frame, hdr.Mask)
- }
- frame = r.cr
- }
- for _, x := range r.Extensions {
- hdr, err = x.UnsetBits(hdr)
- if err != nil {
- return hdr, err
- }
- }
- if r.fragmented() {
- if hdr.OpCode.IsControl() {
- if cb := r.OnIntermediate; cb != nil {
- err = cb(hdr, frame)
- }
- if err == nil {
- // Ensure that src is empty.
- _, err = io.Copy(ioutil.Discard, &r.raw)
- }
- return hdr, err
- }
- } else {
- r.opCode = hdr.OpCode
- }
- if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
- r.utf8.Source = frame
- frame = &r.utf8
- }
- // Save reader with ciphering and other streaming checks.
- r.frame = frame
- if hdr.OpCode == ws.OpContinuation {
- if cb := r.OnContinuation; cb != nil {
- err = cb(hdr, frame)
- }
- }
- if hdr.Fin {
- r.State = r.State.Clear(ws.StateFragmented)
- } else {
- r.State = r.State.Set(ws.StateFragmented)
- }
- return hdr, err
- }
- func (r *Reader) fragmented() bool {
- return r.State.Fragmented()
- }
- func (r *Reader) resetFragment() {
- r.raw = io.LimitedReader{}
- r.frame = nil
- // Reset source of the UTF8Reader, but not the state.
- r.utf8.Source = nil
- }
- func (r *Reader) reset() {
- r.raw = io.LimitedReader{}
- r.frame = nil
- r.utf8 = UTF8Reader{}
- r.opCode = 0
- }
- // readHeader reads a frame header from in.
- func (r *Reader) readHeader(in io.Reader) (h ws.Header, err error) {
- // Make slice of bytes with capacity 12 that could hold any header.
- //
- // The maximum header size is 14, but due to the 2 hop reads,
- // after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
- // So 14 - 2 = 12.
- bts := r.tmp[:2]
- // Prepare to hold first 2 bytes to choose size of next read.
- _, err = io.ReadFull(in, bts)
- if err != nil {
- return h, err
- }
- const bit0 = 0x80
- h.Fin = bts[0]&bit0 != 0
- h.Rsv = (bts[0] & 0x70) >> 4
- h.OpCode = ws.OpCode(bts[0] & 0x0f)
- var extra int
- if bts[1]&bit0 != 0 {
- h.Masked = true
- extra += 4
- }
- length := bts[1] & 0x7f
- switch {
- case length < 126:
- h.Length = int64(length)
- case length == 126:
- extra += 2
- case length == 127:
- extra += 8
- default:
- err = ws.ErrHeaderLengthUnexpected
- return h, err
- }
- if extra == 0 {
- return h, err
- }
- // Increase len of bts to extra bytes need to read.
- // Overwrite first 2 bytes that was read before.
- bts = bts[:extra]
- _, err = io.ReadFull(in, bts)
- if err != nil {
- return h, err
- }
- switch {
- case length == 126:
- h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
- bts = bts[2:]
- case length == 127:
- if bts[0]&0x80 != 0 {
- err = ws.ErrHeaderLengthMSB
- return h, err
- }
- h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
- bts = bts[8:]
- }
- if h.Masked {
- copy(h.Mask[:], bts)
- }
- return h, nil
- }
- // NextReader prepares next message read from r. It returns header that
- // describes the message and io.Reader to read message's payload. It returns
- // non-nil error when it is not possible to read message's initial frame.
- //
- // Note that next NextReader() on the same r should be done after reading all
- // bytes from previously returned io.Reader. For more performant way to discard
- // message use Reader and its Discard() method.
- //
- // Note that it will not handle any "intermediate" frames, that possibly could
- // be received between text/binary continuation frames. That is, if peer sent
- // text/binary frame with fin flag "false", then it could send ping frame, and
- // eventually remaining part of text/binary frame with fin "true" – with
- // NextReader() the ping frame will be dropped without any notice. To handle
- // this rare, but possible situation (and if you do not know exactly which
- // frames peer could send), you could use Reader with OnIntermediate field set.
- func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
- rd := &Reader{
- Source: r,
- State: s,
- }
- header, err := rd.NextFrame()
- if err != nil {
- return header, nil, err
- }
- return header, rd, nil
- }
|