cloudapi: use Recover middleware to handle panics

recover from panics such as out-of-bounds array access & nil
pointer access, print a stack trace and return 5xx error
instead of the service crashing and relying on Execution
framework to handle crashes
This commit is contained in:
Diaa Sami 2021-09-23 18:18:54 +02:00 committed by Tom Gundersen
parent bc9e340ca5
commit 60e403e53e
34 changed files with 4304 additions and 0 deletions

View file

@ -0,0 +1,106 @@
package middleware
import (
"encoding/base64"
"strconv"
"strings"
"github.com/labstack/echo/v4"
)
type (
// BasicAuthConfig defines the config for BasicAuth middleware.
BasicAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Validator is a function to validate BasicAuth credentials.
// Required.
Validator BasicAuthValidator
// Realm is a string to define realm attribute of BasicAuth.
// Default value "Restricted".
Realm string
}
// BasicAuthValidator defines a function to validate BasicAuth credentials.
BasicAuthValidator func(string, string, echo.Context) (bool, error)
)
const (
basic = "basic"
defaultRealm = "Restricted"
)
var (
// DefaultBasicAuthConfig is the default BasicAuth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{
Skipper: DefaultSkipper,
Realm: defaultRealm,
}
)
// BasicAuth returns an BasicAuth middleware.
//
// For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.Validator = fn
return BasicAuthWithConfig(c)
}
// BasicAuthWithConfig returns an BasicAuth middleware with config.
// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
// Defaults
if config.Validator == nil {
panic("echo: basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
config.Skipper = DefaultBasicAuthConfig.Skipper
}
if config.Realm == "" {
config.Realm = defaultRealm
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
auth := c.Request().Header.Get(echo.HeaderAuthorization)
l := len(basic)
if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return err
}
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
valid, err := config.Validator(cred[:i], cred[i+1:], c)
if err != nil {
return err
} else if valid {
return next(c)
}
break
}
}
}
realm := defaultRealm
if config.Realm != defaultRealm {
realm = strconv.Quote(config.Realm)
}
// Need to return `401` for browsers to pop-up login box.
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized
}
}
}

View file

@ -0,0 +1,107 @@
package middleware
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"net"
"net/http"
"github.com/labstack/echo/v4"
)
type (
// BodyDumpConfig defines the config for BodyDump middleware.
BodyDumpConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Handler receives request and response payload.
// Required.
Handler BodyDumpHandler
}
// BodyDumpHandler receives the request and response payload.
BodyDumpHandler func(echo.Context, []byte, []byte)
bodyDumpResponseWriter struct {
io.Writer
http.ResponseWriter
}
)
var (
// DefaultBodyDumpConfig is the default BodyDump middleware config.
DefaultBodyDumpConfig = BodyDumpConfig{
Skipper: DefaultSkipper,
}
)
// BodyDump returns a BodyDump middleware.
//
// BodyDump middleware captures the request and response payload and calls the
// registered handler.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
c := DefaultBodyDumpConfig
c.Handler = handler
return BodyDumpWithConfig(c)
}
// BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
// Defaults
if config.Handler == nil {
panic("echo: body-dump middleware requires a handler function")
}
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
// Request
reqBody := []byte{}
if c.Request().Body != nil { // Read
reqBody, _ = ioutil.ReadAll(c.Request().Body)
}
c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset
// Response
resBody := new(bytes.Buffer)
mw := io.MultiWriter(c.Response().Writer, resBody)
writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
c.Response().Writer = writer
if err = next(c); err != nil {
c.Error(err)
}
// Callback
config.Handler(c, reqBody, resBody.Bytes())
return
}
}
}
func (w *bodyDumpResponseWriter) WriteHeader(code int) {
w.ResponseWriter.WriteHeader(code)
}
func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
func (w *bodyDumpResponseWriter) Flush() {
w.ResponseWriter.(http.Flusher).Flush()
}
func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
}

View file

@ -0,0 +1,117 @@
package middleware
import (
"fmt"
"io"
"sync"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
)
type (
// BodyLimitConfig defines the config for BodyLimit middleware.
BodyLimitConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Maximum allowed size for a request body, it can be specified
// as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
Limit string `yaml:"limit"`
limit int64
}
limitedReader struct {
BodyLimitConfig
reader io.ReadCloser
read int64
context echo.Context
}
)
var (
// DefaultBodyLimitConfig is the default BodyLimit middleware config.
DefaultBodyLimitConfig = BodyLimitConfig{
Skipper: DefaultSkipper,
}
)
// BodyLimit returns a BodyLimit middleware.
//
// BodyLimit middleware sets the maximum allowed size for a request body, if the
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
// response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
// G, T or P.
func BodyLimit(limit string) echo.MiddlewareFunc {
c := DefaultBodyLimitConfig
c.Limit = limit
return BodyLimitWithConfig(c)
}
// BodyLimitWithConfig returns a BodyLimit middleware with config.
// See: `BodyLimit()`.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultBodyLimitConfig.Skipper
}
limit, err := bytes.Parse(config.Limit)
if err != nil {
panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit))
}
config.limit = limit
pool := limitedReaderPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
// Based on content length
if req.ContentLength > config.limit {
return echo.ErrStatusRequestEntityTooLarge
}
// Based on content read
r := pool.Get().(*limitedReader)
r.Reset(req.Body, c)
defer pool.Put(r)
req.Body = r
return next(c)
}
}
}
func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += int64(n)
if r.read > r.limit {
return n, echo.ErrStatusRequestEntityTooLarge
}
return
}
func (r *limitedReader) Close() error {
return r.reader.Close()
}
func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) {
r.reader = reader
r.context = context
r.read = 0
}
func limitedReaderPool(c BodyLimitConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: c}
},
}
}

View file

@ -0,0 +1,146 @@
package middleware
import (
"bufio"
"compress/gzip"
"io"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
"github.com/labstack/echo/v4"
)
type (
// GzipConfig defines the config for Gzip middleware.
GzipConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Gzip compression level.
// Optional. Default value -1.
Level int `yaml:"level"`
}
gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
)
const (
gzipScheme = "gzip"
)
var (
// DefaultGzipConfig is the default Gzip middleware config.
DefaultGzipConfig = GzipConfig{
Skipper: DefaultSkipper,
Level: -1,
}
)
// Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
func Gzip() echo.MiddlewareFunc {
return GzipWithConfig(DefaultGzipConfig)
}
// GzipWithConfig return Gzip middleware with config.
// See: `Gzip()`.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.Level == 0 {
config.Level = DefaultGzipConfig.Level
}
pool := gzipCompressPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
res := c.Response()
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
i := pool.Get()
w, ok := i.(*gzip.Writer)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
rw := res.Writer
w.Reset(rw)
defer func() {
if res.Size == 0 {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
res.Header().Del(echo.HeaderContentEncoding)
}
// We have to reset response to it's pristine state when
// nothing is written to body or error is returned.
// See issue #424, #407.
res.Writer = rw
w.Reset(ioutil.Discard)
}
w.Close()
pool.Put(w)
}()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
res.Writer = grw
}
return next(c)
}
}
}
func (w *gzipResponseWriter) WriteHeader(code int) {
if code == http.StatusNoContent { // Issue #489
w.ResponseWriter.Header().Del(echo.HeaderContentEncoding)
}
w.Header().Del(echo.HeaderContentLength) // Issue #444
w.ResponseWriter.WriteHeader(code)
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
if w.Header().Get(echo.HeaderContentType) == "" {
w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
}
return w.Writer.Write(b)
}
func (w *gzipResponseWriter) Flush() {
w.Writer.(*gzip.Writer).Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
}
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
if p, ok := w.ResponseWriter.(http.Pusher); ok {
return p.Push(target, opts)
}
return http.ErrNotSupported
}
func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
}

