/*
 * Copyright (C) 2024 Curity AB. All rights reserved.
 *
 * The contents of this file are the property of Curity AB.
 * You may not copy or use this file, in either source code
 * or executable form, except in compliance with terms
 * set by Curity AB.
 *
 * For further information, please contact Curity AB.
 */

package io.curity.ssi.sdjwt

import io.curity.ssi.crypto.HashFunction
import io.curity.ssi.crypto.createHashFunction
import io.curity.ssi.json.JsonToNativeSerializationAdapter
import io.curity.ssi.json.NativeToJsonSerializationAdapter
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive

/**
 * Class that is able to conceal nodes of a [JsonObject].
 *
 * @property hashAlgorithmName the name of the used hash algorithm
 * @property hashFunction the name of the used hash algorithm
 * @property concealPredicate predicate that should return `true` on any property or array index that must
 *                              be concealed.
 */
/*
* Implementation note: the current design uses a predicate over path and JSON element pair to decide
* which parts of the JsonObject to conceal
* This design provides for greater flexibility on deciding with parts of the JSON object should be concealed.
* However, it also implies navigating the complete JSON object, which may not be desirable.
* We may need to revisit this decision, when we know more about usage patterns.
*/
class SdJwtConcealer(
    val hashAlgorithmName: String,
    val hashFunction: HashFunction,
    val concealPredicate: (JsonPath, JsonElement) -> Boolean
) {

    /**
     * The concealment result.
     * @property obj the object with the concealed properties and array indexes
     * @property disclosures the disclosure strings
     */
    data class Result<T>(
        val obj: T,
        val disclosures: List<Disclosure>
    )

    /**
     * Conceals a [JsonObject]
     */
    suspend fun conceal(
        obj: JsonObject,
    ): Result<JsonObject> {
        val disclosures = mutableListOf<Disclosure>()
        val objWithHashAlgClaim = obj.toMutableMap().also {
            it[SdJwtConstants.CLAIM_SD_ALG] = JsonPrimitive(hashAlgorithmName)
        }
        val concealedObject = concealObject(JsonObject(objWithHashAlgClaim), disclosures, null)
        return Result(
            concealedObject,
            disclosures.toList()
        )
    }

    suspend fun conceal(
        map: Map<String, *>,
        nativeToJsonSerializationAdapter: NativeToJsonSerializationAdapter = NativeToJsonSerializationAdapter.DEFAULT,
    ): Result<Map<*, *>> {
        val obj = nativeToJsonSerializationAdapter.convertValue(map) as? JsonObject
            ?: throw IllegalArgumentException("input should be a valid JSON object")
        val disclosures = mutableListOf<Disclosure>()
        val objWithHashAlgClaim = obj.toMutableMap().also {
            it[SdJwtConstants.CLAIM_SD_ALG] = JsonPrimitive(hashAlgorithmName)
        }
        val concealedObject = concealObject(JsonObject(objWithHashAlgClaim), disclosures, null)
        val concealedMap = JsonToNativeSerializationAdapter.convertValue(concealedObject) as Map<*, *>
        return Result(
            concealedMap,
            disclosures.toList()
        )
    }

    private suspend fun concealElement(
        element: JsonElement,
        disclosures: MutableList<Disclosure>,
        path: JsonPath,
    ) = when (element) {
        is JsonObject -> concealObject(element, disclosures, path)
        is JsonArray -> concealArray(element, disclosures, path)
        else -> element
    }

    private suspend fun concealObject(
        obj: JsonObject,
        disclosures: MutableList<Disclosure>,
        path: JsonPath?,
    ): JsonObject {
        val nonConcealedProperties = mutableListOf<Pair<String, JsonElement>>()
        val concealedProperties = mutableListOf<String>()
        obj.entries.forEach { (key, value) ->
            val propertyPath = path.withProperty(key)
            val handledValue = concealElement(value, disclosures, propertyPath)
            if (concealPredicate(propertyPath, value)) {
                val disclosure = Disclosure.Property.create(key, handledValue)
                disclosures.add(disclosure)
                concealedProperties.add(disclosure.toDisclosureHashString(hashFunction))
            } else {
                nonConcealedProperties.add(key to handledValue)
            }
        }
        if (concealedProperties.isNotEmpty()) {
            nonConcealedProperties.add(
                SdJwtConstants.CLAIM_SD to
                        JsonArray(concealedProperties.map { JsonPrimitive(it) })
            )
        }

        return JsonObject(nonConcealedProperties.toMap())
    }

    private suspend fun concealArray(
        array: JsonArray,
        disclosures: MutableList<Disclosure>,
        path: JsonPath,
    ): JsonArray = JsonArray(
        array.mapIndexed { index, elem ->
            val indexPath = path.withIndex(index)
            val concealedElem: JsonElement = concealElement(elem, disclosures, indexPath)
            if (concealPredicate(indexPath, elem)) {
                val disclosure = Disclosure.Value.create(concealedElem)
                disclosures.add(disclosure)
                JsonObject(
                    mapOf(
                        SdJwtConstants.CLAIM_ARRAY_VALUE to JsonPrimitive(disclosure.toDisclosureHashString(hashFunction))
                    )
                )
            } else {
                concealedElem
            }
        }
    )

    companion object {
        fun createUsingDefaultHashFunction(concealPredicate: (JsonPath, JsonElement) -> Boolean) =
            SdJwtConcealer(
                SdJwtConstants.DEFAULT_HASH_FUNCTION,
                createHashFunction(SdJwtConstants.DEFAULT_HASH_FUNCTION)
                    ?: throw IllegalStateException("Default hash function not supported"),
                concealPredicate,
            )
    }
}

