simple-sso/src/ssso_crypto.lua

132 lines
4.1 KiB
Lua
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

local logic = require("bit")
local json = require("cjson.safe")
local aes = require("resty.openssl.cipher")
local random = require("resty.random")
local s256 = require("resty.sha256")
local b64 = require("ssso_base64")
local config = require("ssso_config")
local log = require("ssso_log")
local nginx = require("ssso_nginx")
local sites = require("ssso_sites")
local KEY_SIZE = 32 -- 256 bits for AES-256-GCMs key and SHA-256
local IV_SIZE = 12 -- 96 bits for AES-256-GCMs IV
local TAG_SIZE = 16 -- 128 bits for AES-256-GCMs tag
local gcm_aad = random.bytes(8)
-- https://www.rfc-editor.org/rfc/rfc7518.html#section-6.4
local symkey = random.bytes(KEY_SIZE, true) or random.bytes(KEY_SIZE, false)
local keytype = '{"kty":"oct","k":"' .. b64.encode_base64url(symkey) .. '"}'
-- https://en.wikipedia.org/wiki/HMAC
local i_key_pad = ""
local o_key_pad = ""
for c in symkey:gmatch(".") do
i_key_pad = i_key_pad .. string.char(logic.bxor(54, c:byte()))
o_key_pad = o_key_pad .. string.char(logic.bxor(92, c:byte()))
end
-- https://www.rfc-editor.org/rfc/rfc7515.html#appendix-A.1
local jose_256_b64 = b64.encode_base64url('{"alg":"HS256"}')
local function encrypt(bytes)
local iv = random.bytes(IV_SIZE, true) or random.bytes(IV_SIZE, false)
local gcm = aes.new("aes-256-gcm")
local crypted = gcm:encrypt(symkey, iv, bytes, false, gcm_aad)
if not crypted then
return nil
end
local tag = gcm:get_aead_tag()
return iv .. crypted .. tag
end
local function decrypt(bytes)
local iv = bytes:sub(1, IV_SIZE)
local contents = bytes:sub(IV_SIZE + 1, -TAG_SIZE - 1)
local tag = bytes:sub(-TAG_SIZE)
local gcm = aes.new("aes-256-gcm")
local decrypted = gcm:decrypt(symkey, iv, contents, false, gcm_aad, tag)
return decrypted
end
local function hmac(message)
local inner = s256:new()
inner:update(i_key_pad .. message)
local outer = s256:new()
outer:update(o_key_pad .. inner:final())
return outer:final()
end
local function to_jws(jwt)
local jwt64 = b64.encode_base64url(json.encode(jwt))
return jose_256_b64 .. "." .. jwt64 .. "." .. b64.encode_base64url(hmac(jose_256_b64 .. jwt64))
end
local function to_jwt(jws)
local jwslen = #jws
local dot1, _ = jws:find("%.")
if not dot1 or dot1 == jwslen then
return nil
end
local dot2, _ = jws:find("%.", dot1 + 1)
if not dot2 or dot2 == jwslen then
return nil
end
local jose64 = jws:sub(1, dot1 - 1)
if jose64 ~= jose_256_b64 then
return nil
end
local js64 = jws:sub(dot1 + 1, dot2 - 1)
local sig = jws:sub(dot2 + 1)
if sig ~= b64.encode_base64url(hmac(jose64 .. js64)) then
return nil
end
return json.decode(b64.decode_base64url(js64))
end
-- https://www.rfc-editor.org/rfc/rfc7519
-- https://openid.net/specs/openid-connect-core-1_0.html
local function get_jws_and_tslimit(profile)
local user = profile:user()
local ser_profile = profile:serialize()
log.debug("Creating JWS with profile: " .. ser_profile:gsub("([\031\030\029\028\027\026])", function(s) return "[" .. s:byte() .. "]" end))
local crypted_profile = encrypt(ser_profile)
if not user or not crypted_profile then
return nil, nil
end
local iat = nginx.get_seconds_since_epoch()
local exp = iat + config.get_session_seconds()
local jwt = {
iss = "https://" .. config.get_sso_host(),
sub = user,
aud = user,
exp = exp,
iat = iat,
x_ssso = b64.encode_base64url(crypted_profile),
}
return to_jws(jwt), exp
end
local function get_profile_and_new_jws(jws)
local jwt = to_jwt(jws)
local iat = nginx.get_seconds_since_epoch()
if jwt == nil or not jwt["x_ssso"] or not jwt["exp"] or jwt.exp < iat then
return nil, nil, nil
end
local ser_profile = decrypt(b64.decode_base64url(jwt.x_ssso))
if not ser_profile then
return nil, nil, nil
end
log.debug("Read profile from JWS: " .. ser_profile:gsub("([\031\030\029\028])", function(s) return "[" .. s:byte() .. "]" end))
local profile = sites.class__profile:deserialize(ser_profile)
jwt.iat = iat
jwt.exp = iat + config.get_session_seconds()
return profile, to_jws(jwt), jwt.exp
end
return {
get_jws_and_tslimit = get_jws_and_tslimit,
get_profile_and_new_jws = get_profile_and_new_jws,
}