211
vendor/github.com/labstack/echo/v4/middleware/cors.go generated vendored Normal file
View file

@ -0,0 +1,211 @@
package middleware
import (
"net/http"
"regexp"
"strconv"
"strings"
"github.com/labstack/echo/v4"
)
type (
// CORSConfig defines the config for CORS middleware.
CORSConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
AllowMethods []string `yaml:"allow_methods"`
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
// Optional. Default value []string{}.
AllowHeaders []string `yaml:"allow_headers"`
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
// Optional. Default value false.
AllowCredentials bool `yaml:"allow_credentials"`
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
// Optional. Default value []string{}.
ExposeHeaders []string `yaml:"expose_headers"`
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
// Optional. Default value 0.
MaxAge int `yaml:"max_age"`
}
)
var (
// DefaultCORSConfig is the default CORS middleware config.
DefaultCORSConfig = CORSConfig{
Skipper: DefaultSkipper,
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}
)
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
func CORS() echo.MiddlewareFunc {
return CORSWithConfig(DefaultCORSConfig)
}
// CORSWithConfig returns a CORS middleware with config.
// See: `CORS()`.
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper
}
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}
allowOriginPatterns := []string{}
for _, origin := range config.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.Replace(pattern, "\\*", ".*", -1)
pattern = strings.Replace(pattern, "\\?", ".", -1)
pattern = "^" + pattern + "$"
allowOriginPatterns = append(allowOriginPatterns, pattern)
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
maxAge := strconv.Itoa(config.MaxAge)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := ""
preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// No Origin provided
if origin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
// Simple request
if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
}
return next(c)
}
// Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
if allowHeaders != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := req.Header.Get(echo.HeaderAccessControlRequestHeaders)
if h != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
}
}
if config.MaxAge > 0 {
res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
}
return c.NoContent(http.StatusNoContent)
}
}
}

221
vendor/github.com/labstack/echo/v4/middleware/csrf.go generated vendored Normal file
View file

@ -0,0 +1,221 @@
package middleware
import (
"crypto/subtle"
"errors"
"net/http"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
)
type (
// CSRFConfig defines the config for CSRF middleware.
CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// TokenLength is the length of the generated token.
TokenLength uint8 `yaml:"token_length"`
// Optional. Default value 32.
// TokenLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// Optional. Default value "header:X-CSRF-Token".
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "query:<name>"
TokenLookup string `yaml:"token_lookup"`
// Context key to store generated CSRF token into context.
// Optional. Default value "csrf".
ContextKey string `yaml:"context_key"`
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
CookieName string `yaml:"cookie_name"`
// Domain of the CSRF cookie.
// Optional. Default value none.
CookieDomain string `yaml:"cookie_domain"`
// Path of the CSRF cookie.
// Optional. Default value none.
CookiePath string `yaml:"cookie_path"`
// Max age (in seconds) of the CSRF cookie.
// Optional. Default value 86400 (24hr).
CookieMaxAge int `yaml:"cookie_max_age"`
// Indicates if CSRF cookie is secure.
// Optional. Default value false.
CookieSecure bool `yaml:"cookie_secure"`
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool `yaml:"cookie_http_only"`
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
}
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
// either a token or an error.
csrfTokenExtractor func(echo.Context) (string, error)
)
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper,
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf",
CookieName: "_csrf",
CookieMaxAge: 86400,
CookieSameSite: http.SameSiteDefaultMode,
}
)
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc {
c := DefaultCSRFConfig
return CSRFWithConfig(c)
}
// CSRFWithConfig returns a CSRF middleware with config.
// See `CSRF()`.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper
}
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}
if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup
}
if config.ContextKey == "" {
config.ContextKey = DefaultCSRFConfig.ContextKey
}
if config.CookieName == "" {
config.CookieName = DefaultCSRFConfig.CookieName
}
if config.CookieMaxAge == 0 {
config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
}
if config.CookieSameSite == http.SameSiteNoneMode {
config.CookieSecure = true
}
// Initialize
parts := strings.Split(config.TokenLookup, ":")
extractor := csrfTokenFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = csrfTokenFromForm(parts[1])
case "query":
extractor = csrfTokenFromQuery(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
k, err := c.Cookie(config.CookieName)
token := ""
// Generate token
if err != nil {
token = random.String(config.TokenLength)
} else {
// Reuse token
token = k.Value
}
switch req.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default:
// Validate token only for requests which are not defined as 'safe' by RFC7231
clientToken, err := extractor(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if !validateCSRFToken(token, clientToken) {
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
}
}
// Set CSRF cookie
cookie := new(http.Cookie)
cookie.Name = config.CookieName
cookie.Value = token
if config.CookiePath != "" {
cookie.Path = config.CookiePath
}
if config.CookieDomain != "" {
cookie.Domain = config.CookieDomain
}
if config.CookieSameSite != http.SameSiteDefaultMode {
cookie.SameSite = config.CookieSameSite
}
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
cookie.Secure = config.CookieSecure
cookie.HttpOnly = config.CookieHTTPOnly
c.SetCookie(cookie)
// Store token in the context
c.Set(config.ContextKey, token)
// Protect clients from caching the response
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
return next(c)
}
}
}
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided request header.
func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
return c.Request().Header.Get(header), nil
}
}
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided form parameter.
func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("missing csrf token in the form parameter")
}
return token, nil
}
}
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
// provided query parameter.
func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("missing csrf token in the query string")
}
return token, nil
}
}
func validateCSRFToken(token, clientToken string) bool {
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
}

View file

@ -0,0 +1,120 @@
package middleware
import (
"bytes"
"compress/gzip"
"io"
"io/ioutil"
"net/http"
"sync"
"github.com/labstack/echo/v4"
)
type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
// create with an empty reader (but with GZIP header)
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
if err != nil {
return err
}
b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()
r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
return r
},
}
}
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
}
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
b := c.Request().Body
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
if err := gr.Reset(b); err != nil {
pool.Put(gr)
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
var buf bytes.Buffer
io.Copy(&buf, gr)
gr.Close()
pool.Put(gr)
b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
r := ioutil.NopCloser(&buf)
c.Request().Body = r
}
return next(c)
}
}
}

347
vendor/github.com/labstack/echo/v4/middleware/jwt.go generated vendored Normal file
View file

