auth.go 6.32 KB
package webutility

import (
	"crypto/rand"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"net/http"
	"strings"
	"time"

	"github.com/dgrijalva/jwt-go"
)

const OneDay = time.Hour * 24
const OneWeek = OneDay * 7
const saltSize = 32

var appName = "webutility"
var secret = "webutility"

type Role struct {
	Name string `json:"name"`
	ID   int64  `json:"id"`
}

// TokenClaims are JWT token claims.
type TokenClaims struct {
	// extending a struct
	jwt.StandardClaims

	// custom claims
	Token     string `json:"access_token"`
	TokenType string `json:"token_type"`
	Username  string `json:"username"`
	Role      string `json:"role"`
	RoleID    int64  `json:"role_id"`
	ExpiresIn int64  `json:"expires_in"`
}

func InitJWT(appName, secret string) {
	appName = appName
	secret = secret
}

// ValidateCredentials hashes pass and salt and returns comparison result with resultHash
func ValidateCredentials(pass, salt, resultHash string) (bool, error) {
	hash, _, err := CreateHash(pass, salt)
	if err != nil {
		return false, err
	}
	res := hash == resultHash
	return res, nil
}

// CreateHash hashes str using SHA256.
// If the presalt parameter is not provided CreateHash will generate new salt string.
// Returns hash and salt strings or an error if it fails.
func CreateHash(str, presalt string) (hash, salt string, err error) {
	// chech if message is presalted
	if presalt == "" {
		salt, err = randomSalt()
		if err != nil {
			return "", "", err
		}
	} else {
		salt = presalt
	}

	// convert strings to raw byte slices
	rawstr := []byte(str)
	rawsalt, err := hex.DecodeString(salt)
	if err != nil {
		return "", "", err
	}

	rawdata := make([]byte, len(rawstr)+len(rawsalt))
	rawdata = append(rawdata, rawstr...)
	rawdata = append(rawdata, rawsalt...)

	// hash message + salt
	hasher := sha256.New()
	hasher.Write(rawdata)
	rawhash := hasher.Sum(nil)

	hash = hex.EncodeToString(rawhash)
	return hash, salt, nil
}

// CreateAuthToken returns JWT token with encoded username, role, expiration date and issuer claims.
// It returns an error if it fails.
func CreateAuthToken(username string, role Role) (TokenClaims, error) {
	t0 := (time.Now()).Unix()
	t1 := (time.Now().Add(OneWeek)).Unix()
	claims := TokenClaims{
		TokenType: "Bearer",
		Username:  username,
		Role:      role.Name,
		RoleID:    role.ID,
		ExpiresIn: t1 - t0,
	}
	// initialize jwt.StandardClaims fields (anonymous struct)
	claims.IssuedAt = t0
	claims.ExpiresAt = t1
	claims.Issuer = appName

	jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	token, err := jwtToken.SignedString([]byte(secret))
	if err != nil {
		return TokenClaims{}, err
	}
	claims.Token = token
	return claims, nil
}

// RefreshAuthToken returns new JWT token with sprolongs JWT token's expiration date for one week.
// It returns new JWT token or an error if it fails.
func RefreshAuthToken(tok string) (TokenClaims, error) {
	token, err := jwt.ParseWithClaims(tok, &TokenClaims{}, secretFunc)
	if err != nil {
		if validation, ok := err.(*jwt.ValidationError); ok {
			// don't return error if token is expired
			// just extend it
			if !(validation.Errors&jwt.ValidationErrorExpired != 0) {
				return TokenClaims{}, err
			}
		} else {
			return TokenClaims{}, err
		}
	}

	// type assertion
	claims, ok := token.Claims.(*TokenClaims)
	if !ok {
		return TokenClaims{}, errors.New("token is not valid")
	}

	// extend token expiration date
	return CreateAuthToken(claims.Username, Role{claims.Role, claims.RoleID})
}

// RbacCheck returns true if user that made HTTP request is authorized to
// access the resource it is targeting.
// It exctracts user's role from the JWT token located in Authorization header of
// http.Request and then compares it with the list of supplied roles and returns
// true if there's a match, if "*" is provided or if the authRoles is nil.
// Otherwise it returns false.
func RbacCheck(req *http.Request, authRoles []string) bool {
	if authRoles == nil {
		return true
	}

	// validate token and check expiration date
	claims, err := GetTokenClaims(req)
	if err != nil {
		return false
	}
	// check if token has expired
	if claims.ExpiresAt < (time.Now()).Unix() {
		return false
	}

	// check if role extracted from token matches
	// any of the provided (allowed) ones
	for _, r := range authRoles {
		if claims.Role == r || r == "*" {
			return true
		}
	}

	return false
}

// AuthCheck returns token claims and boolean value based on user's rights to access resource specified in req.
// It exctracts user's role from the JWT token located in Authorization header of
// HTTP request and then compares it with the list of supplied (authorized);
// it returns true if there's a match, if "*" is provided or if the authRoles is nil.
func AuthCheck(req *http.Request, authRoles []string) (*TokenClaims, bool) {
	if authRoles == nil {
		return nil, true
	}

	// validate token and check expiration date
	claims, err := GetTokenClaims(req)
	if err != nil {
		return claims, false
	}
	// check if token has expired
	if claims.ExpiresAt < (time.Now()).Unix() {
		return claims, false
	}

	// check if role extracted from token matches
	// any of the provided (allowed) ones
	for _, r := range authRoles {
		if claims.Role == r || r == "*" {
			return claims, true
		}
	}

	return claims, false
}

// GetTokenClaims extracts JWT claims from Authorization header of the request.
// Returns token claims or an error.
func GetTokenClaims(req *http.Request) (*TokenClaims, error) {
	// check for and strip 'Bearer' prefix
	var tokstr string
	authHead := req.Header.Get("Authorization")
	if ok := strings.HasPrefix(authHead, "Bearer "); ok {
		tokstr = strings.TrimPrefix(authHead, "Bearer ")
	} else {
		return &TokenClaims{}, errors.New("authorization header in incomplete")
	}

	token, err := jwt.ParseWithClaims(tokstr, &TokenClaims{}, secretFunc)
	if err != nil {
		return &TokenClaims{}, err
	}

	// type assertion
	claims, ok := token.Claims.(*TokenClaims)
	if !ok || !token.Valid {
		return &TokenClaims{}, errors.New("token is not valid")
	}

	return claims, nil
}

// randomSalt returns a string of random characters of 'saltSize' length.
func randomSalt() (s string, err error) {
	rawsalt := make([]byte, saltSize)

	_, err = rand.Read(rawsalt)
	if err != nil {
		return "", err
	}

	s = hex.EncodeToString(rawsalt)
	return s, nil
}

// secretFunc returns byte slice of API secret keyword.
func secretFunc(token *jwt.Token) (interface{}, error) {
	return []byte(secret), nil
}