recover from panics such as out-of-bounds array access & nil pointer access, print a stack trace and return 5xx error instead of the service crashing and relying on Execution framework to handle crashes
221 lines
6.1 KiB
Go
221 lines
6.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"crypto/subtle"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/labstack/gommon/random"
|
|
)
|
|
|
|
type (
|
|
// CSRFConfig defines the config for CSRF middleware.
|
|
CSRFConfig struct {
|
|
// Skipper defines a function to skip middleware.
|
|
Skipper Skipper
|
|
|
|
// TokenLength is the length of the generated token.
|
|
TokenLength uint8 `yaml:"token_length"`
|
|
// Optional. Default value 32.
|
|
|
|
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
|
// to extract token from the request.
|
|
// Optional. Default value "header:X-CSRF-Token".
|
|
// Possible values:
|
|
// - "header:<name>"
|
|
// - "form:<name>"
|
|
// - "query:<name>"
|
|
TokenLookup string `yaml:"token_lookup"`
|
|
|
|
// Context key to store generated CSRF token into context.
|
|
// Optional. Default value "csrf".
|
|
ContextKey string `yaml:"context_key"`
|
|
|
|
// Name of the CSRF cookie. This cookie will store CSRF token.
|
|
// Optional. Default value "csrf".
|
|
CookieName string `yaml:"cookie_name"`
|
|
|
|
// Domain of the CSRF cookie.
|
|
// Optional. Default value none.
|
|
CookieDomain string `yaml:"cookie_domain"`
|
|
|
|
// Path of the CSRF cookie.
|
|
// Optional. Default value none.
|
|
CookiePath string `yaml:"cookie_path"`
|
|
|
|
// Max age (in seconds) of the CSRF cookie.
|
|
// Optional. Default value 86400 (24hr).
|
|
CookieMaxAge int `yaml:"cookie_max_age"`
|
|
|
|
// Indicates if CSRF cookie is secure.
|
|
// Optional. Default value false.
|
|
CookieSecure bool `yaml:"cookie_secure"`
|
|
|
|
// Indicates if CSRF cookie is HTTP only.
|
|
// Optional. Default value false.
|
|
CookieHTTPOnly bool `yaml:"cookie_http_only"`
|
|
|
|
// Indicates SameSite mode of the CSRF cookie.
|
|
// Optional. Default value SameSiteDefaultMode.
|
|
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
|
|
}
|
|
|
|
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
|
// either a token or an error.
|
|
csrfTokenExtractor func(echo.Context) (string, error)
|
|
)
|
|
|
|
var (
|
|
// DefaultCSRFConfig is the default CSRF middleware config.
|
|
DefaultCSRFConfig = CSRFConfig{
|
|
Skipper: DefaultSkipper,
|
|
TokenLength: 32,
|
|
TokenLookup: "header:" + echo.HeaderXCSRFToken,
|
|
ContextKey: "csrf",
|
|
CookieName: "_csrf",
|
|
CookieMaxAge: 86400,
|
|
CookieSameSite: http.SameSiteDefaultMode,
|
|
}
|
|
)
|
|
|
|
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
|
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
|
func CSRF() echo.MiddlewareFunc {
|
|
c := DefaultCSRFConfig
|
|
return CSRFWithConfig(c)
|
|
}
|
|
|
|
// CSRFWithConfig returns a CSRF middleware with config.
|
|
// See `CSRF()`.
|
|
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|
// Defaults
|
|
if config.Skipper == nil {
|
|
config.Skipper = DefaultCSRFConfig.Skipper
|
|
}
|
|
if config.TokenLength == 0 {
|
|
config.TokenLength = DefaultCSRFConfig.TokenLength
|
|
}
|
|
if config.TokenLookup == "" {
|
|
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
|
}
|
|
if config.ContextKey == "" {
|
|
config.ContextKey = DefaultCSRFConfig.ContextKey
|
|
}
|
|
if config.CookieName == "" {
|
|
config.CookieName = DefaultCSRFConfig.CookieName
|
|
}
|
|
if config.CookieMaxAge == 0 {
|
|
config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
|
|
}
|
|
if config.CookieSameSite == http.SameSiteNoneMode {
|
|
config.CookieSecure = true
|
|
}
|
|
|
|
// Initialize
|
|
parts := strings.Split(config.TokenLookup, ":")
|
|
extractor := csrfTokenFromHeader(parts[1])
|
|
switch parts[0] {
|
|
case "form":
|
|
extractor = csrfTokenFromForm(parts[1])
|
|
case "query":
|
|
extractor = csrfTokenFromQuery(parts[1])
|
|
}
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if config.Skipper(c) {
|
|
return next(c)
|
|
}
|
|
|
|
req := c.Request()
|
|
k, err := c.Cookie(config.CookieName)
|
|
token := ""
|
|
|
|
// Generate token
|
|
if err != nil {
|
|
token = random.String(config.TokenLength)
|
|
} else {
|
|
// Reuse token
|
|
token = k.Value
|
|
}
|
|
|
|
switch req.Method {
|
|
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
|
default:
|
|
// Validate token only for requests which are not defined as 'safe' by RFC7231
|
|
clientToken, err := extractor(c)
|
|
if err != nil {
|
|
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
|
}
|
|
if !validateCSRFToken(token, clientToken) {
|
|
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
|
}
|
|
}
|
|
|
|
// Set CSRF cookie
|
|
cookie := new(http.Cookie)
|
|
cookie.Name = config.CookieName
|
|
cookie.Value = token
|
|
if config.CookiePath != "" {
|
|
cookie.Path = config.CookiePath
|
|
}
|
|
if config.CookieDomain != "" {
|
|
cookie.Domain = config.CookieDomain
|
|
}
|
|
if config.CookieSameSite != http.SameSiteDefaultMode {
|
|
cookie.SameSite = config.CookieSameSite
|
|
}
|
|
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
|
|
cookie.Secure = config.CookieSecure
|
|
cookie.HttpOnly = config.CookieHTTPOnly
|
|
c.SetCookie(cookie)
|
|
|
|
// Store token in the context
|
|
c.Set(config.ContextKey, token)
|
|
|
|
// Protect clients from caching the response
|
|
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
|
// provided request header.
|
|
func csrfTokenFromHeader(header string) csrfTokenExtractor {
|
|
return func(c echo.Context) (string, error) {
|
|
return c.Request().Header.Get(header), nil
|
|
}
|
|
}
|
|
|
|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
|
// provided form parameter.
|
|
func csrfTokenFromForm(param string) csrfTokenExtractor {
|
|
return func(c echo.Context) (string, error) {
|
|
token := c.FormValue(param)
|
|
if token == "" {
|
|
return "", errors.New("missing csrf token in the form parameter")
|
|
}
|
|
return token, nil
|
|
}
|
|
}
|
|
|
|
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
|
|
// provided query parameter.
|
|
func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
|
return func(c echo.Context) (string, error) {
|
|
token := c.QueryParam(param)
|
|
if token == "" {
|
|
return "", errors.New("missing csrf token in the query string")
|
|
}
|
|
return token, nil
|
|
}
|
|
}
|
|
|
|
func validateCSRFToken(token, clientToken string) bool {
|
|
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
|
|
}
|