要するにJWT認証をしたい。JWT認証をしたいんですけど(Scalaの)記事がないんじゃぁぁぁ。というわけで書いた記事になります。ここでいうJWT認証は実際のシステムでいうとサーバサイドの話であって今回はクライアントサイドは関係ないです。JWTはAWS Cognitoから発行されたものという前提で書いてますが、他のでも同様のやり方でできるハズです。

ライブラリを探す


ちょろっとできないかな~とライブラリを探した結果、どうもJVM界隈ではnimbus-jose-jwtというのが有名っぽかったので、こちらを使うことにしました。この他にも多くのライブラリがありますが、探すにはJWT.IOを使うのが良いと思います。

Scalaではjwt-scalaというライブラリが有名そうで、こちらはJSONライブラリのcirce向けのものも存在するということもあって使ってみようかと思ったのですが、ドキュメントを読んだ感じだとJWTのヘッダーのkidを取得して公開鍵を生成してごにょごにょしなければならないようで、手間そうだったのでnimbus-jose-jwtを使うこととしました。

環境


下記環境で試してます。

  • Scala 2.12.8
  • nimbus-jose-jwt 7.1

やること


まず、流れを書いておかないと記事としてアレなので雑にやることを書きます。

JSON ウェブトークンの検証

やることは上記の通りです。上記をザックリ訳すと下記になります。

  • JWK(JSON Web Key)をダウンロードする
    • Cognitoの場合は https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json にある
  • JWKから公開鍵を生成する
  • JWTをデコードする
  • JWTの署名、有効期限、発行者、token_useクレームの検証する

あっ、これなんかめんどくさそうですね。めんどくさそうですが、nimbus-jose-jwtだったらちょろっとでできました。

コードを書いてみる


Validating bearer JWT access tokens

だいたいもう上記のドキュメントの通りなんですが、まず、ざっくりと書いてみると下記のような感じになりました。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import java.net.URL

import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.source.{JWKSource, RemoteJWKSet}
import com.nimbusds.jose.proc.{JWSVerificationKeySelector, SecurityContext}
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.proc.{BadJWTException, ConfigurableJWTProcessor, DefaultJWTProcessor}

object AWSCognitoJWT extends App {

val accessToken = "<実際にクライアントから渡ってくるJWT>"

// AWS CognitoのJWK(リージョンとプールIDは自分のものを指定する)
val awsCognitoJwk = "https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json"

// JWKをAWS Cognitoからダウンロードしてくる(これはアプリケーションの起動時のみという設計にする)
val jwkSet: JWKSource[SecurityContext] = new RemoteJWKSet(new URL(awsCognitoJwk))

// 署名とか有効期限をチェックするプロセッサーのインスタンスを生成
val jwtProcessor: ConfigurableJWTProcessor[SecurityContext] = new DefaultJWTProcessor[SecurityContext]

// トークンのアルゴリズムを指定してJWSVerificationKeySelectorインスタンスを生成してプロセッサーにセット
val jWSVerificationKeySelector: JWSVerificationKeySelector[SecurityContext] = new JWSVerificationKeySelector(JWSAlgorithm.RS256, jwkSet)
jwtProcessor.setJWSKeySelector(jWSVerificationKeySelector)

// プロセッサーでチェックする(ここで認証ダメなら例外発生するので下記のコードは厳密にはダメ)
val claimsSet = jwtProcessor.process(accessToken, null)

println(claimsSet.toJSONObject())
}

これの結果で下記のようにJWTをデコードしたものが取得できます。ちなみに、これはJWT.IOのデバッガーでデコードした場合のPayloadの値と同じです。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 値はダミーです
{
"sub": "96487788-4567-4a78-b3c9-cf5609pb0643",
"event_id": "jiok3456-73s4-12e9-8e68-llo0e25t6yd3",
"token_use": "access",
"scope": "aws.cognito.signin.user",
"auth_time": 1557575911,
"iss": "https://cognito-idp.ap-northeast-1.amazonaws.com/ap-northeast-1_ABCDEFGHIJK",
"exp": 1557674049,
"iat": 1557670449,
"jti": "tyu8505t-3hka-4a11-3c5k-4loijaclmhy9",
"client_id": "lijtu68mjbxh78459kiiiuy7ba",
"username": "test-user2"
}

また、例えば有効期限が切れている場合などはjwtProcessor.process実行時に下記のような例外が出ます。つまり、プロセッサーを通したら有効かどうかもチェックできているということですね。

1
2
3
Exception in thread "main" com.nimbusds.jwt.proc.BadJWTException: Expired JWT
at com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier.<clinit>(DefaultJWTClaimsVerifier.java:62)
at com.nimbusds.jwt.proc.DefaultJWTProcessor.<init>(DefaultJWTProcessor.java:139)

コードを少しなおす


ついでなのでもうちょっと修正したコードを載せときます。最終的にはだいたい下記のような感じになるんじゃないかなぁぁぁ…と思います。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import java.net.URL

import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.source.{JWKSource, RemoteJWKSet}
import com.nimbusds.jose.proc.{JWSVerificationKeySelector, SecurityContext}
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.proc.{BadJWTException, ConfigurableJWTProcessor, DefaultJWTProcessor}

import scala.util.{Failure, Success, Try}

object AwsCognitoJwt {

private val awsCognitoJwk = "https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json"
private val jwkSet: JWKSource[SecurityContext] = new RemoteJWKSet(new URL(awsCognitoJwk))

private val jwtProcessor: ConfigurableJWTProcessor[SecurityContext] = new DefaultJWTProcessor[SecurityContext]
private val jWSVerificationKeySelector: JWSVerificationKeySelector[SecurityContext] = new JWSVerificationKeySelector(JWSAlgorithm.RS256, jwkSet)

jwtProcessor.setJWSKeySelector(jWSVerificationKeySelector)

def validate(jwt: String): Either[String, JWTClaimsSet] = {
Try(jwtProcessor.process(jwt, null)) match {
case Success(jwtClaimsSet: JWTClaimsSet) => Right(jwtClaimsSet)
case Failure(badJwtException: BadJWTException) => Left("無効なJWTです")
case Failure(exception: Exception) => Left("例外が発生しました")
}
}

}

おわり。