@ -0,0 +1,347 @@
// +build go1.15
package middleware
import (
"errors"
"fmt"
"net/http"
"reflect"
"strings"
"github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4"
)
type (
// JWTConfig defines the config for JWT middleware.
JWTConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc BeforeFunc
// SuccessHandler defines a function which is executed for a valid token.
SuccessHandler JWTSuccessHandler
// ErrorHandler defines a function which is executed for an invalid token.
// It may be used to define a custom JWT error.
ErrorHandler JWTErrorHandler
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
ErrorHandlerWithContext JWTErrorHandlerWithContext
// Signing key to validate token.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKeys is provided.
SigningKey interface{}
// Map of signing keys to validate token with kid field usage.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKey is provided.
SigningKeys map[string]interface{}
// Signing method used to check the token's signing algorithm.
// Optional. Default value HS256.
SigningMethod string
// Context key to store user information from the token into context.
// Optional. Default value "user".
ContextKey string
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
// Not used if custom ParseTokenFunc is set.
// Optional. Default value jwt.MapClaims
Claims jwt.Claims
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
// Multiply sources example:
// - "header: Authorization,cookie: myowncookie"
TokenLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
// Used by default ParseTokenFunc implementation.
//
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided.
// Not used if custom ParseTokenFunc is set.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
}
// JWTSuccessHandler defines a function which is executed for a valid token.
JWTSuccessHandler func(echo.Context)
// JWTErrorHandler defines a function which is executed for an invalid token.
JWTErrorHandler func(error) error
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(error, echo.Context) error
jwtExtractor func(echo.Context) (string, error)
)
// Algorithms
const (
AlgorithmHS256 = "HS256"
)
// Errors
var (
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
)
var (
// DefaultJWTConfig is the default JWT auth middleware config.
DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper,
SigningMethod: AlgorithmHS256,
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
KeyFunc: nil,
}
)
// JWT returns a JSON Web Token (JWT) auth middleware.
//
// For valid token, it sets the user in context and calls next handler.
// For invalid token, it returns "401 - Unauthorized" error.
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
// See `JWTConfig.TokenLookup`
func JWT(key interface{}) echo.MiddlewareFunc {
c := DefaultJWTConfig
c.SigningKey = key
return JWTWithConfig(c)
}
// JWTWithConfig returns a JWT auth middleware with config.
// See: `JWT()`.
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper
}
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil {
panic("echo: jwt middleware requires signing key")
}
if config.SigningMethod == "" {
config.SigningMethod = DefaultJWTConfig.SigningMethod
}
if config.ContextKey == "" {
config.ContextKey = DefaultJWTConfig.ContextKey
}
if config.Claims == nil {
config.Claims = DefaultJWTConfig.Claims
}
if config.TokenLookup == "" {
config.TokenLookup = DefaultJWTConfig.TokenLookup
}
if config.AuthScheme == "" {
config.AuthScheme = DefaultJWTConfig.AuthScheme
}
if config.KeyFunc == nil {
config.KeyFunc = config.defaultKeyFunc
}
if config.ParseTokenFunc == nil {
config.ParseTokenFunc = config.defaultParseToken
}
// Initialize
// Split sources
sources := strings.Split(config.TokenLookup, ",")
var extractors []jwtExtractor
for _, source := range sources {
parts := strings.Split(source, ":")
switch parts[0] {
case "query":
extractors = append(extractors, jwtFromQuery(parts[1]))
case "param":
extractors = append(extractors, jwtFromParam(parts[1]))
case "cookie":
extractors = append(extractors, jwtFromCookie(parts[1]))
case "form":
extractors = append(extractors, jwtFromForm(parts[1]))
case "header":
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
if config.BeforeFunc != nil {
config.BeforeFunc(c)
}
var auth string
var err error
for _, extractor := range extractors {
// Extract token from extractor, if it's not fail break the loop and
// set auth
auth, err = extractor(c)
if err == nil {
break
}
}
// If none of extractor has a token, handle error
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err)
}
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
}
return err
}
token, err := config.ParseTokenFunc(auth, c)
if err == nil {
// Store user information from token into context.
c.Set(config.ContextKey, token)
if config.SuccessHandler != nil {
config.SuccessHandler(c)
}
return next(c)
}
if config.ErrorHandler != nil {
return config.ErrorHandler(err)
}
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
}
return &echo.HTTPError{
Code: ErrJWTInvalid.Code,
Message: ErrJWTInvalid.Message,
Internal: err,
}
}
}
}
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
token := new(jwt.Token)
var err error
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.KeyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
}
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
// defaultKeyFunc returns a signing key of the given token.
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(config.SigningKeys) > 0 {
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := config.SigningKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return config.SigningKey, nil
}
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", ErrJWTMissing
}
}
// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
func jwtFromQuery(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string.
func jwtFromParam(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.Param(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
func jwtFromCookie(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
return "", ErrJWTMissing
}
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View file

@ -0,0 +1,166 @@
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v4"
)
type (
// KeyAuthConfig defines the config for KeyAuth middleware.
KeyAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used
// to extract key from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "form:<name>"
KeyLookup string `yaml:"key_lookup"`
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator
// ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error.
ErrorHandler KeyAuthErrorHandler
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(string, echo.Context) (bool, error)
keyExtractor func(echo.Context) (string, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(error, echo.Context) error
)
var (
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: DefaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
}
)
// KeyAuth returns an KeyAuth middleware.
//
// For valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
c := DefaultKeyAuthConfig
c.Validator = fn
return KeyAuthWithConfig(c)
}
// KeyAuthWithConfig returns an KeyAuth middleware with config.
// See `KeyAuth()`.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper
}
// Defaults
if config.AuthScheme == "" {
config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
}
if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
}
if config.Validator == nil {
panic("echo: key-auth middleware requires a validator function")
}
// Initialize
parts := strings.Split(config.KeyLookup, ":")
extractor := keyFromHeader(parts[1], config.AuthScheme)
switch parts[0] {
case "query":
extractor = keyFromQuery(parts[1])
case "form":
extractor = keyFromForm(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
// Extract and verify key
key, err := extractor(c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
valid, err := config.Validator(key, c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid key",
Internal: err,
}
} else if valid {
return next(c)
}
return echo.ErrUnauthorized
}
}
}
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("missing key in request header")
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("invalid key in the request header")
}
return auth, nil
}
}
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("missing key in the query string")
}
return key, nil
}
}
// keyFromForm returns a `keyExtractor` that extracts key from the form.
func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("missing key in the form")
}
return key, nil
}
}

223
vendor/github.com/labstack/echo/v4/middleware/logger.go generated vendored Normal file
View file

