/*
 * Copyright (C) 2023 Curity AB. All rights reserved.
 *
 * The contents of this file are the property of Curity AB.
 * You may not copy or use this file, in either source code
 * or executable form, except in compliance with terms
 * set by Curity AB.
 *
 * For further information, please contact Curity AB.
 */

package io.curity.vp.services

import io.curity.ssi.crypto.CryptoException
import io.curity.ssi.crypto.VerificationKeySet
import io.curity.ssi.jose.JsonJwt
import io.curity.ssi.jose.JsonJwtClaimsSet
import io.curity.ssi.jose.JwsUtil
import io.curity.ssi.validation.ValidationScope
import io.curity.ssi.validation.suspendableValidationScope
import io.curity.vp.ClientIdScheme
import io.curity.vp.RequestObjectJwt
import io.curity.vp.RequestObjectJwt.Keys.NONCE
import io.curity.vp.RequestObjectJwt.Keys.RESPONSE_MODE
import io.curity.vp.RequestObjectJwt.Keys.RESPONSE_TYPE
import io.curity.vp.RequestObjectJwt.Keys.RESPONSE_URI
import io.curity.vp.RequestObjectJwt.Keys.STATE
import io.curity.vp.serialization.JsonPresentationDefinition
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.request.get
import kotlinx.datetime.Clock
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonPrimitive

sealed interface RequestObjectJwtValidator {
    suspend fun validate(requestObjectJwt: String): ValidatedRequestJwt
}

sealed class DefaultRequestObjectJwtValidator(
    private val preRegisteredClientStore: PreRegisteredClientStore,
    private val clock: Clock,
    private val httpClient: HttpClient
) : RequestObjectJwtValidator {


    override suspend fun validate(requestObjectJwt: String) = suspendableValidationScope {
        val jws = JwsUtil.unsafeParseJws(requestObjectJwt)
        val clientIdSchemeString = jws.payload[RequestObjectJwt.Keys.CLIENT_ID_SCHEME]?.jsonPrimitive?.content
        if (clientIdSchemeString == null) {
            addValidationError("Missing ${RequestObjectJwt.Keys.CLIENT_ID_SCHEME} claim")
            return@suspendableValidationScope null
        }

        validateBasicClaims(jws.payload)
        val presDef = validatePresentationDefinition(jws.payload) ?: return@suspendableValidationScope null

        when (val clientIdScheme = ClientIdScheme.fromString(clientIdSchemeString)) {
            ClientIdScheme.PRE_REGISTERED -> {
                val clientId = jws.payload[RequestObjectJwt.Keys.CLIENT_ID]?.jsonPrimitive?.content
                if (clientId != null) {
                    val preRegisteredClient = preRegisteredClientStore.getByClientId(clientId)
                    if (preRegisteredClient == null) {
                        addValidationError("Unknown client $clientId")
                    } else {
                        validateVpFormatForPreRegisteredClient(preRegisteredClient)
                        validateSignatureForPreRegisteredClient(requestObjectJwt, preRegisteredClient)
                    }
                } else {
                    addValidationError("Missing ${RequestObjectJwt.Keys.CLIENT_ID} claim")
                }

            }

            else -> addValidationError("Unsupported ${RequestObjectJwt.Keys.CLIENT_ID_SCHEME} value: $clientIdScheme")
        }

        if (hasNoErrors) {
            val state = jws.payload.customClaims[STATE]?.jsonPrimitive?.content

            // these claims were validated above - TODO use builder to accumulate them?
            val responseUri = jws.payload[RESPONSE_URI]?.jsonPrimitive?.content!!
            val nonce = jws.payload[NONCE]?.jsonPrimitive?.content!!
            val iss = jws.payload.iss!!
            val exp = jws.payload.exp!!
            ValidatedRequestJwt(
                jws = jws,
                responseUri = responseUri,
                iss = iss,
                nonce = nonce,
                presentationDefinition = presDef,
                state = state,
                expAt = exp
            )
        } else {
            null
        }
    }.orThrow()

    private fun ValidationScope.validatePresentationDefinition(payload: JsonJwtClaimsSet): JsonPresentationDefinition? {
        val presentationDefinition = payload[RequestObjectJwt.Keys.PRESENTATION_DEFINITION]?.jsonPrimitive?.content
        if (presentationDefinition == null) {
            addValidationError("Missing ${RequestObjectJwt.Keys.PRESENTATION_DEFINITION} claim")
            return null
        }
        return Json.decodeFromString(presentationDefinition)
    }

    private fun ValidationScope.validateBasicClaims(claims: JsonJwtClaimsSet) {
        val nowInEpochSeconds = clock.now().epochSeconds
        val exp = claims.exp
        if (exp == null) {
            addValidationError("exp claim is missing")
        } else if (nowInEpochSeconds > exp) {
            addValidationError("Request Object expired")
        }

        val nbf = claims.nbf
        if (nbf == null) {
            addValidationError("nbf claim is missing")
        } else if (nowInEpochSeconds < nbf) {
            addValidationError("Request Object used before the 'nbf' claim")
        }

        if (claims[NONCE]?.jsonPrimitive?.content == null) {
            addValidationError("Missing $NONCE claim")
        }

        val responseType = claims[RESPONSE_TYPE]?.jsonPrimitive?.content
        if (responseType != "vp_token" && responseType != "vp_token id_token" && responseType != "code") {
            addValidationError("Invalid value for the $RESPONSE_TYPE claim: $responseType.")
        }

        val aud = claims.aud
        val iss = claims.iss

        val isAudInvalid = aud?.all?.firstOrNull { it == "https://self-issued.me/v2" || it == iss } == null
        if (isAudInvalid) {
            addValidationError("Invalid value for the aud claim: ${claims.aud}.")
        }

        val clientId = claims[RequestObjectJwt.Keys.CLIENT_ID]?.jsonPrimitive?.content
        if (iss != clientId) {
            addValidationError("iss claim value doesn't match ${RequestObjectJwt.Keys.CLIENT_ID} claim value.")
        }
    }

    private fun ValidationScope.validateVpFormatForPreRegisteredClient(preRegisteredClient: PreRegisteredClient) {
        // check if both wallet and pre registred client support the same VP format and algorithms
        val isFormatSupported = WalletMetadata.vpFormatsSupported.firstOrNull { walletFormat ->
            preRegisteredClient.vpFormatsSupported.firstOrNull { clientFormat ->
                clientFormat.format == walletFormat.format && clientFormat.algs.intersect(walletFormat.algs)
                    .isNotEmpty()
            } != null
        } != null
        if (!isFormatSupported) {
            addValidationError("Format not supported")
        }
    }

    private suspend fun ValidationScope.validateSignatureForPreRegisteredClient(
        jws: String, preRegisteredClient: PreRegisteredClient
    ) {
        try {
            val keySet: VerificationKeySet = fetchJwks(preRegisteredClient.jwksUri)
            JwsUtil.verifyJws(keySet, jws)
        } catch (ex: CryptoException) {
            addValidationError("Invalid signature.")
        }
    }

    private suspend fun ValidationScope.fetchJwks(jwksUri: String): VerificationKeySet {
        val response = httpClient.get(jwksUri)
        if (response.status.value != 200) {
            addValidationError("Cannot fetch JWKS from: $jwksUri. Status: ${response.status.value}")
        }
        return VerificationKeySet.fromJwks(response.body())
    }
}

