/*
 * Copyright (C) 2024 Curity AB. All rights reserved.
 *
 * The contents of this file are the property of Curity AB.
 * You may not copy or use this file, in either source code
 * or executable form, except in compliance with terms
 * set by Curity AB.
 *
 * For further information, please contact Curity AB.
 */

package io.curity.vc.services

import io.curity.ssi.crypto.CryptoException
import io.curity.ssi.crypto.VerificationKey
import io.curity.ssi.data.Dictionary
import io.curity.ssi.did.DidUrl
import io.curity.ssi.did.document.DidResolutionException
import io.curity.ssi.did.document.DidResolver
import io.curity.ssi.did.json.JsonPublicKeyJwk
import io.curity.ssi.did.json.SerializableDidUrl
import io.curity.ssi.jose.JsonJwsHeader
import io.curity.ssi.jose.JsonJwtClaimsSet
import io.curity.ssi.jose.JwsUtil
import io.curity.ssi.json.data.ToJsonDictionary
import io.curity.ssi.validation.ValidationScope
import io.curity.ssi.validation.suspendableValidationScope
import io.curity.vc.CredentialEndpoint.VerifiableCredential.StringCredential
import io.curity.vc.Issuer
import io.curity.vc.serialization.DidValidatedJwtVcJsonCredentialResponse
import io.curity.vc.serialization.JsonCredentialIssuer
import io.curity.vc.serialization.JsonDisplayable
import io.curity.vc.serialization.JwtVcJsonCredentialConfigurationsSupported
import io.curity.vc.serialization.JwtVcJsonCredentialResponse
import io.curity.vc.serialization.JwtVcJsonLdCredentialResponse
import io.curity.vc.serialization.JwtVcJsonVerifiableCredentialResponse
import io.curity.vc.serialization.LdpVcCredentialResponse
import io.curity.vc.serialization.SerializableW3CVerifiableCredential
import io.curity.vc.serialization.ValidatedJwtVcJsonCredentialResponse
import io.curity.vc.serialization.ValidatedJwtVcJsonLdCredentialResponse
import io.curity.vc.serialization.ValidatedLdpVcCredentialResponse
import io.curity.vc.serialization.ValidatedVcSdJwtCredentialResponse
import io.curity.vc.serialization.VcSdJwtCredentialResponse
import io.curity.vc.services.JwtCredentialValidationHelper.validateCredentialClaims
import io.curity.vc.services.JwtCredentialValidationHelper.validateJwtSignature
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
class DidSubjectVerifiableCredentialResponseValidator(
    private val issuer: JsonCredentialIssuer,
    private val didSubjectIdentifier: DidUrl,
    private val clock: Clock = Clock.System,
    private val supportedCredential: JwtVcJsonCredentialConfigurationsSupported,
    private val didResolver: DidResolver<JsonElement>,
) : VerifiableCredentialResponseValidator {
    override suspend fun validateJwtVcJson(response: JwtVcJsonCredentialResponse): ValidatedJwtVcJsonCredentialResponse {
        val validator =
            JwtVcJsonCredentialWithDidSubjectValidator(
                issuer,
                didSubjectIdentifier,
                clock,
                supportedCredential,
                didResolver
            )
        return validator.validate(response)
    }

    override suspend fun validateJwtVcJsonLd(response: JwtVcJsonLdCredentialResponse): ValidatedJwtVcJsonLdCredentialResponse {
        throw UnsupportedOperationException("This validator can only handle the 'jwt_vc_json' format")
    }

    override suspend fun validateLdp(response: LdpVcCredentialResponse): ValidatedLdpVcCredentialResponse {
        throw UnsupportedOperationException("This validator can only handle the 'jwt_vc_json' format")
    }

    override suspend fun validateVcSdJwt(response: VcSdJwtCredentialResponse): ValidatedVcSdJwtCredentialResponse {
        throw UnsupportedOperationException("This validator can only handle the 'jwt_vc_json' format")
    }
}