@ -0,0 +1,223 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"strconv"
"strings"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/color"
"github.com/valyala/fasttemplate"
)
type (
// LoggerConfig defines the config for Logger middleware.
LoggerConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Tags to construct the logger format.
//
// - time_unix
// - time_unix_nano
// - time_rfc3339
// - time_rfc3339_nano
// - time_custom
// - id (Request ID)
// - remote_ip
// - uri
// - host
// - method
// - path
// - protocol
// - referer
// - user_agent
// - status
// - error
// - latency (In nanoseconds)
// - latency_human (Human readable)
// - bytes_in (Bytes received)
// - bytes_out (Bytes sent)
// - header:<NAME>
// - query:<NAME>
// - form:<NAME>
//
// Example "${remote_ip} ${status}"
//
// Optional. Default value DefaultLoggerConfig.Format.
Format string `yaml:"format"`
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string `yaml:"custom_time_format"`
// Output is a writer where logs in JSON format are written.
// Optional. Default value os.Stdout.
Output io.Writer
template *fasttemplate.Template
colorer *color.Color
pool *sync.Pool
}
)
var (
// DefaultLoggerConfig is the default Logger middleware config.
DefaultLoggerConfig = LoggerConfig{
Skipper: DefaultSkipper,
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
`"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
`"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
`,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
CustomTimeFormat: "2006-01-02 15:04:05.00000",
colorer: color.New(),
}
)
// Logger returns a middleware that logs HTTP requests.
func Logger() echo.MiddlewareFunc {
return LoggerWithConfig(DefaultLoggerConfig)
}
// LoggerWithConfig returns a Logger middleware with config.
// See: `Logger()`.
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultLoggerConfig.Skipper
}
if config.Format == "" {
config.Format = DefaultLoggerConfig.Format
}
if config.Output == nil {
config.Output = DefaultLoggerConfig.Output
}
config.template = fasttemplate.New(config.Format, "${", "}")
config.colorer = color.New()
config.colorer.SetOutput(config.Output)
config.pool = &sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256))
},
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
start := time.Now()
if err = next(c); err != nil {
c.Error(err)
}
stop := time.Now()
buf := config.pool.Get().(*bytes.Buffer)
buf.Reset()
defer config.pool.Put(buf)
if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
switch tag {
case "time_unix":
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
case "time_unix_nano":
return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10))
case "time_rfc3339":
return buf.WriteString(time.Now().Format(time.RFC3339))
case "time_rfc3339_nano":
return buf.WriteString(time.Now().Format(time.RFC3339Nano))
case "time_custom":
return buf.WriteString(time.Now().Format(config.CustomTimeFormat))
case "id":
id := req.Header.Get(echo.HeaderXRequestID)
if id == "" {
id = res.Header().Get(echo.HeaderXRequestID)
}
return buf.WriteString(id)
case "remote_ip":
return buf.WriteString(c.RealIP())
case "host":
return buf.WriteString(req.Host)
case "uri":
return buf.WriteString(req.RequestURI)
case "method":
return buf.WriteString(req.Method)
case "path":
p := req.URL.Path
if p == "" {
p = "/"
}
return buf.WriteString(p)
case "protocol":
return buf.WriteString(req.Proto)
case "referer":
return buf.WriteString(req.Referer())
case "user_agent":
return buf.WriteString(req.UserAgent())
case "status":
n := res.Status
s := config.colorer.Green(n)
switch {
case n >= 500:
s = config.colorer.Red(n)
case n >= 400:
s = config.colorer.Yellow(n)
case n >= 300:
s = config.colorer.Cyan(n)
}
return buf.WriteString(s)
case "error":
if err != nil {
// Error may contain invalid JSON e.g. `"`
b, _ := json.Marshal(err.Error())
b = b[1 : len(b)-1]
return buf.Write(b)
}
case "latency":
l := stop.Sub(start)
return buf.WriteString(strconv.FormatInt(int64(l), 10))
case "latency_human":
return buf.WriteString(stop.Sub(start).String())
case "bytes_in":
cl := req.Header.Get(echo.HeaderContentLength)
if cl == "" {
cl = "0"
}
return buf.WriteString(cl)
case "bytes_out":
return buf.WriteString(strconv.FormatInt(res.Size, 10))
default:
switch {
case strings.HasPrefix(tag, "header:"):
return buf.Write([]byte(c.Request().Header.Get(tag[7:])))
case strings.HasPrefix(tag, "query:"):
return buf.Write([]byte(c.QueryParam(tag[6:])))
case strings.HasPrefix(tag, "form:"):
return buf.Write([]byte(c.FormValue(tag[5:])))
case strings.HasPrefix(tag, "cookie:"):
cookie, err := c.Cookie(tag[7:])
if err == nil {
return buf.Write([]byte(cookie.Value))
}
}
}
return 0, nil
}); err != nil {
return
}
if config.Output == nil {
_, err = c.Logger().Output().Write(buf.Bytes())
return
}
_, err = config.Output.Write(buf.Bytes())
return
}
}
}

View file

@ -0,0 +1,92 @@
package middleware
import (
"net/http"
"github.com/labstack/echo/v4"
)
type (
// MethodOverrideConfig defines the config for MethodOverride middleware.
MethodOverrideConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Getter is a function that gets overridden method from the request.
// Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
Getter MethodOverrideGetter
}
// MethodOverrideGetter is a function that gets overridden method from the request
MethodOverrideGetter func(echo.Context) string
)
var (
// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
DefaultMethodOverrideConfig = MethodOverrideConfig{
Skipper: DefaultSkipper,
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
}
)
// MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and
// uses it instead of the original method.
//
// For security reasons, only `POST` method can be overridden.
func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
}
// MethodOverrideWithConfig returns a MethodOverride middleware with config.
// See: `MethodOverride()`.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper
}
if config.Getter == nil {
config.Getter = DefaultMethodOverrideConfig.Getter
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
if req.Method == http.MethodPost {
m := config.Getter(c)
if m != "" {
req.Method = m
}
}
return next(c)
}
}
}
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
// the request header.
func MethodFromHeader(header string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.Request().Header.Get(header)
}
}
// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
// form parameter.
func MethodFromForm(param string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.FormValue(param)
}
}
// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
// the query parameter.
func MethodFromQuery(param string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.QueryParam(param)
}
}

View file

@ -0,0 +1,89 @@
package middleware
import (
"net/http"
"regexp"
"strconv"
"strings"
"github.com/labstack/echo/v4"
)
type (
// Skipper defines a function to skip middleware. Returning true skips processing
// the middleware.
Skipper func(echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc func(echo.Context)
)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1)
if groups == nil {
return nil
}
values := groups[0][1:]
replace := make([]string, 2*len(values))
for i, v := range values {
j := 2 * i
replace[j] = "$" + strconv.Itoa(i+1)
replace[j+1] = v
}
return strings.NewReplacer(replace...)
}
func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
// Initialize
rulesRegex := map[*regexp.Regexp]string{}
for k, v := range rewrite {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*?)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
rulesRegex[regexp.MustCompile(k)] = v
}
return rulesRegex
}
func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error {
if len(rewriteRegex) == 0 {
return nil
}
// Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
// We only want to use path part for rewriting and therefore trim prefix if it exists
rawURI := req.RequestURI
if rawURI != "" && rawURI[0] != '/' {
prefix := ""
if req.URL.Scheme != "" {
prefix = req.URL.Scheme + "://"
}
if req.URL.Host != "" {
prefix += req.URL.Host // host or host:port
}
if prefix != "" {
rawURI = strings.TrimPrefix(rawURI, prefix)
}
}
for k, v := range rewriteRegex {
if replacer := captureTokens(k, rawURI); replacer != nil {
url, err := req.URL.Parse(replacer.Replace(v))
if err != nil {
return err
}
req.URL = url
return nil // rewrite only once
}
}
return nil
}
// DefaultSkipper returns false which processes the middleware.
func DefaultSkipper(echo.Context) bool {
return false
}

303
vendor/github.com/labstack/echo/v4/middleware/proxy.go generated vendored Normal file
View file