class HaipRequestObjectJwtValidator(
    preRegisteredClientStore: PreRegisteredClientStore,
    clock: Clock,
    httpClient: HttpClient
) : DefaultRequestObjectJwtValidator(preRegisteredClientStore, clock, httpClient) {

    override suspend fun validate(requestObjectJwt: String) = suspendableValidationScope {
        val validatedRequestJwt = super.validate(requestObjectJwt)

        val responseMode = validatedRequestJwt.jws.payload[RESPONSE_MODE]?.jsonPrimitive?.content
        if (responseMode != "direct_post") {
            addValidationError("Invalid $RESPONSE_MODE. Only 'direct_post' is accepted.")
        }
        val responseType = validatedRequestJwt.jws.payload[RESPONSE_TYPE]?.jsonPrimitive?.content
        if (responseType != "vp_token") {
            addValidationError("Invalid $RESPONSE_TYPE. Only 'vp_token' is accepted.")
        }

        // when response_mode = direct_post, then response must be posted to a https uri
        val responseUri = validatedRequestJwt.jws.payload[RESPONSE_URI]?.jsonPrimitive?.content
        if (responseUri == null || !responseUri.startsWith("https://")) {
            addValidationError("Invalid value for the $RESPONSE_URI claim: $responseUri. 'https' is required.")
        }
        if (hasNoErrors) {
            validatedRequestJwt
        } else {
            null
        }
    }.orThrow()
}

@Serializable
data class ValidatedRequestJwt(
    val jws: JsonJwt.Jws,
    val responseUri: String,
    val iss: String,
    val nonce: String,
    val presentationDefinition: JsonPresentationDefinition,
    val state: String?,
    val expAt: Long
)