auth_utility.go 6.14 KB
// TODO: Improve roles
package webutility

import (
	"crypto/rand"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"net/http"
	"strings"
	"time"

	"github.com/dgrijalva/jwt-go"
)

const OneDay = time.Hour * 24
const OneWeek = OneDay * 7
const saltSize = 32
const appName = "korisnicki-centar"
const secret = "korisnicki-centar-api"

type Role struct {
	Name string `json:"name"`
	ID   uint32 `json:"id"`
}

// TokenClaims are JWT token claims.
type TokenClaims struct {
	Username string `json:"username"`
	Role     string `json:"role"`
	RoleID   uint32 `json:"roleID"`
	jwt.StandardClaims
}

// CredentialsStruct is an instace of username/password values.
type CredentialsStruct struct {
	Username string `json:"username"`
	Password string `json:"password"`
}

// generateSalt returns a string of random characters of 'saltSize' length.
func generateSalt() (salt string, err error) {
	rawsalt := make([]byte, saltSize)

	_, err = rand.Read(rawsalt)
	if err != nil {
		return "", err
	}

	salt = hex.EncodeToString(rawsalt)
	return salt, nil
}

// HashString hashes input string using SHA256.
// If the presalt parameter is not provided HashString will generate new salt string.
// Returns hash and salt string or an error if it fails.
func HashString(str, presalt string) (hash, salt string, err error) {
	// chech if message is presalted
	if presalt == "" {
		salt, err = generateSalt()
		if err != nil {
			return "", "", err
		}
	} else {
		salt = presalt
	}

	// convert strings to raw byte slices
	rawstr := []byte(str)
	rawsalt, err := hex.DecodeString(salt)
	if err != nil {
		return "", "", err
	}

	rawdata := make([]byte, len(rawstr)+len(rawsalt))
	rawdata = append(rawdata, rawstr...)
	rawdata = append(rawdata, rawsalt...)

	// hash message + salt
	hasher := sha256.New()
	hasher.Write(rawdata)
	rawhash := hasher.Sum(nil)

	hash = hex.EncodeToString(rawhash)
	return hash, salt, nil
}

// CreateAPIToken returns JWT token with encoded username, role, expiration date and issuer claims.
// It returns an error if it fails.
func CreateAPIToken(username string, role Role) (string, error) {
	var apiToken string
	var err error

	if err != nil {
		return "", err
	}

	claims := TokenClaims{
		username,
		role.Name,
		role.ID,
		jwt.StandardClaims{
			ExpiresAt: (time.Now().Add(OneWeek)).Unix(),
			Issuer:    appName,
		},
	}

	jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	apiToken, err = jwtToken.SignedString([]byte(secret))
	if err != nil {
		return "", err
	}
	return apiToken, nil
}

// RefreshAPIToken prolongs JWT token's expiration date for one week.
// It returns new JWT token or an error if it fails.
func RefreshAPIToken(tokenString string) (string, error) {
	var newToken string
	tokenString = strings.TrimPrefix(tokenString, "Bearer ")
	token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, secretFunc)
	if err != nil {
		return "", err
	}

	// type assertion
	claims, ok := token.Claims.(*TokenClaims)
	if !ok || !token.Valid {
		return "", errors.New("token is not valid")
	}

	claims.ExpiresAt = (time.Now().Add(OneWeek)).Unix()
	jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

	newToken, err = jwtToken.SignedString([]byte(secret))
	if err != nil {
		return "", err
	}

	return newToken, nil
}

// ParseAPIToken parses JWT token claims.
// It returns a pointer to TokenClaims struct or an error if it fails.
func ParseAPIToken(tokenString string) (*TokenClaims, error) {
	if ok := strings.HasPrefix(tokenString, "Bearer "); ok {
		tokenString = strings.TrimPrefix(tokenString, "Bearer ")
	} else {
		return &TokenClaims{}, errors.New("Authorization header is incomplete")
	}

	token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, secretFunc)
	if err != nil {
		return &TokenClaims{}, err
	}

	// type assertion
	claims, ok := token.Claims.(*TokenClaims)
	if !ok || !token.Valid {
		return &TokenClaims{}, errors.New("token is not valid")
	}
	return claims, nil
}

func GetTokenClaims(r *http.Request) (claims *TokenClaims, err error) {
	token := r.Header.Get("Authorization")
	if ok := strings.HasPrefix(token, "Bearer "); ok {
		token = strings.TrimPrefix(token, "Bearer ")
	} else {
		return &TokenClaims{}, errors.New("Authorization header is incomplete")
	}

	parsedToken, err := jwt.ParseWithClaims(token, &TokenClaims{}, secretFunc)
	if err != nil {
		return &TokenClaims{}, err
	}

	// type assertion
	claims, ok := parsedToken.Claims.(*TokenClaims)
	if !ok || !parsedToken.Valid {
		return &TokenClaims{}, errors.New("token is not valid")
	}
	return claims, err
}

// secretFunc returns byte slice of API secret keyword.
func secretFunc(token *jwt.Token) (interface{}, error) {
	return []byte(secret), nil
}

// RbacCheck returns true if role that made HTTP request is authorized to
// access the resource it is targeting.
// It exctracts user's role from the JWT token located in Authorization header of
// http.Request and then compares it with the list of supplied roles and returns
// true if there's a match, if "*" is provided or if the authRoles is nil.
// Otherwise it returns false.
func RbacCheck(req *http.Request, authRoles []string) bool {
	if authRoles == nil {
		return true
	}

	token := req.Header.Get("Authorization")
	claims, err := ParseAPIToken(token)
	if err != nil {
		return false
	}

	for _, r := range authRoles {
		if claims.Role == r || r == "*" {
			return true
		}
	}

	return false
}

// Rbac sets common headers and performs RBAC.
// If RBAC  passes it calls the handlerFunc.
func RbacHandler(handlerFunc http.HandlerFunc, authRoles []string) http.HandlerFunc {
	return func(w http.ResponseWriter, req *http.Request) {
		w.Header().Set("Access-Control-Allow-Origin", "*")

		w.Header().Set("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE, OPTIONS")

		w.Header().Set("Access-Control-Allow-Headers", `Accept, Content-Type,
			Content-Length, Accept-Encoding, X-CSRF-Token, Authorization`)

		w.Header().Set("Content-Type", "application/json; charset=utf-8")

		// TODO: Check for content type

		if req.Method == "OPTIONS" {
			return
		}

		err := req.ParseForm()
		if err != nil {
			BadRequestResponse(w, req)
			return
		}

		if !RbacCheck(req, authRoles) {
			UnauthorizedResponse(w, req)
			return
		}

		// execute HandlerFunc
		handlerFunc(w, req)
	}
}