recover from panics such as out-of-bounds array access & nil pointer access, print a stack trace and return 5xx error instead of the service crashing and relying on Execution framework to handle crashes
303 lines
8.6 KiB
Go
303 lines
8.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
)
|
|
|
|
// TODO: Handle TLS proxy
|
|
|
|
type (
|
|
// ProxyConfig defines the config for Proxy middleware.
|
|
ProxyConfig struct {
|
|
// Skipper defines a function to skip middleware.
|
|
Skipper Skipper
|
|
|
|
// Balancer defines a load balancing technique.
|
|
// Required.
|
|
Balancer ProxyBalancer
|
|
|
|
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
|
|
// retrieved by index e.g. $1, $2 and so on.
|
|
// Examples:
|
|
// "/old": "/new",
|
|
// "/api/*": "/$1",
|
|
// "/js/*": "/public/javascripts/$1",
|
|
// "/users/*/orders/*": "/user/$1/order/$2",
|
|
Rewrite map[string]string
|
|
|
|
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures
|
|
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
|
|
// Example:
|
|
// "^/old/[0.9]+/": "/new",
|
|
// "^/api/.+?/(.*)": "/v2/$1",
|
|
RegexRewrite map[*regexp.Regexp]string
|
|
|
|
// Context key to store selected ProxyTarget into context.
|
|
// Optional. Default value "target".
|
|
ContextKey string
|
|
|
|
// To customize the transport to remote.
|
|
// Examples: If custom TLS certificates are required.
|
|
Transport http.RoundTripper
|
|
|
|
// ModifyResponse defines function to modify response from ProxyTarget.
|
|
ModifyResponse func(*http.Response) error
|
|
}
|
|
|
|
// ProxyTarget defines the upstream target.
|
|
ProxyTarget struct {
|
|
Name string
|
|
URL *url.URL
|
|
Meta echo.Map
|
|
}
|
|
|
|
// ProxyBalancer defines an interface to implement a load balancing technique.
|
|
ProxyBalancer interface {
|
|
AddTarget(*ProxyTarget) bool
|
|
RemoveTarget(string) bool
|
|
Next(echo.Context) *ProxyTarget
|
|
}
|
|
|
|
commonBalancer struct {
|
|
targets []*ProxyTarget
|
|
mutex sync.RWMutex
|
|
}
|
|
|
|
// RandomBalancer implements a random load balancing technique.
|
|
randomBalancer struct {
|
|
*commonBalancer
|
|
random *rand.Rand
|
|
}
|
|
|
|
// RoundRobinBalancer implements a round-robin load balancing technique.
|
|
roundRobinBalancer struct {
|
|
*commonBalancer
|
|
i uint32
|
|
}
|
|
)
|
|
|
|
var (
|
|
// DefaultProxyConfig is the default Proxy middleware config.
|
|
DefaultProxyConfig = ProxyConfig{
|
|
Skipper: DefaultSkipper,
|
|
ContextKey: "target",
|
|
}
|
|
)
|
|
|
|
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
in, _, err := c.Response().Hijack()
|
|
if err != nil {
|
|
c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
|
|
return
|
|
}
|
|
defer in.Close()
|
|
|
|
out, err := net.Dial("tcp", t.URL.Host)
|
|
if err != nil {
|
|
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
|
|
return
|
|
}
|
|
defer out.Close()
|
|
|
|
// Write header
|
|
err = r.Write(out)
|
|
if err != nil {
|
|
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
|
|
return
|
|
}
|
|
|
|
errCh := make(chan error, 2)
|
|
cp := func(dst io.Writer, src io.Reader) {
|
|
_, err = io.Copy(dst, src)
|
|
errCh <- err
|
|
}
|
|
|
|
go cp(out, in)
|
|
go cp(in, out)
|
|
err = <-errCh
|
|
if err != nil && err != io.EOF {
|
|
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
|
|
}
|
|
})
|
|
}
|
|
|
|
// NewRandomBalancer returns a random proxy balancer.
|
|
func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
|
|
b := &randomBalancer{commonBalancer: new(commonBalancer)}
|
|
b.targets = targets
|
|
return b
|
|
}
|
|
|
|
// NewRoundRobinBalancer returns a round-robin proxy balancer.
|
|
func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {
|
|
b := &roundRobinBalancer{commonBalancer: new(commonBalancer)}
|
|
b.targets = targets
|
|
return b
|
|
}
|
|
|
|
// AddTarget adds an upstream target to the list.
|
|
func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {
|
|
for _, t := range b.targets {
|
|
if t.Name == target.Name {
|
|
return false
|
|
}
|
|
}
|
|
b.mutex.Lock()
|
|
defer b.mutex.Unlock()
|
|
b.targets = append(b.targets, target)
|
|
return true
|
|
}
|
|
|
|
// RemoveTarget removes an upstream target from the list.
|
|
func (b *commonBalancer) RemoveTarget(name string) bool {
|
|
b.mutex.Lock()
|
|
defer b.mutex.Unlock()
|
|
for i, t := range b.targets {
|
|
if t.Name == name {
|
|
b.targets = append(b.targets[:i], b.targets[i+1:]...)
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Next randomly returns an upstream target.
|
|
func (b *randomBalancer) Next(c echo.Context) *ProxyTarget {
|
|
if b.random == nil {
|
|
b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
|
}
|
|
b.mutex.RLock()
|
|
defer b.mutex.RUnlock()
|
|
return b.targets[b.random.Intn(len(b.targets))]
|
|
}
|
|
|
|
// Next returns an upstream target using round-robin technique.
|
|
func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
|
|
b.i = b.i % uint32(len(b.targets))
|
|
t := b.targets[b.i]
|
|
atomic.AddUint32(&b.i, 1)
|
|
return t
|
|
}
|
|
|
|
// Proxy returns a Proxy middleware.
|
|
//
|
|
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
|
|
func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
|
|
c := DefaultProxyConfig
|
|
c.Balancer = balancer
|
|
return ProxyWithConfig(c)
|
|
}
|
|
|
|
// ProxyWithConfig returns a Proxy middleware with config.
|
|
// See: `Proxy()`
|
|
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
|
// Defaults
|
|
if config.Skipper == nil {
|
|
config.Skipper = DefaultProxyConfig.Skipper
|
|
}
|
|
if config.Balancer == nil {
|
|
panic("echo: proxy middleware requires balancer")
|
|
}
|
|
|
|
if config.Rewrite != nil {
|
|
if config.RegexRewrite == nil {
|
|
config.RegexRewrite = make(map[*regexp.Regexp]string)
|
|
}
|
|
for k, v := range rewriteRulesRegex(config.Rewrite) {
|
|
config.RegexRewrite[k] = v
|
|
}
|
|
}
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) (err error) {
|
|
if config.Skipper(c) {
|
|
return next(c)
|
|
}
|
|
|
|
req := c.Request()
|
|
res := c.Response()
|
|
tgt := config.Balancer.Next(c)
|
|
c.Set(config.ContextKey, tgt)
|
|
|
|
if err := rewriteURL(config.RegexRewrite, req); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Fix header
|
|
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
|
|
// However, for backward compatibility, legacy behavior is preserved unless you configure Echo#IPExtractor.
|
|
if req.Header.Get(echo.HeaderXRealIP) == "" || c.Echo().IPExtractor != nil {
|
|
req.Header.Set(echo.HeaderXRealIP, c.RealIP())
|
|
}
|
|
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
|
|
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
|
|
}
|
|
if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
|
|
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
|
|
}
|
|
|
|
// Proxy
|
|
switch {
|
|
case c.IsWebSocket():
|
|
proxyRaw(tgt, c).ServeHTTP(res, req)
|
|
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
|
|
default:
|
|
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
|
|
}
|
|
if e, ok := c.Get("_error").(error); ok {
|
|
err = e
|
|
}
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// StatusCodeContextCanceled is a custom HTTP status code for situations
|
|
// where a client unexpectedly closed the connection to the server.
|
|
// As there is no standard error code for "client closed connection", but
|
|
// various well-known HTTP clients and server implement this HTTP code we use
|
|
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
|
|
const StatusCodeContextCanceled = 499
|
|
|
|
func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
|
|
proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
|
|
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
|
|
desc := tgt.URL.String()
|
|
if tgt.Name != "" {
|
|
desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String())
|
|
}
|
|
// If the client canceled the request (usually by closing the connection), we can report a
|
|
// client error (4xx) instead of a server error (5xx) to correctly identify the situation.
|
|
// The Go standard library (at of late 2020) wraps the exported, standard
|
|
// context.Canceled error with unexported garbage value requiring a substring check, see
|
|
// https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430
|
|
if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") {
|
|
httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err))
|
|
httpError.Internal = err
|
|
c.Set("_error", httpError)
|
|
} else {
|
|
httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))
|
|
httpError.Internal = err
|
|
c.Set("_error", httpError)
|
|
}
|
|
}
|
|
proxy.Transport = config.Transport
|
|
proxy.ModifyResponse = config.ModifyResponse
|
|
return proxy
|
|
}
|