From 65d214f47d49a6a5637598d8fac7523f5fb40fa8 Mon Sep 17 00:00:00 2001 From: "marko.tikvic" Date: Thu, 29 Aug 2019 16:03:57 +0200 Subject: [PATCH] improved middleware: more control over content type --- auth.go | 2 +- http.go | 21 +++++++++++++-------- middleware/main.go | 12 ++++++------ middleware/middleware.go | 42 +++++++++++++++++++++++++++--------------- 4 files changed, 47 insertions(+), 30 deletions(-) diff --git a/auth.go b/auth.go index 563d78e..d2351f6 100644 --- a/auth.go +++ b/auth.go @@ -171,7 +171,7 @@ func GetTokenClaims(req *http.Request) (*TokenClaims, error) { if ok := strings.HasPrefix(authHead, "Bearer "); ok { tokstr = strings.TrimPrefix(authHead, "Bearer ") } else { - return &TokenClaims{}, errors.New("authorization header in incomplete") + return &TokenClaims{}, errors.New("authorization header is incomplete") } token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc) diff --git a/http.go b/http.go index a2fc796..6978e45 100644 --- a/http.go +++ b/http.go @@ -51,14 +51,12 @@ func (r *StatusRecorder) Size() int { // NotFoundHandlerFunc writes HTTP error 404 to w. func NotFoundHandlerFunc(w http.ResponseWriter, req *http.Request) { - SetDefaultHeaders(w) - if req.Method == "OPTIONS" { - return - } + SetAccessControlHeaders(w) + SetContentType(w, "application/json") NotFound(w, req, fmt.Sprintf("Resource you requested was not found: %s", req.URL.String())) } -// SetContentType ... +// SetContentType must be called before SetResponseStatus (w.WriteHeader) (?) func SetContentType(w http.ResponseWriter, ctype string) { w.Header().Set("Content-Type", ctype) } @@ -73,12 +71,11 @@ func WriteResponse(w http.ResponseWriter, content []byte) { w.Write(content) } -// SetDefaultHeaders set's default headers for an HTTP response. -func SetDefaultHeaders(w http.ResponseWriter) { +// SetAccessControlHeaders set's default headers for an HTTP response. +func SetAccessControlHeaders(w http.ResponseWriter) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") - SetContentType(w, "application/json; charset=utf-8") } // GetLocale ... @@ -100,11 +97,13 @@ func Success(w http.ResponseWriter, payload interface{}, code int) { // OK ... func OK(w http.ResponseWriter, payload interface{}) { + SetContentType(w, "application/json") Success(w, payload, http.StatusOK) } // Created ... func Created(w http.ResponseWriter, payload interface{}) { + SetContentType(w, "application/json") Success(w, payload, http.StatusCreated) } @@ -122,30 +121,36 @@ func Error(w http.ResponseWriter, r *http.Request, code int, err string) { // BadRequest ... func BadRequest(w http.ResponseWriter, r *http.Request, err string) { + SetContentType(w, "application/json") Error(w, r, http.StatusBadRequest, err) } // Unauthorized ... func Unauthorized(w http.ResponseWriter, r *http.Request, err string) { + SetContentType(w, "application/json") Error(w, r, http.StatusUnauthorized, err) } // Forbidden ... func Forbidden(w http.ResponseWriter, r *http.Request, err string) { + SetContentType(w, "application/json") Error(w, r, http.StatusForbidden, err) } // NotFound ... func NotFound(w http.ResponseWriter, r *http.Request, err string) { + SetContentType(w, "application/json") Error(w, r, http.StatusNotFound, err) } // Conflict ... func Conflict(w http.ResponseWriter, r *http.Request, err string) { + SetContentType(w, "application/json") Error(w, r, http.StatusConflict, err) } // InternalServerError ... func InternalServerError(w http.ResponseWriter, r *http.Request, err string) { + SetContentType(w, "application/json") Error(w, r, http.StatusInternalServerError, err) } diff --git a/middleware/main.go b/middleware/main.go index 55e8654..8c1da43 100644 --- a/middleware/main.go +++ b/middleware/main.go @@ -5,17 +5,17 @@ import ( ) func Headers(h http.HandlerFunc) http.HandlerFunc { - return IgnoreOptionsRequests(ParseForm(h)) + return SetAccessControlHeaders(IgnoreOptionsRequests(ParseForm(h))) } -func AuthOnly(roles string, h http.HandlerFunc) http.HandlerFunc { - return IgnoreOptionsRequests(ParseForm(Auth(roles, h))) +func AuthUser(roles string, h http.HandlerFunc) http.HandlerFunc { + return SetAccessControlHeaders(IgnoreOptionsRequests(ParseForm(Auth(roles, h)))) } -func Full(roles string, h http.HandlerFunc) http.HandlerFunc { - return IgnoreOptionsRequests(ParseForm(LogTraffic(Auth(roles, h)))) +func AuthUserLogTraffic(roles string, h http.HandlerFunc) http.HandlerFunc { + return SetAccessControlHeaders(IgnoreOptionsRequests(ParseForm(LogHTTP(Auth(roles, h))))) } func LogTraffic(h http.HandlerFunc) http.HandlerFunc { - return IgnoreOptionsRequests(ParseForm(LogRequestAndResponse(h))) + return SetAccessControlHeaders(IgnoreOptionsRequests(ParseForm(LogHTTP(h)))) } diff --git a/middleware/middleware.go b/middleware/middleware.go index b106281..ec547a8 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -11,13 +11,21 @@ import ( var httpLogger *gologger.Logger +func SetAccessControlHeaders(h http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + web.SetAccessControlHeaders(w) + + h(w, req) + } +} + // IgnoreOptionsRequests ... func IgnoreOptionsRequests(h http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { - web.SetDefaultHeaders(w) if req.Method == http.MethodOptions { return } + h(w, req) } } @@ -30,6 +38,7 @@ func ParseForm(h http.HandlerFunc) http.HandlerFunc { web.BadRequest(w, req, err.Error()) return } + h(w, req) } } @@ -42,6 +51,7 @@ func ParseMultipartForm(h http.HandlerFunc) http.HandlerFunc { web.BadRequest(w, req, err.Error()) return } + h(w, req) } } @@ -51,26 +61,27 @@ func SetLogger(logger *gologger.Logger) { httpLogger = logger } -// LogRequestAndResponse ... -func LogRequestAndResponse(h http.HandlerFunc) http.HandlerFunc { +// LogHTTP ... +func LogHTTP(h http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { - if httpLogger != nil { - t1 := time.Now() + if httpLogger == nil { + h(w, req) + return + } - claims, _ := web.GetTokenClaims(req) - in := httpLogger.LogHTTPRequest(req, claims.Username) + t1 := time.Now() - rec := web.NewStatusRecorder(w) + claims, _ := web.GetTokenClaims(req) + in := httpLogger.LogHTTPRequest(req, claims.Username) - h(rec, req) + rec := web.NewStatusRecorder(w) - t2 := time.Now() - out := httpLogger.LogHTTPResponse(rec.Status(), t2.Sub(t1), rec.Size()) + h(rec, req) - httpLogger.CombineHTTPLogs(in, out) - } else { - h(w, req) - } + t2 := time.Now() + out := httpLogger.LogHTTPResponse(rec.Status(), t2.Sub(t1), rec.Size()) + + httpLogger.CombineHTTPLogs(in, out) } } @@ -81,6 +92,7 @@ func Auth(roles string, h http.HandlerFunc) http.HandlerFunc { web.Unauthorized(w, req, err.Error()) return } + h(w, req) } } -- 1.8.1.2