結構前にScalaでタイトルのことをやったのですが、それのPython版。

Cognitoではないけれども、自分がやりたいことほぼ同じことをやっている記事が既にあった。基本的にこちらを参考にさせてもらっている。

[Python] PyJWT で Google OAuth 2.0 API の ID Token を検証

環境


以下の環境でやっている。

  • Python 3.6.8
  • pyjwt 1.7.1
  • cryptography 3.1.1
  • flask 1.1.2

Flaskは今回の内容とは直接関係ないけど、API叩いたときにAuthorizationヘッダーのトークンが正しいかどうか。という実装をためしてみたので、そのAPIサーバとして使っている。

インストール


各々のパッケージをインストールする。前述のとおり、APIサーバ用途でflaskも使うので下記のコマンドはflaskを含んでいるけど、JWTのバリデーションだけなら不要。また、パッケージのインストールはpipenvを使っている。

1
$ pipenv install pyjwt cryptography flask

コード


検証に必要な内容は下記のAWSの公式ドキュメントに書いてある。

AWS Cognito: JSON ウェブトークンの検証

これに従うと…

  • JWK(JSON Web Key)をダウンロードする
    • Cognitoの場合は https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json
  • JWKから公開鍵を生成する
  • JWTをデコードする
  • JWTの署名、有効期限、発行者、token_useクレームの検証する

を行わないといけない。有効期限とかはライブラリでチェックしてくれる。

これに従って書くとザっと下記のようになる。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json
import jwt
import requests

from jwt.algorithms import RSAAlgorithm

class Cognito:

cognito_iss = https://cognito-idp.<region>.amazonaws.com/<userPoolId>
cognito_app_client_id = <your-client-id>
cognito_jwk_url = cognito_iss + '/.well-known/jwks.json'
jwk_set = requests.get(cognito_jwk_url).json()

@classmethod
def validate_jwt(cls, token):

header = jwt.get_unverified_header(token)
jwk = next(filter(lambda x: x['kid'] == header['kid'], cls.jwk_set['keys']))
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))

claims = jwt.decode(
token,
public_key,
issuer=cls.cognito_iss,
audience=cls.cognito_app_client_id,
algorithms=jwk['alg'],
options = dict(
verify_aud=False
)
)

if claims['client_id'] != cls.cognito_app_client_id:
raise Exception

if claims['token_use'] != 'access':
raise Exception

if claims['iss'] != cls.cognito_iss:
raise Exception

ところで、前述の上記ドキュメント内に audのチェックが必要な旨が書いてある。しかし、Cognitoから発行されたアクセストークンにはaudが含まれておらず、特にオプションも指定せずにそのままjwt.decodeすると例外が発生する。

これは「自分のCognitoの設定がおかしいのか??」とおもったものの、audはどうやらIDトークンには含まれるがアクセストークンには含まれないっぽい。

Why doesn’t Amazon Cognito return an audience field in its access tokens?

そのため、上記のコードではoptionsverify_audFlaseにしている。オプションはドキュメント参照。

PyJWT: API Reference

Flaskに組み込む


さて、せっかくなのでFlaskのAPIに組み込んでみる。

まず、独自の例外を作る。

1
2
3
# app.utils.exceptions
class UnauthorizedException(Exception):
pass

これを先ほどのクラスで使用するようにする。さらに、リクエストヘッダのAuthorizationからアクセストークンを取得してJWTのバリデーション関数に渡すデコレータを作る。デコレータ作らなくても全部のAPIに愚直に処理を書くという手もある。実際、少し調べたところハンドリングがめんどくさそうなのでAPIの数が少なければそれでいいかもしれないという気もする…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import json
import jwt
import requests

from flask import Blueprint, Response, request
from jwt.algorithms import RSAAlgorithm
from functools import wraps
from app.utils.exceptions import UnauthorizedException

class Cognito:

cognito_iss = https://cognito-idp.<region>.amazonaws.com/<userPoolId>
cognito_app_client_id = <your-client-id>
cognito_jwk_url = cognito_iss + '/.well-known/jwks.json'
jwk_set = requests.get(cognito_jwk_url).json()

@classmethod
def validate_jwt(cls, token):

header = jwt.get_unverified_header(token)
jwk = next(filter(lambda x: x['kid'] == header['kid'], cls.jwk_set['keys']))
public_key = RSAAlgorithm.from_jwk(json.dumps(jwk))

claims = jwt.decode(
token,
public_key,
issuer=cls.cognito_iss,
audience=cls.cognito_app_client_id,
algorithms=jwk['alg'],
options = dict(
verify_aud=False
)
)

if claims['client_id'] != cls.cognito_app_client_id:
raise UnauthorizedException

if claims['token_use'] != 'access':
raise UnauthorizedException

if claims['iss'] != cls.cognito_iss:
raise UnauthorizedException

# デコレータ
def jwt_validator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
token = request.headers['Authorization'].split()[1]
Cognito.validate_jwt(token)
return f(*args, **kwargs)
return decorated_function

作成したデコレータjwt_validatorをルーティングに足す。

1
2
3
4
5
@module_users.route("/users/<user_name>", methods=['GET'], strict_slashes=False)
@jwt_validator
def get_user(user_name):

...なんかいろんな処理

これでJWTのバリデーションに失敗した場合、良い感じに例外が発生するようになる。後は、例外に応じて良い感じにハンドリングしてレスポンス返すような実装があれば望ましいけど、そこまでやってないのでこれで終わり。