private class JwtVcJsonCredentialWithDidSubjectValidator(
    private val issuer: JsonCredentialIssuer,
    private val didSubjectIdentifier: SerializableDidUrl,
    private val clock: Clock,
    private val supportedCredential: JwtVcJsonCredentialConfigurationsSupported,
    private val didResolver: DidResolver<JsonElement>,
) {

    suspend fun validate(
        response: JwtVcJsonCredentialResponse
    ): DidValidatedJwtVcJsonCredentialResponse = suspendableValidationScope {
        val credential: StringCredential = when (response) {
            is JwtVcJsonCredentialResponse.Deferred -> error("Cannot handle deferred response")
            is JwtVcJsonVerifiableCredentialResponse -> response.credential
        }

        val unsafeJws = JwsUtil.unsafeParseJws(credential.value)

        val didKid = unsafeJws.header.let { header ->
            if (header == null) {
                addValidationError("JWS Header is missing")
                null
            } else {
                validateHeader(header)
            }
        }

        validatePayload(unsafeJws.payload)?.let { (w3cCredential, subjectDisplayInformation) ->
            if (hasNoErrors && didKid != null) {
                if (!validateJwtSignature(credential.value, didKid, didResolver)) {
                    addValidationError("Unable to validate signature")
                    null
                } else {
                    DidValidatedJwtVcJsonCredentialResponse(
                        response,
                        credential.value,
                        w3cCredential,
                        subjectDisplayInformation,
                        supportedCredential.display
                    )
                }
            } else {
                null
            }
        }
    }.orThrow()

    private fun ValidationScope.validateHeader(header: JsonJwsHeader): DidUrl? {
        return header.kid?.let { kidString ->
            DidUrl.from(kidString).ifNull { addValidationError("JWE 'kid' is not a valid DID URL") }
        }.ifNull { addValidationError("JWE 'kid' claim is missing") }
    }

    private fun ValidationScope.validatePayload(payload: JsonJwtClaimsSet):
            Pair<SerializableW3CVerifiableCredential, Map<String, List<JsonDisplayable>>>? {

        val w3cCredential = validateCredentialClaims(payload, didSubjectIdentifier, clock)

        val subjectDisplayInformation = supportedCredential.credentialDefinition.credentialSubject
            ?.claims
            ?.entries?.mapNotNull { (key, value) ->
                val valueDisplay = value.display
                if (valueDisplay != null) {
                    key to valueDisplay
                } else {
                    null
                }
            }
            ?.associate { (key, value) ->
                key to value
            } ?: emptyMap()

        return if (hasNoErrors && w3cCredential != null) {
            w3cCredential to subjectDisplayInformation
        } else {
            null
        }
    }

    companion object {

        private fun <T> T?.ifNull(action: () -> Unit): T? {
            if (this == null) action()
            return this
        }
    }

}

object JwtCredentialValidationHelper {

    suspend fun ValidationScope.validateJwtSignature(
        jwt: String,
        didUrl: DidUrl,
        didResolver: DidResolver<JsonElement>,
    ): Boolean {
        val issuerKey = try {
            val publicKeyJwk = didResolver.resolve(didUrl.did).document
                .getVerificationMethodFor(didUrl)
                ?.publicKeyJwk
            if (publicKeyJwk == null) {
                addValidationError("Unable to find usable key in DID document")
            }
            publicKeyJwk
        } catch (ex: DidResolutionException) {
            addValidationError("Unable to find DID document")
            return false
        } ?: return false

        val verificationKey = try {
            VerificationKey.fromJwk(Json.encodeToString(JsonPublicKeyJwk.from(issuerKey)))
        } catch (ex: CryptoException) {
            addValidationError("Unable to use key resolved from DID URL")
            return false
        }

        return try {
            JwsUtil.verifyJws(verificationKey, jwt)
            true
        } catch (ex: CryptoException) {
            addValidationError("Signature is not valid")
            false
        }
    }

