dialer.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. package ws
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/tls"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/http"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "time"
  15. "github.com/gobwas/httphead"
  16. "github.com/gobwas/pool/pbufio"
  17. )
  18. // Constants used by Dialer.
  19. const (
  20. DefaultClientReadBufferSize = 4096
  21. DefaultClientWriteBufferSize = 4096
  22. )
  23. // Handshake represents handshake result.
  24. type Handshake struct {
  25. // Protocol is the subprotocol selected during handshake.
  26. Protocol string
  27. // Extensions is the list of negotiated extensions.
  28. Extensions []httphead.Option
  29. }
  30. // Errors used by the websocket client.
  31. var (
  32. ErrHandshakeBadStatus = fmt.Errorf("unexpected http status")
  33. ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
  34. ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
  35. )
  36. // DefaultDialer is dialer that holds no options and is used by Dial function.
  37. var DefaultDialer Dialer
  38. // Dial is like Dialer{}.Dial().
  39. func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
  40. return DefaultDialer.Dial(ctx, urlstr)
  41. }
  42. // Dialer contains options for establishing websocket connection to an url.
  43. type Dialer struct {
  44. // ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
  45. // They used to read and write http data while upgrading to WebSocket.
  46. // Allocated buffers are pooled with sync.Pool to avoid extra allocations.
  47. //
  48. // If a size is zero then default value is used.
  49. ReadBufferSize, WriteBufferSize int
  50. // Timeout is the maximum amount of time a Dial() will wait for a connect
  51. // and an handshake to complete.
  52. //
  53. // The default is no timeout.
  54. Timeout time.Duration
  55. // Protocols is the list of subprotocols that the client wants to speak,
  56. // ordered by preference.
  57. //
  58. // See https://tools.ietf.org/html/rfc6455#section-4.1
  59. Protocols []string
  60. // Extensions is the list of extensions that client wants to speak.
  61. //
  62. // Note that if server decides to use some of this extensions, Dial() will
  63. // return Handshake struct containing a slice of items, which are the
  64. // shallow copies of the items from this list. That is, internals of
  65. // Extensions items are shared during Dial().
  66. //
  67. // See https://tools.ietf.org/html/rfc6455#section-4.1
  68. // See https://tools.ietf.org/html/rfc6455#section-9.1
  69. Extensions []httphead.Option
  70. // Header is an optional HandshakeHeader instance that could be used to
  71. // write additional headers to the handshake request.
  72. //
  73. // It used instead of any key-value mappings to avoid allocations in user
  74. // land.
  75. Header HandshakeHeader
  76. // Host is an optional string that could be used to specify the host during
  77. // HTTP upgrade request by setting 'Host' header.
  78. //
  79. // Default value is an empty string, which results in setting 'Host' header
  80. // equal to the URL hostname given to Dialer.Dial().
  81. Host string
  82. // OnStatusError is the callback that will be called after receiving non
  83. // "101 Continue" HTTP response status. It receives an io.Reader object
  84. // representing server response bytes. That is, it gives ability to parse
  85. // HTTP response somehow (probably with http.ReadResponse call) and make a
  86. // decision of further logic.
  87. //
  88. // The arguments are only valid until the callback returns.
  89. OnStatusError func(status int, reason []byte, resp io.Reader)
  90. // OnHeader is the callback that will be called after successful parsing of
  91. // header, that is not used during WebSocket handshake procedure. That is,
  92. // it will be called with non-websocket headers, which could be relevant
  93. // for application-level logic.
  94. //
  95. // The arguments are only valid until the callback returns.
  96. //
  97. // Returned value could be used to prevent processing response.
  98. OnHeader func(key, value []byte) (err error)
  99. // NetDial is the function that is used to get plain tcp connection.
  100. // If it is not nil, then it is used instead of net.Dialer.
  101. NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
  102. // TLSClient is the callback that will be called after successful dial with
  103. // received connection and its remote host name. If it is nil, then the
  104. // default tls.Client() will be used.
  105. // If it is not nil, then TLSConfig field is ignored.
  106. TLSClient func(conn net.Conn, hostname string) net.Conn
  107. // TLSConfig is passed to tls.Client() to start TLS over established
  108. // connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
  109. // non-nil and its ServerName is empty, then for every Dial() it will be
  110. // cloned and appropriate ServerName will be set.
  111. TLSConfig *tls.Config
  112. // WrapConn is the optional callback that will be called when connection is
  113. // ready for an i/o. That is, it will be called after successful dial and
  114. // TLS initialization (for "wss" schemes). It may be helpful for different
  115. // user land purposes such as end to end encryption.
  116. //
  117. // Note that for debugging purposes of an http handshake (e.g. sent request
  118. // and received response), there is an wsutil.DebugDialer struct.
  119. WrapConn func(conn net.Conn) net.Conn
  120. }
  121. // Dial connects to the url host and upgrades connection to WebSocket.
  122. //
  123. // If server has sent frames right after successful handshake then returned
  124. // buffer will be non-nil. In other cases buffer is always nil. For better
  125. // memory efficiency received non-nil bufio.Reader should be returned to the
  126. // inner pool with PutReader() function after use.
  127. //
  128. // Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
  129. // If you want to dial non-ascii host name, take care of its name serialization
  130. // avoiding bad request issues. For more info see net/http Request.Write()
  131. // implementation, especially cleanHost() function.
  132. func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
  133. u, err := url.ParseRequestURI(urlstr)
  134. if err != nil {
  135. return nil, nil, hs, err
  136. }
  137. // Prepare context to dial with. Initially it is the same as original, but
  138. // if d.Timeout is non-zero and points to time that is before ctx.Deadline,
  139. // we use more shorter context for dial.
  140. dialctx := ctx
  141. var deadline time.Time
  142. if t := d.Timeout; t != 0 {
  143. deadline = time.Now().Add(t)
  144. if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
  145. var cancel context.CancelFunc
  146. dialctx, cancel = context.WithDeadline(ctx, deadline)
  147. defer cancel()
  148. }
  149. }
  150. if conn, err = d.dial(dialctx, u); err != nil {
  151. return conn, nil, hs, err
  152. }
  153. defer func() {
  154. if err != nil {
  155. conn.Close()
  156. }
  157. }()
  158. if ctx == context.Background() {
  159. // No need to start I/O interrupter goroutine which is not zero-cost.
  160. conn.SetDeadline(deadline)
  161. defer conn.SetDeadline(noDeadline)
  162. } else {
  163. // Context could be canceled or its deadline could be exceeded.
  164. // Start the interrupter goroutine to handle context cancelation.
  165. done := setupContextDeadliner(ctx, conn)
  166. defer func() {
  167. // Map Upgrade() error to a possible context expiration error. That
  168. // is, even if Upgrade() err is nil, context could be already
  169. // expired and connection be "poisoned" by SetDeadline() call.
  170. // In that case we must not return ctx.Err() error.
  171. done(&err)
  172. }()
  173. }
  174. br, hs, err = d.Upgrade(conn, u)
  175. return conn, br, hs, err
  176. }
  177. var (
  178. // netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
  179. // Dialer.NetDial is not provided.
  180. netEmptyDialer net.Dialer
  181. // tlsEmptyConfig is an empty tls.Config used as default one.
  182. tlsEmptyConfig tls.Config
  183. )
  184. func tlsDefaultConfig() *tls.Config {
  185. return &tlsEmptyConfig
  186. }
  187. func hostport(host, defaultPort string) (hostname, addr string) {
  188. var (
  189. colon = strings.LastIndexByte(host, ':')
  190. bracket = strings.IndexByte(host, ']')
  191. )
  192. if colon > bracket {
  193. return host[:colon], host
  194. }
  195. return host, host + defaultPort
  196. }
  197. func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
  198. dial := d.NetDial
  199. if dial == nil {
  200. dial = netEmptyDialer.DialContext
  201. }
  202. switch u.Scheme {
  203. case "ws":
  204. _, addr := hostport(u.Host, ":80")
  205. conn, err = dial(ctx, "tcp", addr)
  206. case "wss":
  207. hostname, addr := hostport(u.Host, ":443")
  208. conn, err = dial(ctx, "tcp", addr)
  209. if err != nil {
  210. return nil, err
  211. }
  212. tlsClient := d.TLSClient
  213. if tlsClient == nil {
  214. tlsClient = d.tlsClient
  215. }
  216. conn = tlsClient(conn, hostname)
  217. default:
  218. return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
  219. }
  220. if wrap := d.WrapConn; wrap != nil {
  221. conn = wrap(conn)
  222. }
  223. return conn, err
  224. }
  225. func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
  226. config := d.TLSConfig
  227. if config == nil {
  228. config = tlsDefaultConfig()
  229. }
  230. if config.ServerName == "" {
  231. config = tlsCloneConfig(config)
  232. config.ServerName = hostname
  233. }
  234. // Do not make conn.Handshake() here because downstairs we will prepare
  235. // i/o on this conn with proper context's timeout handling.
  236. return tls.Client(conn, config)
  237. }
  238. var (
  239. // This variables are set like in net/net.go.
  240. // noDeadline is just zero value for readability.
  241. noDeadline = time.Time{}
  242. // aLongTimeAgo is a non-zero time, far in the past, used for immediate
  243. // cancelation of dials.
  244. aLongTimeAgo = time.Unix(42, 0)
  245. )
  246. // Upgrade writes an upgrade request to the given io.ReadWriter conn at given
  247. // url u and reads a response from it.
  248. //
  249. // It is a caller responsibility to manage I/O deadlines on conn.
  250. //
  251. // It returns handshake info and some bytes which could be written by the peer
  252. // right after response and be caught by us during buffered read.
  253. func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
  254. // headerSeen constants helps to report whether or not some header was seen
  255. // during reading request bytes.
  256. const (
  257. headerSeenUpgrade = 1 << iota
  258. headerSeenConnection
  259. headerSeenSecAccept
  260. // headerSeenAll is the value that we expect to receive at the end of
  261. // headers read/parse loop.
  262. headerSeenAll = 0 |
  263. headerSeenUpgrade |
  264. headerSeenConnection |
  265. headerSeenSecAccept
  266. )
  267. br = pbufio.GetReader(conn,
  268. nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
  269. )
  270. bw := pbufio.GetWriter(conn,
  271. nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
  272. )
  273. defer func() {
  274. pbufio.PutWriter(bw)
  275. if br.Buffered() == 0 || err != nil {
  276. // Server does not wrote additional bytes to the connection or
  277. // error occurred. That is, no reason to return buffer.
  278. pbufio.PutReader(br)
  279. br = nil
  280. }
  281. }()
  282. nonce := make([]byte, nonceSize)
  283. initNonce(nonce)
  284. httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header, d.Host)
  285. if err := bw.Flush(); err != nil {
  286. return br, hs, err
  287. }
  288. // Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
  289. sl, err := readLine(br)
  290. if err != nil {
  291. return br, hs, err
  292. }
  293. // Begin validation of the response.
  294. // See https://tools.ietf.org/html/rfc6455#section-4.2.2
  295. // Parse request line data like HTTP version, uri and method.
  296. resp, err := httpParseResponseLine(sl)
  297. if err != nil {
  298. return br, hs, err
  299. }
  300. // Even if RFC says "1.1 or higher" without mentioning the part of the
  301. // version, we apply it only to minor part.
  302. if resp.major != 1 || resp.minor < 1 {
  303. err = ErrHandshakeBadProtocol
  304. return br, hs, err
  305. }
  306. if resp.status != http.StatusSwitchingProtocols {
  307. err = StatusError(resp.status)
  308. if onStatusError := d.OnStatusError; onStatusError != nil {
  309. // Invoke callback with multireader of status-line bytes br.
  310. onStatusError(resp.status, resp.reason,
  311. io.MultiReader(
  312. bytes.NewReader(sl),
  313. strings.NewReader(crlf),
  314. br,
  315. ),
  316. )
  317. }
  318. return br, hs, err
  319. }
  320. // If response status is 101 then we expect all technical headers to be
  321. // valid. If not, then we stop processing response without giving user
  322. // ability to read non-technical headers. That is, we do not distinguish
  323. // technical errors (such as parsing error) and protocol errors.
  324. var headerSeen byte
  325. for {
  326. line, e := readLine(br)
  327. if e != nil {
  328. err = e
  329. return br, hs, err
  330. }
  331. if len(line) == 0 {
  332. // Blank line, no more lines to read.
  333. break
  334. }
  335. k, v, ok := httpParseHeaderLine(line)
  336. if !ok {
  337. err = ErrMalformedResponse
  338. return br, hs, err
  339. }
  340. switch btsToString(k) {
  341. case headerUpgradeCanonical:
  342. headerSeen |= headerSeenUpgrade
  343. if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
  344. err = ErrHandshakeBadUpgrade
  345. return br, hs, err
  346. }
  347. case headerConnectionCanonical:
  348. headerSeen |= headerSeenConnection
  349. // Note that as RFC6455 says:
  350. // > A |Connection| header field with value "Upgrade".
  351. // That is, in server side, "Connection" header could contain
  352. // multiple token. But in response it must contains exactly one.
  353. if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
  354. err = ErrHandshakeBadConnection
  355. return br, hs, err
  356. }
  357. case headerSecAcceptCanonical:
  358. headerSeen |= headerSeenSecAccept
  359. if !checkAcceptFromNonce(v, nonce) {
  360. err = ErrHandshakeBadSecAccept
  361. return br, hs, err
  362. }
  363. case headerSecProtocolCanonical:
  364. // RFC6455 1.3:
  365. // "The server selects one or none of the acceptable protocols
  366. // and echoes that value in its handshake to indicate that it has
  367. // selected that protocol."
  368. for _, want := range d.Protocols {
  369. if string(v) == want {
  370. hs.Protocol = want
  371. break
  372. }
  373. }
  374. if hs.Protocol == "" {
  375. // Server echoed subprotocol that is not present in client
  376. // requested protocols.
  377. err = ErrHandshakeBadSubProtocol
  378. return br, hs, err
  379. }
  380. case headerSecExtensionsCanonical:
  381. hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
  382. if err != nil {
  383. return br, hs, err
  384. }
  385. default:
  386. if onHeader := d.OnHeader; onHeader != nil {
  387. if e := onHeader(k, v); e != nil {
  388. err = e
  389. return br, hs, err
  390. }
  391. }
  392. }
  393. }
  394. if err == nil && headerSeen != headerSeenAll {
  395. switch {
  396. case headerSeen&headerSeenUpgrade == 0:
  397. err = ErrHandshakeBadUpgrade
  398. case headerSeen&headerSeenConnection == 0:
  399. err = ErrHandshakeBadConnection
  400. case headerSeen&headerSeenSecAccept == 0:
  401. err = ErrHandshakeBadSecAccept
  402. default:
  403. panic("unknown headers state")
  404. }
  405. }
  406. return br, hs, err
  407. }
  408. // PutReader returns bufio.Reader instance to the inner reuse pool.
  409. // It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
  410. // contains unprocessed buffered data, that was sent by the server quickly
  411. // right after handshake.
  412. func PutReader(br *bufio.Reader) {
  413. pbufio.PutReader(br)
  414. }
  415. // StatusError contains an unexpected status-line code from the server.
  416. type StatusError int
  417. func (s StatusError) Error() string {
  418. return "unexpected HTTP response status: " + strconv.Itoa(int(s))
  419. }
  420. func isTimeoutError(err error) bool {
  421. t, ok := err.(net.Error)
  422. return ok && t.Timeout()
  423. }
  424. func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
  425. if len(selected) == 0 {
  426. return received, nil
  427. }
  428. var (
  429. index int
  430. option httphead.Option
  431. err error
  432. )
  433. index = -1
  434. match := func() (ok bool) {
  435. for _, want := range wanted {
  436. // A server accepts one or more extensions by including a
  437. // |Sec-WebSocket-Extensions| header field containing one or more
  438. // extensions that were requested by the client.
  439. //
  440. // The interpretation of any extension parameters, and what
  441. // constitutes a valid response by a server to a requested set of
  442. // parameters by a client, will be defined by each such extension.
  443. if bytes.Equal(option.Name, want.Name) {
  444. // Check parsed extension to be present in client
  445. // requested extensions. We move matched extension
  446. // from client list to avoid allocation of httphead.Option.Name,
  447. // httphead.Option.Parameters have to be copied from the header
  448. want.Parameters, _ = option.Parameters.Copy(make([]byte, option.Parameters.Size()))
  449. received = append(received, want)
  450. return true
  451. }
  452. }
  453. return false
  454. }
  455. ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
  456. if i != index {
  457. // Met next option.
  458. index = i
  459. if i != 0 && !match() {
  460. // Server returned non-requested extension.
  461. err = ErrHandshakeBadExtensions
  462. return httphead.ControlBreak
  463. }
  464. option = httphead.Option{Name: name}
  465. }
  466. if attr != nil {
  467. option.Parameters.Set(attr, val)
  468. }
  469. return httphead.ControlContinue
  470. })
  471. if !ok {
  472. err = ErrMalformedResponse
  473. return received, err
  474. }
  475. if !match() {
  476. return received, ErrHandshakeBadExtensions
  477. }
  478. return received, err
  479. }
  480. // setupContextDeadliner is a helper function that starts connection I/O
  481. // interrupter goroutine.
  482. //
  483. // Started goroutine calls SetDeadline() with long time ago value when context
  484. // become expired to make any I/O operations failed. It returns done function
  485. // that stops started goroutine and maps error received from conn I/O methods
  486. // to possible context expiration error.
  487. //
  488. // In concern with possible SetDeadline() call inside interrupter goroutine,
  489. // caller passes pointer to its I/O error (even if it is nil) to done(&err).
  490. // That is, even if I/O error is nil, context could be already expired and
  491. // connection "poisoned" by SetDeadline() call. In that case done(&err) will
  492. // store at *err ctx.Err() result. If err is caused not by timeout, it will
  493. // leaved untouched.
  494. func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
  495. var (
  496. quit = make(chan struct{})
  497. interrupt = make(chan error, 1)
  498. )
  499. go func() {
  500. select {
  501. case <-quit:
  502. interrupt <- nil
  503. case <-ctx.Done():
  504. // Cancel i/o immediately.
  505. conn.SetDeadline(aLongTimeAgo)
  506. interrupt <- ctx.Err()
  507. }
  508. }()
  509. return func(err *error) {
  510. close(quit)
  511. // If ctx.Err() is non-nil and the original err is net.Error with
  512. // Timeout() == true, then it means that I/O was canceled by us by
  513. // SetDeadline(aLongTimeAgo) call, or by somebody else previously
  514. // by conn.SetDeadline(x).
  515. //
  516. // Even on race condition when both deadlines are expired
  517. // (SetDeadline() made not by us and context's), we prefer ctx.Err() to
  518. // be returned.
  519. if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
  520. *err = ctxErr
  521. }
  522. }
  523. }