/*
 * Copyright (C) 2024 Curity AB. All rights reserved.
 *
 * The contents of this file are the property of Curity AB.
 * You may not copy or use this file, in either source code
 * or executable form, except in compliance with terms
 * set by Curity AB.
 *
 * For further information, please contact Curity AB.
 */

package io.curity.ssi.sdjwt

import io.curity.ssi.crypto.createHashFunction
import io.curity.ssi.crypto.encodeBase64url
import io.curity.ssi.crypto.sha256
import io.curity.ssi.jose.JsonJwt
import io.curity.ssi.validation.ValidationResult
import kotlinx.serialization.json.JsonPrimitive
import kotlin.jvm.JvmOverloads
import kotlin.jvm.JvmStatic

/**
 * Represents a parsed SD-JWT
 *
 * @property jwtString The original JWT string
 * @property jwt the parsed JWT
 * @property disclosedCustomClaims the JWT claims with disclosure information
 * @property disclosureStrings the disclosure strings, present outside the JWT
 */
class SdJwt(
    val jwtString: String,
    val jwt: JsonJwt.Jws,
    val disclosedCustomClaims: SdValue.Object,
    val disclosureStrings: List<String>,
    val keyBindingJwtString: String?,
    val keyBindingJwt: JsonJwt.Jws?,
) {

    /**
     * Produces an encoded SD-JWT with *all* disclosure strings.
     */
    fun encodeAsStringWithAllDisclosures() = encode(jwtString, disclosureStrings, keyBindingJwtString)


    /**
     * Produces an encoded SD-JWT with a subset of disclosure strings
     */
    fun encodeAsStringWithClaims(disclosures: List<String>) = encode(jwtString, disclosures, keyBindingJwtString)

    companion object {
        @JvmStatic
        @JvmOverloads
        fun encode(
            jwtString: String,
            disclosureStrings: List<String>,
            keyBindingJwtString: String? = null,
        ): String {
            val builder = StringBuilder(jwtString)
                .append(SdJwtConstants.DELIMITER)
            disclosureStrings.forEach {
                builder.append(it).append(SdJwtConstants.DELIMITER)
            }
            if (keyBindingJwtString != null) {
                builder.append(keyBindingJwtString)
            }

            return builder.toString()
        }

        @JvmStatic
        fun decode(encodedToken: String): ValidationResult<Decoded> {
            val parts = encodedToken.split(SdJwtConstants.DELIMITER)
            if (parts.size < 2) {
                return ValidationResult.invalid("Invalid SD-JWT: missing at least one '${SdJwtConstants.DELIMITER}'")
            }
            if (parts[0].isEmpty()) {
                return ValidationResult.invalid("Invalid SD-JWT: needs to start with a JWT")
            }
            val jwtString = parts[0]
            // We do not use the last string, which is the one after the trailing '~'
            val disclosures = parts.slice(1..<parts.lastIndex)
            // The last part is the KB-JWT, if present
            val keyBindingJwtString = parts[parts.lastIndex].let {
                it.ifBlank {
                    null
                }
            }
            return ValidationResult.valid(Decoded(jwtString, disclosures, keyBindingJwtString))
        }

        /**
         * Computes the value of the `sd_hash` payload claim to use in the KB-JWT
         */
        @JvmStatic
        suspend fun computeKbJwtSdHash(
            jwt: JsonJwt.Jws,
            jwtString: String,
            disclosures: List<String>
        ): String {
            val hashFunction = getHashFunctionToUse(jwt)
            val hashInput = encode(jwtString, disclosures)
            return encodeBase64url(hashFunction(hashInput.encodeToByteArray()))
        }

        @JvmStatic
        fun getHashFunctionToUse(jwt: JsonJwt.Jws) =
            jwt.payload[SdJwtConstants.CLAIM_SD_ALG]?.let { sdAlgElement ->
                if (sdAlgElement !is JsonPrimitive || !sdAlgElement.isString) {
                    throw IllegalArgumentException("Invalid SD-JWT: ${SdJwtConstants.CLAIM_SD_ALG} is not a string")
                }
                val hashAlg = sdAlgElement.content
                createHashFunction(hashAlg)
                    ?: throw IllegalArgumentException("Invalid SD-JWT: Unsupported hash algorithm $hashAlg ")

            } ?: sha256()
    }

    data class Decoded(
        val jwt: String,
        val disclosures: List<String>,
        val keyBindingJwt: String? = null,
    )
}