    fun ValidationScope.validateCredentialClaims(
        payload: JsonJwtClaimsSet,
        didSubjectIdentifier: SerializableDidUrl,
        clock: Clock,
    ): SerializableW3CVerifiableCredential? {

        val iss = payload.iss
        val sub = payload.sub.also { sub ->
            isTrue("'sub' claim matches subject identifier") {
                sub == didSubjectIdentifier.did.toString()
            }
        }

        val jti = payload.jti.also {
            // TODO requirements on jti
        }

        val exp = mandatory("exp", payload.exp) { exp ->
            val expInstant = Instant.fromEpochSeconds(exp)
            isTrue("'exp' claim is in the future") {
                clock.now() < expInstant
            }
            expInstant
        }

        val iat = mandatory("iat", payload.iat) { iat ->
            val iatInstant = Instant.fromEpochSeconds(iat)
            isTrue("'iat' is not in the future") {
                iatInstant <= clock.now()
            }
            iatInstant
        }

        return payload["vc"]?.let { vc ->
            tryConvert({ err -> "Credential does not conform to w3c data model: $err" }, vc) { credential ->
                SerializableW3CVerifiableCredential.fromJson(credential)
            }?.also { vcClaim -> validateVcClaim(vcClaim, iss, sub, jti, iat, exp, didSubjectIdentifier) }
        }
    }

    private fun ValidationScope.validateVcClaim(
        vc: SerializableW3CVerifiableCredential,
        iss: String?,
        sub: String?,
        jti: String?,
        iat: Instant?,
        exp: Instant?,
        didSubjectIdentifier: SerializableDidUrl,
    ) {

        vc.type.let { typeList ->
            isTrue("type contains 'VerifiableCredential'") {
                typeList.contains("VerifiableCredential")
            }
        }

        vc.issuer.let { vcIssuer ->
            isTrue("vc.issuer matches jwt.issuer") {
                val issuerId = when (vcIssuer) {
                    is Issuer.IssuerObject -> vcIssuer.id
                    is Issuer.IssuerURI -> vcIssuer.value
                }
                issuerId == iss
            }
        }

        vc.id.let { id ->
            isTrue("vc.id must match jwt.jti") {
                id == jti
            }
        }

        // TODO check if `vc.issued` is the same as `vc.issuanceDate`
        // FIXME IS-8293
        if (iat != null) {
            vc.issued.let { issuedString ->
                tryConvert({ err -> "vc.issued is invalid: $err" }, issuedString) {
                    Instant.parse(issuedString)
                }?.let { issued ->
                    isTrue("vc.issued matches jwt.iat") {
                        sameish(issued, iat)
                    }
                }
            }
        }

        vc.expirationDate?.let { expirationDate ->
            tryConvert({ err -> "vc.expirationDate is invalid: $err" }, expirationDate) {
                Instant.parse(expirationDate)
            }?.let { expiration ->
                isTrue("vc.expirationDate matches jwt.exp") {
                    sameish(expiration, exp)
                }
            }
        }

        vc.credentialSubject.forEach { subject ->
            validateVcCredentialSubject(subject, sub, didSubjectIdentifier)
        }
    }

    private fun ValidationScope.validateVcCredentialSubject(
        credentialSubject: ToJsonDictionary,
        sub: String?,
        didSubjectIdentifier: SerializableDidUrl,
    ) {
        claim<JsonElement>(credentialSubject, "id") { idJsonElement ->
            val credentialSubjectId: String? = when (idJsonElement) {
                is JsonPrimitive -> idJsonElement.content
                is JsonObject -> (idJsonElement["id"] as? JsonPrimitive)?.content
                else -> null
            }
            isTrue("credentialSubject.id must match the used subject identifier") {
                credentialSubjectId == didSubjectIdentifier.did.toString()
            }
            isTrue("credentialSubject.id must match jwt.sub") {
                credentialSubjectId == sub
            }
        }
    }

    private inline fun <reified T> ValidationScope.claim(
        payload: Dictionary<T>,
        name: String,
        block: ValidationScope.(T) -> Unit,
    ): T? {
        val value = payload[name]
        if (value == null) {
            addValidationError("Missing claim '$name'")
        } else {
            block(value)
        }
        return value
    }

    private fun sameish(i1: Instant?, i2: Instant?): Boolean =
        i1 != null && i2 != null && i1.minus(i2).absoluteValue.inWholeSeconds <= 2

}