@ -0,0 +1,303 @@
package middleware
import (
"context"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/labstack/echo/v4"
)
// TODO: Handle TLS proxy
type (
// ProxyConfig defines the config for Proxy middleware.
ProxyConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Balancer defines a load balancing technique.
// Required.
Balancer ProxyBalancer
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Examples:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
Rewrite map[string]string
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRewrite map[*regexp.Regexp]string
// Context key to store selected ProxyTarget into context.
// Optional. Default value "target".
ContextKey string
// To customize the transport to remote.
// Examples: If custom TLS certificates are required.
Transport http.RoundTripper
// ModifyResponse defines function to modify response from ProxyTarget.
ModifyResponse func(*http.Response) error
}
// ProxyTarget defines the upstream target.
ProxyTarget struct {
Name string
URL *url.URL
Meta echo.Map
}
// ProxyBalancer defines an interface to implement a load balancing technique.
ProxyBalancer interface {
AddTarget(*ProxyTarget) bool
RemoveTarget(string) bool
Next(echo.Context) *ProxyTarget
}
commonBalancer struct {
targets []*ProxyTarget
mutex sync.RWMutex
}
// RandomBalancer implements a random load balancing technique.
randomBalancer struct {
*commonBalancer
random *rand.Rand
}
// RoundRobinBalancer implements a round-robin load balancing technique.
roundRobinBalancer struct {
*commonBalancer
i uint32
}
)
var (
// DefaultProxyConfig is the default Proxy middleware config.
DefaultProxyConfig = ProxyConfig{
Skipper: DefaultSkipper,
ContextKey: "target",
}
)
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack()
if err != nil {
c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
return
}
defer in.Close()
out, err := net.Dial("tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
return
}
defer out.Close()
// Write header
err = r.Write(out)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
return
}
errCh := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) {
_, err = io.Copy(dst, src)
errCh <- err
}
go cp(out, in)
go cp(in, out)
err = <-errCh
if err != nil && err != io.EOF {
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
}
})
}
// NewRandomBalancer returns a random proxy balancer.
func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
b := &randomBalancer{commonBalancer: new(commonBalancer)}
b.targets = targets
return b
}
// NewRoundRobinBalancer returns a round-robin proxy balancer.
func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {
b := &roundRobinBalancer{commonBalancer: new(commonBalancer)}
b.targets = targets
return b
}
// AddTarget adds an upstream target to the list.
func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {
for _, t := range b.targets {
if t.Name == target.Name {
return false
}
}
b.mutex.Lock()
defer b.mutex.Unlock()
b.targets = append(b.targets, target)
return true
}
// RemoveTarget removes an upstream target from the list.
func (b *commonBalancer) RemoveTarget(name string) bool {
b.mutex.Lock()
defer b.mutex.Unlock()
for i, t := range b.targets {
if t.Name == name {
b.targets = append(b.targets[:i], b.targets[i+1:]...)
return true
}
}
return false
}
// Next randomly returns an upstream target.
func (b *randomBalancer) Next(c echo.Context) *ProxyTarget {
if b.random == nil {
b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
}
b.mutex.RLock()
defer b.mutex.RUnlock()
return b.targets[b.random.Intn(len(b.targets))]
}
// Next returns an upstream target using round-robin technique.
func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
b.i = b.i % uint32(len(b.targets))
t := b.targets[b.i]
atomic.AddUint32(&b.i, 1)
return t
}
// Proxy returns a Proxy middleware.
//
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
c := DefaultProxyConfig
c.Balancer = balancer
return ProxyWithConfig(c)
}
// ProxyWithConfig returns a Proxy middleware with config.
// See: `Proxy()`
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper
}
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
}
if config.Rewrite != nil {
if config.RegexRewrite == nil {
config.RegexRewrite = make(map[*regexp.Regexp]string)
}
for k, v := range rewriteRulesRegex(config.Rewrite) {
config.RegexRewrite[k] = v
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
tgt := config.Balancer.Next(c)
c.Set(config.ContextKey, tgt)
if err := rewriteURL(config.RegexRewrite, req); err != nil {
return err
}
// Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
// However, for backward compatibility, legacy behavior is preserved unless you configure Echo#IPExtractor.
if req.Header.Get(echo.HeaderXRealIP) == "" || c.Echo().IPExtractor != nil {
req.Header.Set(echo.HeaderXRealIP, c.RealIP())
}
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
}
if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
}
if e, ok := c.Get("_error").(error); ok {
err = e
}
return
}
}
}
// StatusCodeContextCanceled is a custom HTTP status code for situations
// where a client unexpectedly closed the connection to the server.
// As there is no standard error code for "client closed connection", but
// various well-known HTTP clients and server implement this HTTP code we use
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
const StatusCodeContextCanceled = 499
func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
desc := tgt.URL.String()
if tgt.Name != "" {
desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String())
}
// If the client canceled the request (usually by closing the connection), we can report a
// client error (4xx) instead of a server error (5xx) to correctly identify the situation.
// The Go standard library (at of late 2020) wraps the exported, standard
// context.Canceled error with unexported garbage value requiring a substring check, see
// https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430
if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") {
httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err))
httpError.Internal = err
c.Set("_error", httpError)
} else {
httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))
httpError.Internal = err
c.Set("_error", httpError)
}
}
proxy.Transport = config.Transport
proxy.ModifyResponse = config.ModifyResponse
return proxy
}

View file

@ -0,0 +1,267 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/labstack/echo/v4"
"golang.org/x/time/rate"
)
type (
// RateLimiterStore is the interface to be implemented by custom stores.
RateLimiterStore interface {
// Stores for the rate limiter have to implement the Allow method
Allow(identifier string) (bool, error)
}
)
type (
// RateLimiterConfig defines the configuration for the rate limiter
RateLimiterConfig struct {
Skipper Skipper
BeforeFunc BeforeFunc
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
IdentifierExtractor Extractor
// Store defines a store for the rate limiter
Store RateLimiterStore
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
ErrorHandler func(context echo.Context, err error) error
// DenyHandler provides a handler to be called when RateLimiter denies access
DenyHandler func(context echo.Context, identifier string, err error) error
}
// Extractor is used to extract data from echo.Context
Extractor func(context echo.Context) (string, error)
)
// errors
var (
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
)
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{
Skipper: DefaultSkipper,
IdentifierExtractor: func(ctx echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
ErrorHandler: func(context echo.Context, err error) error {
return &echo.HTTPError{
Code: ErrExtractorError.Code,
Message: ErrExtractorError.Message,
Internal: err,
}
},
DenyHandler: func(context echo.Context, identifier string, err error) error {
return &echo.HTTPError{
Code: ErrRateLimitExceeded.Code,
Message: ErrRateLimitExceeded.Message,
Internal: err,
}
},
}
/*
RateLimiter returns a rate limiting middleware
e := echo.New()
limiterStore := middleware.NewRateLimiterMemoryStore(20)
e.GET("/rate-limited", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}, RateLimiter(limiterStore))
*/
func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc {
config := DefaultRateLimiterConfig
config.Store = store
return RateLimiterWithConfig(config)
}
/*
RateLimiterWithConfig returns a rate limiting middleware
e := echo.New()
config := middleware.RateLimiterConfig{
Skipper: DefaultSkipper,
Store: middleware.NewRateLimiterMemoryStore(
middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute}
)
IdentifierExtractor: func(ctx echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
ErrorHandler: func(context echo.Context, err error) error {
return context.JSON(http.StatusTooManyRequests, nil)
},
DenyHandler: func(context echo.Context, identifier string) error {
return context.JSON(http.StatusForbidden, nil)
},
}
e.GET("/rate-limited", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}, middleware.RateLimiterWithConfig(config))
*/
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultRateLimiterConfig.Skipper
}
if config.IdentifierExtractor == nil {
config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor
}
if config.ErrorHandler == nil {
config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler
}
if config.DenyHandler == nil {
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
}
if config.Store == nil {
panic("Store configuration must be provided")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
if config.BeforeFunc != nil {
config.BeforeFunc(c)
}
identifier, err := config.IdentifierExtractor(c)
if err != nil {
c.Error(config.ErrorHandler(c, err))
return nil
}
if allow, err := config.Store.Allow(identifier); !allow {
c.Error(config.DenyHandler(c, identifier, err))
return nil
}
return next(c)
}
}
}
type (
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
RateLimiterMemoryStore struct {
visitors map[string]*Visitor
mutex sync.Mutex
rate rate.Limit
burst int
expiresIn time.Duration
lastCleanup time.Time
}
// Visitor signifies a unique user's limiter details
Visitor struct {
*rate.Limiter
lastSeen time.Time
}
)
/*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with
the provided rate (as req/s). The provided rate less than 1 will be treated as zero.
Burst and ExpiresIn will be set to default values.
Example (with 20 requests/sec):
limiterStore := middleware.NewRateLimiterMemoryStore(20)
*/
func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) {
return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: rate,
})
}
/*
NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore
with the provided configuration. Rate must be provided. Burst will be set to the value of
the configured rate if not provided or set to 0.
The build-in memory store is usually capable for modest loads. For higher loads other
store implementations should be considered.
Characteristics:
* Concurrency above 100 parallel requests may causes measurable lock contention
* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map
* A high number of requests from a single IP address may cause lock contention
Example:
limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig(
middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minutes},
)
*/
func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) {
store = &RateLimiterMemoryStore{}
store.rate = config.Rate
store.burst = config.Burst
store.expiresIn = config.ExpiresIn
if config.ExpiresIn == 0 {
store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn
}
if config.Burst == 0 {
store.burst = int(config.Rate)
}
store.visitors = make(map[string]*Visitor)
store.lastCleanup = now()
return
}
// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
type RateLimiterMemoryStoreConfig struct {
Rate rate.Limit // Rate of requests allowed to pass as req/s
Burst int // Burst additionally allows a number of requests to pass when rate limit is reached
ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
}
// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
ExpiresIn: 3 * time.Minute,
}
// Allow implements RateLimiterStore.Allow
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
store.mutex.Lock()
limiter, exists := store.visitors[identifier]
if !exists {
limiter = new(Visitor)
limiter.Limiter = rate.NewLimiter(store.rate, store.burst)
store.visitors[identifier] = limiter
}
limiter.lastSeen = now()
if now().Sub(store.lastCleanup) > store.expiresIn {
store.cleanupStaleVisitors()
}
store.mutex.Unlock()
return limiter.AllowN(now(), 1), nil
}
/*
cleanupStaleVisitors helps manage the size of the visitors map by removing stale records
of users who haven't visited again after the configured expiry time has elapsed
*/
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
for id, visitor := range store.visitors {
if now().Sub(visitor.lastSeen) > store.expiresIn {
delete(store.visitors, id)
}
}
store.lastCleanup = now()
}
/*
actual time method which is mocked in test file
*/
var now = time.Now

