// create model class
import * as tf from "@tensorflow/tfjs";
import {Model} from "./model";


export class ChestXrayModel extends Model{
    constructor() {
        super();
        this.model_name = "Chest X-ray model";
        this.model_weights_url = 'https://raw.githubusercontent.com/mlmed/chester-xray/master/models/xrv-all-45rot15trans15scale/model.json';
        this.model_config_url = 'https://raw.githubusercontent.com/mlmed/chester-xray/master/models/xrv-all-45rot15trans15scale/config.json';
        this.model = null;
        this.model_config = null;
        this.model_ready = false;
        this.model_config_ready = false;
        this.display_gradients = true;
    }

    // load model weights
    async load_model() {
        this.model = await tf.loadGraphModel(this.model_weights_url);
        this.model_ready = true;
        return this.model
    }

    async load_config() {
        this.model_config = await fetch(this.model_config_url).then(res => res.json())
        this.model_config_ready = true;
        return this.model_config
    }

    async load() {
        await this.load_model();
        await this.load_config();
        // return joint promise when both are ready
        return Promise.all([this.model_ready, this.model_config_ready])
    }

    get_status() {
        if (this.model_ready && this.model_config_ready) {
            return 1;
        } else {
            return 0;
        }
    }

    get_labels() {
        return this.model_config.LABELS;
    }

    prepare_image_resize_crop(imgElement, size){
        let orig_width = imgElement.width
        let orig_height = imgElement.height
        if (orig_width < orig_height){
            imgElement.width = size
            imgElement.height = Math.floor(size*orig_height/orig_width)
        }else{
            imgElement.height = size
            imgElement.width = Math.floor(size*orig_width/orig_height)
        }

        let img = tf.browser.fromPixels(imgElement).toFloat();

        let hOffset = Math.floor(img.shape[1]/2 - size/2)
        let wOffset = Math.floor(img.shape[0]/2 - size/2)

        let img_cropped = img.slice([wOffset,hOffset],[size,size])

        img_cropped = img_cropped.mean(2).div(255)

        return img_cropped
    }

    preprocess_input(input) {
        // verify input is an image element
        if (!(input instanceof HTMLImageElement)) {
            throw new Error('Input is not an image element.');
        }
        const model_input_width = this.model_config.IMAGE_SIZE
        const model_input_height = this.model_config.IMAGE_SIZE

        let reshaped_tensor = this.prepare_image_resize_crop(input, this.model_config.IMAGE_SIZE);
	    reshaped_tensor = reshaped_tensor.mul(2).sub(1).mul(tf.scalar(this.model_config.IMAGE_SCALE));
        reshaped_tensor = reshaped_tensor.reshape([1, 1, model_input_width, model_input_height])

        return reshaped_tensor
    }

    distOverClasses(values){
        const topClassesAndProbs = [];
        for (let i = 0; i < values.length; i++) {

            let value_normalized = values[i]
            if (values[i]<this.model_config.OP_POINT[i]){
                value_normalized = values[i]/(this.model_config.OP_POINT[i]*2);
            }else{
                value_normalized = 1-((1-values[i])/((1-(this.model_config.OP_POINT[i]))*2));
                if (value_normalized>0.6 & this.model_config.SCALE_UPPER){
                value_normalized = Math.min(1, value_normalized*this.model_config.SCALE_UPPER);
                }
            }
            // console.log(model_config.LABELS[i] + ",pred:" + values[i] + "," + "OP_POINT:" + model_config.OP_POINT[i] + "->normalized:" + value_normalized);

            topClassesAndProbs.push({
                className: this.model_config.LABELS[i],
                probability: value_normalized
            });
        }
        return topClassesAndProbs
    }

    async predict(input) {
        let predicted_activations = tf.tidy(() => {
                return this.model.execute(input, [this.model_config.OUTPUT_NODE])
            }
        );
        predicted_activations = await predicted_activations.data()

        const predicted_class_probabilities = await this.distOverClasses(predicted_activations)
        // as some classes can be empty ("") we remove them
        const non_empty_predicted_class_probabilities = predicted_class_probabilities.filter(c => c.className != '')
        return [non_empty_predicted_class_probabilities, []]
    }

    get_top_class_index(predicted_class_probabilities) {
        // find name & index of non empty class with max probability
        // => conversion between absolute index and index of non empty class has to be done as some labels of the model are empty
        const max_class = predicted_class_probabilities[
            predicted_class_probabilities.map(c => c.probability).indexOf(Math.max(...predicted_class_probabilities.map(c => c.probability)))
            ].className
        const max_class_idx = predicted_class_probabilities.map(c => c.className).indexOf(max_class)
        return max_class_idx
    }

    async compute_gradients(input, idx) {
        return tf.tidy(() => {
            let chestgrad = tf.grad(x => this.model.predict(x).reshape([-1]).gather(idx))

            const grad = chestgrad(input);

            const layer = grad.mean(0).abs().max(0)
            return layer.div(layer.max())
    });
    }

}