View file

@ -0,0 +1,101 @@
package middleware
import (
"fmt"
"runtime"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
)
type (
// RecoverConfig defines the config for Recover middleware.
RecoverConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Size of the stack to be printed.
// Optional. Default value 4KB.
StackSize int `yaml:"stack_size"`
// DisableStackAll disables formatting stack traces of all other goroutines
// into buffer after the trace for the current goroutine.
// Optional. Default value false.
DisableStackAll bool `yaml:"disable_stack_all"`
// DisablePrintStack disables printing stack trace.
// Optional. Default value as false.
DisablePrintStack bool `yaml:"disable_print_stack"`
// LogLevel is log level to printing stack trace.
// Optional. Default value 0 (Print).
LogLevel log.Lvl
}
)
var (
// DefaultRecoverConfig is the default Recover middleware config.
DefaultRecoverConfig = RecoverConfig{
Skipper: DefaultSkipper,
StackSize: 4 << 10, // 4 KB
DisableStackAll: false,
DisablePrintStack: false,
LogLevel: 0,
}
)
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig)
}
// RecoverWithConfig returns a Recover middleware with config.
// See: `Recover()`.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper
}
if config.StackSize == 0 {
config.StackSize = DefaultRecoverConfig.StackSize
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
defer func() {
if r := recover(); r != nil {
err, ok := r.(error)
if !ok {
err = fmt.Errorf("%v", r)
}
stack := make([]byte, config.StackSize)
length := runtime.Stack(stack, !config.DisableStackAll)
if !config.DisablePrintStack {
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
switch config.LogLevel {
case log.DEBUG:
c.Logger().Debug(msg)
case log.INFO:
c.Logger().Info(msg)
case log.WARN:
c.Logger().Warn(msg)
case log.ERROR:
c.Logger().Error(msg)
case log.OFF:
// None.
default:
c.Logger().Print(msg)
}
}
c.Error(err)
}
}()
return next(c)
}
}
}

View file

@ -0,0 +1,152 @@
package middleware
import (
"net/http"
"strings"
"github.com/labstack/echo/v4"
)
// RedirectConfig defines the config for Redirect middleware.
type RedirectConfig struct {
// Skipper defines a function to skip middleware.
Skipper
// Status code to be used when redirecting the request.
// Optional. Default value http.StatusMovedPermanently.
Code int `yaml:"code"`
}
// redirectLogic represents a function that given a scheme, host and uri
// can both: 1) determine if redirect is needed (will set ok accordingly) and
// 2) return the appropriate redirect url.
type redirectLogic func(scheme, host, uri string) (ok bool, url string)
const www = "www."
// DefaultRedirectConfig is the default Redirect middleware config.
var DefaultRedirectConfig = RedirectConfig{
Skipper: DefaultSkipper,
Code: http.StatusMovedPermanently,
}
// HTTPSRedirect redirects http requests to https.
// For example, http://labstack.com will be redirect to https://labstack.com.
//
// Usage `Echo#Pre(HTTPSRedirect())`
func HTTPSRedirect() echo.MiddlewareFunc {
return HTTPSRedirectWithConfig(DefaultRedirectConfig)
}
// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSRedirect()`.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return false, ""
})
}
// HTTPSWWWRedirect redirects http requests to https www.
// For example, http://labstack.com will be redirect to https://www.labstack.com.
//
// Usage `Echo#Pre(HTTPSWWWRedirect())`
func HTTPSWWWRedirect() echo.MiddlewareFunc {
return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig)
}
// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSWWWRedirect()`.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if scheme != "https" && !strings.HasPrefix(host, www) {
return true, "https://www." + host + uri
}
return false, ""
})
}
// HTTPSNonWWWRedirect redirects http requests to https non www.
// For example, http://www.labstack.com will be redirect to https://labstack.com.
//
// Usage `Echo#Pre(HTTPSNonWWWRedirect())`
func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig)
}
// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSNonWWWRedirect()`.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
if scheme != "https" {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
})
}
// WWWRedirect redirects non www requests to www.
// For example, http://labstack.com will be redirect to http://www.labstack.com.
//
// Usage `Echo#Pre(WWWRedirect())`
func WWWRedirect() echo.MiddlewareFunc {
return WWWRedirectWithConfig(DefaultRedirectConfig)
}
// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `WWWRedirect()`.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return false, ""
})
}
// NonWWWRedirect redirects www requests to non www.
// For example, http://www.labstack.com will be redirect to http://labstack.com.
//
// Usage `Echo#Pre(NonWWWRedirect())`
func NonWWWRedirect() echo.MiddlewareFunc {
return NonWWWRedirectWithConfig(DefaultRedirectConfig)
}
// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `NonWWWRedirect()`.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return false, ""
})
}
func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultRedirectConfig.Skipper
}
if config.Code == 0 {
config.Code = DefaultRedirectConfig.Code
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req, scheme := c.Request(), c.Scheme()
host := req.Host
if ok, url := cb(scheme, host, req.RequestURI); ok {
return c.Redirect(config.Code, url)
}
return next(c)
}
}
}

View file

@ -0,0 +1,70 @@
package middleware
import (
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
)
type (
// RequestIDConfig defines the config for RequestID middleware.
RequestIDConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Generator defines a function to generate an ID.
// Optional. Default value random.String(32).
Generator func() string
// RequestIDHandler defines a function which is executed for a request id.
RequestIDHandler func(echo.Context, string)
}
)
var (
// DefaultRequestIDConfig is the default RequestID middleware config.
DefaultRequestIDConfig = RequestIDConfig{
Skipper: DefaultSkipper,
Generator: generator,
}
)
// RequestID returns a X-Request-ID middleware.
func RequestID() echo.MiddlewareFunc {
return RequestIDWithConfig(DefaultRequestIDConfig)
}
// RequestIDWithConfig returns a X-Request-ID middleware with config.
func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultRequestIDConfig.Skipper
}
if config.Generator == nil {
config.Generator = generator
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
rid := req.Header.Get(echo.HeaderXRequestID)
if rid == "" {
rid = config.Generator()
}
res.Header().Set(echo.HeaderXRequestID, rid)
if config.RequestIDHandler != nil {
config.RequestIDHandler(c, rid)
}
return next(c)
}
}
}
func generator() string {
return random.String(32)
}

View file

@ -0,0 +1,81 @@
package middleware
import (
"regexp"
"github.com/labstack/echo/v4"
)
type (
// RewriteConfig defines the config for Rewrite middleware.
RewriteConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Example:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
// Required.
Rules map[string]string `yaml:"rules"`
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"`
}
)
var (
// DefaultRewriteConfig is the default Rewrite middleware config.
DefaultRewriteConfig = RewriteConfig{
Skipper: DefaultSkipper,
}
)
// Rewrite returns a Rewrite middleware.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func Rewrite(rules map[string]string) echo.MiddlewareFunc {
c := DefaultRewriteConfig
c.Rules = rules
return RewriteWithConfig(c)
}
// RewriteWithConfig returns a Rewrite middleware with config.
// See: `Rewrite()`.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
// Defaults
if config.Rules == nil && config.RegexRules == nil {
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
}
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
}
if config.RegexRules == nil {
config.RegexRules = make(map[*regexp.Regexp]string)
}
for k, v := range rewriteRulesRegex(config.Rules) {
config.RegexRules[k] = v
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
if err := rewriteURL(config.RegexRules, c.Request()); err != nil {
return err
}
return next(c)
}
}
}

145
vendor/github.com/labstack/echo/v4/middleware/secure.go generated vendored Normal file
View file

@ -0,0 +1,145 @@
package middleware
import (
"fmt"
"github.com/labstack/echo/v4"
)
type (
// SecureConfig defines the config for Secure middleware.
SecureConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block".
XSSProtection string `yaml:"xss_protection"`
// ContentTypeNosniff provides protection against overriding Content-Type
// header by setting the `X-Content-Type-Options` header.
// Optional. Default value "nosniff".
ContentTypeNosniff string `yaml:"content_type_nosniff"`
// XFrameOptions can be used to indicate whether or not a browser should
// be allowed to render a page in a <frame>, <iframe> or <object> .
// Sites can use this to avoid clickjacking attacks, by ensuring that their
// content is not embedded into other sites.provides protection against
// clickjacking.
// Optional. Default value "SAMEORIGIN".
// Possible values:
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
XFrameOptions string `yaml:"x_frame_options"`
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
// long (in seconds) browsers should remember that this site is only to
// be accessed using HTTPS. This reduces your exposure to some SSL-stripping
// man-in-the-middle (MITM) attacks.
// Optional. Default value 0.
HSTSMaxAge int `yaml:"hsts_max_age"`
// HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security`
// header, excluding all subdomains from security policy. It has no effect
// unless HSTSMaxAge is set to a non-zero value.
// Optional. Default value false.
HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"`
// ContentSecurityPolicy sets the `Content-Security-Policy` header providing
// security against cross-site scripting (XSS), clickjacking and other code
// injection attacks resulting from execution of malicious content in the
// trusted web page context.
// Optional. Default value "".
ContentSecurityPolicy string `yaml:"content_security_policy"`
// CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead
// of the `Content-Security-Policy` header. This allows iterative updates of the
// content security policy by only reporting the violations that would
// have occurred instead of blocking the resource.
// Optional. Default value false.
CSPReportOnly bool `yaml:"csp_report_only"`
// HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security`
// header, which enables the domain to be included in the HSTS preload list
// maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/
// Optional. Default value false.
HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"`
// ReferrerPolicy sets the `Referrer-Policy` header providing security against
// leaking potentially sensitive request paths to third parties.
// Optional. Default value "".
ReferrerPolicy string `yaml:"referrer_policy"`
}
)
var (
// DefaultSecureConfig is the default Secure middleware config.
DefaultSecureConfig = SecureConfig{
Skipper: DefaultSkipper,
XSSProtection: "1; mode=block",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
HSTSPreloadEnabled: false,
}
)
// Secure returns a Secure middleware.
// Secure middleware provides protection against cross-site scripting (XSS) attack,
// content type sniffing, clickjacking, insecure connection and other code injection
// attacks.
func Secure() echo.MiddlewareFunc {
return SecureWithConfig(DefaultSecureConfig)
}
// SecureWithConfig returns a Secure middleware with config.
// See: `Secure()`.
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultSecureConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
if config.XSSProtection != "" {
res.Header().Set(echo.HeaderXXSSProtection, config.XSSProtection)
}
if config.ContentTypeNosniff != "" {
res.Header().Set(echo.HeaderXContentTypeOptions, config.ContentTypeNosniff)
}
if config.XFrameOptions != "" {
res.Header().Set(echo.HeaderXFrameOptions, config.XFrameOptions)
}
if (c.IsTLS() || (req.Header.Get(echo.HeaderXForwardedProto) == "https")) && config.HSTSMaxAge != 0 {
subdomains := ""
if !config.HSTSExcludeSubdomains {
subdomains = "; includeSubdomains"
}
if config.HSTSPreloadEnabled {
subdomains = fmt.Sprintf("%s; preload", subdomains)
}
res.Header().Set(echo.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", config.HSTSMaxAge, subdomains))
}
if config.ContentSecurityPolicy != "" {
if config.CSPReportOnly {
res.Header().Set(echo.HeaderContentSecurityPolicyReportOnly, config.ContentSecurityPolicy)
} else {
res.Header().Set(echo.HeaderContentSecurityPolicy, config.ContentSecurityPolicy)
}
}
if config.ReferrerPolicy != "" {
res.Header().Set(echo.HeaderReferrerPolicy, config.ReferrerPolicy)
}
return next(c)
}
}
}

130
vendor/github.com/labstack/echo/v4/middleware/slash.go generated vendored Normal file
View file

@ -0,0 +1,130 @@
package middleware
import (
"strings"
"github.com/labstack/echo/v4"
)
type (
// TrailingSlashConfig defines the config for TrailingSlash middleware.
TrailingSlashConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code.
RedirectCode int `yaml:"redirect_code"`
}
)
var (
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
DefaultTrailingSlashConfig = TrailingSlashConfig{
Skipper: DefaultSkipper,
}
)
// AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`.
//
// Usage `Echo#Pre(AddTrailingSlash())`
func AddTrailingSlash() echo.MiddlewareFunc {
return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig)
}
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config.
// See `AddTrailingSlash()`.
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
url := req.URL
path := url.Path
qs := c.QueryString()
if !strings.HasSuffix(path, "/") {
path += "/"
uri := path
if qs != "" {
uri += "?" + qs
}
// Redirect
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
}
// Forward
req.RequestURI = uri
url.Path = path
}
return next(c)
}
}
}
// RemoveTrailingSlash returns a root level (before router) middleware which removes
// a trailing slash from the request URI.
//
// Usage `Echo#Pre(RemoveTrailingSlash())`
func RemoveTrailingSlash() echo.MiddlewareFunc {
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
}
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config.
// See `RemoveTrailingSlash()`.
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
url := req.URL
path := url.Path
qs := c.QueryString()
l := len(path) - 1
if l > 0 && strings.HasSuffix(path, "/") {
path = path[:l]
uri := path
if qs != "" {
uri += "?" + qs
}
// Redirect
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
}
// Forward
req.RequestURI = uri
url.Path = path
}
return next(c)
}
}
}
func sanitizeURI(uri string) string {
// double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
// we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
uri = "/" + strings.TrimLeft(uri, `/\`)
}
return uri
}

276
vendor/github.com/labstack/echo/v4/middleware/static.go generated vendored Normal file
View file

@ -0,0 +1,276 @@
package middleware
import (
"fmt"
"html/template"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
)
type (
// StaticConfig defines the config for Static middleware.
StaticConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Root directory from where the static content is served.
// Required.
Root string `yaml:"root"`
// Index file for serving a directory.
// Optional. Default value "index.html".
Index string `yaml:"index"`
// Enable HTML5 mode by forwarding all not-found requests to root so that
// SPA (single-page application) can handle the routing.
// Optional. Default value false.
HTML5 bool `yaml:"html5"`
// Enable directory browsing.
// Optional. Default value false.
Browse bool `yaml:"browse"`
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
// Filesystem provides access to the static content.
// Optional. Defaults to http.Dir(config.Root)
Filesystem http.FileSystem `yaml:"-"`
}
)
const html = `
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<title>{{ .Name }}</title>
<style>
body {
font-family: Menlo, Consolas, monospace;
padding: 48px;
}
header {
padding: 4px 16px;
font-size: 24px;
}
ul {
list-style-type: none;
margin: 0;
padding: 20px 0 0 0;
display: flex;
flex-wrap: wrap;
}
li {
width: 300px;
padding: 16px;
}
li a {
display: block;
overflow: hidden;
white-space: nowrap;
text-overflow: ellipsis;
text-decoration: none;
transition: opacity 0.25s;
}
li span {
color: #707070;
font-size: 12px;
}
li a:hover {
opacity: 0.50;
}
.dir {
color: #E91E63;
}
.file {
color: #673AB7;
}
</style>
</head>
<body>
<header>
{{ .Name }}
</header>
<ul>
{{ range .Files }}
<li>
{{ if .Dir }}
{{ $name := print .Name "/" }}
<a class="dir" href="{{ $name }}">{{ $name }}</a>
{{ else }}
<a class="file" href="{{ .Name }}">{{ .Name }}</a>
<span>{{ .Size }}</span>
{{ end }}
</li>
{{ end }}
</ul>
</body>
</html>
`
var (
// DefaultStaticConfig is the default Static middleware config.
DefaultStaticConfig = StaticConfig{
Skipper: DefaultSkipper,
Index: "index.html",
}
)
// Static returns a Static middleware to serves static content from the provided
// root directory.
func Static(root string) echo.MiddlewareFunc {
c := DefaultStaticConfig
c.Root = root
return StaticWithConfig(c)
}
// StaticWithConfig returns a Static middleware with config.
// See `Static()`.
func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
// Defaults
if config.Root == "" {
config.Root = "." // For security we want to restrict to CWD.
}
if config.Skipper == nil {
config.Skipper = DefaultStaticConfig.Skipper
}
if config.Index == "" {
config.Index = DefaultStaticConfig.Index
}
if config.Filesystem == nil {
config.Filesystem = http.Dir(config.Root)
config.Root = "."
}
// Index template
t, err := template.New("index").Parse(html)
if err != nil {
panic(fmt.Sprintf("echo: %v", err))
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
p := c.Request().URL.Path
if strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`.
p = c.Param("*")
}
p, err = url.PathUnescape(p)
if err != nil {
return
}
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
if config.IgnoreBase {
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
baseURLPath := path.Base(p)
if baseURLPath == routePath {
i := strings.LastIndex(name, routePath)
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
}
}
file, err := openFile(config.Filesystem, name)
if err != nil {
if !os.IsNotExist(err) {
return err
}
if err = next(c); err == nil {
return err
}
he, ok := err.(*echo.HTTPError)
if !(ok && config.HTML5 && he.Code == http.StatusNotFound) {
return err
}
file, err = openFile(config.Filesystem, filepath.Join(config.Root, config.Index))
if err != nil {
return err
}
}
defer file.Close()
info, err := file.Stat()
if err != nil {
return err
}
if info.IsDir() {
index, err := openFile(config.Filesystem, filepath.Join(name, config.Index))
if err != nil {
if config.Browse {
return listDir(t, name, file, c.Response())
}
if os.IsNotExist(err) {
return next(c)
}
}
defer index.Close()
info, err = index.Stat()
if err != nil {
return err
}
return serveFile(c, index, info)
}
return serveFile(c, file, info)
}
}
}
func openFile(fs http.FileSystem, name string) (http.File, error) {
pathWithSlashes := filepath.ToSlash(name)
return fs.Open(pathWithSlashes)
}
func serveFile(c echo.Context, file http.File, info os.FileInfo) error {
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), file)
return nil
}
func listDir(t *template.Template, name string, dir http.File, res *echo.Response) (err error) {
files, err := dir.Readdir(-1)
if err != nil {
return
}
// Create directory index
res.Header().Set(echo.HeaderContentType, echo.MIMETextHTMLCharsetUTF8)
data := struct {
Name string
Files []interface{}
}{
Name: name,
}
for _, f := range files {
data.Files = append(data.Files, struct {
Name string
Dir bool
Size string
}{f.Name(), f.IsDir(), bytes.Format(f.Size())})
}
return t.Execute(res, data)
}

View file

@ -0,0 +1,128 @@
package middleware
import (
"context"
"net/http"
"time"
"github.com/labstack/echo/v4"
)
type (
// TimeoutConfig defines the config for Timeout middleware.
TimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
// It can be used to define a custom timeout error message
ErrorMessage string
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
// request timeouted and we already had sent the error code (503) and message response to the client.
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
OnTimeoutRouteErrorHandler func(err error, c echo.Context)
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
Timeout time.Duration
}
)
var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorMessage: "",
}
)
// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler
// call runs for longer than its time limit. NB: timeout does not stop handler execution.
func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig)
}
// TimeoutWithConfig returns a Timeout middleware with config.
// See: `Timeout()`.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) || config.Timeout == 0 {
return next(c)
}
handlerWrapper := echoHandlerFuncWrapper{
ctx: c,
handler: next,
errChan: make(chan error, 1),
errHandler: config.OnTimeoutRouteErrorHandler,
}
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
handler.ServeHTTP(c.Response().Writer, c.Request())
select {
case err := <-handlerWrapper.errChan:
return err
default:
return nil
}
}
}
}
type echoHandlerFuncWrapper struct {
ctx echo.Context
handler echo.HandlerFunc
errHandler func(err error, c echo.Context)
errChan chan error
}
func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// replace echo.Context Request with the one provided by TimeoutHandler to let later middlewares/handler on the chain
// handle properly it's cancellation
t.ctx.SetRequest(r)
// replace writer with TimeoutHandler custom one. This will guarantee that
// `writes by h to its ResponseWriter will return ErrHandlerTimeout.`
originalWriter := t.ctx.Response().Writer
t.ctx.Response().Writer = rw
// in case of panic we restore original writer and call panic again
// so it could be handled with global middleware Recover()
defer func() {
if err := recover(); err != nil {
t.ctx.Response().Writer = originalWriter
panic(err)
}
}()
err := t.handler(t.ctx)
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
if err != nil && t.errHandler != nil {
t.errHandler(err, t.ctx)
}
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
}
// we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
// and should not anymore send additional headers/data
// so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
t.ctx.Response().Writer = originalWriter
if err != nil {
// call global error handler to write error to the client. This is needed or `http.TimeoutHandler` will send status code by itself
// and after that our tries to write status code will not work anymore
t.ctx.Error(err)
// we pass error from handler to middlewares up in handler chain to act on it if needed. But this means that
// global error handler is probably be called twice as `t.ctx.Error` already does that.
t.errChan <- err
}
}

54
vendor/github.com/labstack/echo/v4/middleware/util.go generated vendored Normal file
View file

@ -0,0 +1,54 @@
package middleware
import (
"strings"
)
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}