import {Model} from "./model";
import * as ort from 'onnxruntime-web';
import {Tensor} from 'onnxruntime-web';
import * as tf from "@tensorflow/tfjs";


export class PedWristXrayModel extends Model {
    constructor() {
        super();
        this.model_name = "Pediatric wrist X-ray model";
        // TODO: this should be replace by static file server
        // this.model_weights_url = 'https://github.com/mdciri/YOLOv7-Bone-Fracture-Detection/releases/download/trained-models/yolov7-p6-bonefracture.onnx';
        this.model_weights_url = 'https://api.kassandra.julianklug.com/pediatric_wrist_xray';
        this.model_config_url = '';
        this.model = null;
        this.model_config = {
            'input_size': [1, 3, 640, 640],
            'labels': ["Bone Anomaly", "Bone Lesion", "Foreign Body", "Fracture", "Metal", "Periosteal Reaction",
                "Pronator Sign", "Soft Tissue", "Text"],
            'confidence_threshold': 0.3
        }
        this.model_ready = false;
        this.model_config_ready = true;
        this.display_gradients = false;
    }

    // load model weights
    async load_model() {
        let response = await fetch(this.model_weights_url);
        console.log('response', response)
        let downloadedModelBlob = await response.blob();
        console.log('downloadedModelBlob', downloadedModelBlob)
        let downloadedModelBuffer = await downloadedModelBlob.arrayBuffer();
        console.log('downloadedModelBuffer', downloadedModelBuffer)

        // let downloadedModelBuffer = './static/yolov7-p6-bonefracture.onnx';

        // For now only wasm is supported
        // const sessionOption = { executionProviders: ['wasm'] };

        // create a new session and load the model.
        this.model = await ort.InferenceSession.create(downloadedModelBuffer);
        this.model_ready = true;
        return this.model
    }

    async load() {
        await this.load_model();
        // return joint promise when both are ready
        return Promise.all([this.model_ready, this.model_config_ready])
    }

     get_status() {
        if (this.model_ready && this.model_config_ready) {
            return 1;
        } else {
            return 0;
        }
    }

    get_labels() {
        return this.model_config.labels;
    }

    async preprocess_input(input) {
        // verify input is an image element
        if (!(input instanceof HTMLImageElement)) {
            throw new Error('Input is not an image element.');
        }

        // access image data from input, load it into a tensor and reshape it and scale it
        return await this.getImageTensorFromImageElement(input, this.model_config.input_size);
    }

    // static image preprocessing methods adapted from
    // https://onnxruntime.ai/docs/tutorials/web/classify-images-nextjs-github-template.html
    async loadImage(img, width = 640, height= 640) {
        // load the image and resize it.
        const imageData = tf.browser.fromPixels(img).toFloat()
        const resizedImageData = tf.image.resizeBilinear(imageData, [width, height])
        const resizedImageArray = await resizedImageData.data()
        return resizedImageArray;
    };

    imageDataToTensor(image, dims) {
      // 1. Get buffer data from image and create R, G, and B arrays.
      let imageBufferData = image
      const [redArray, greenArray, blueArray] = new Array(new Array(), new Array(), new Array());

      // 2. Loop through the image buffer and extract the R, G, and B channels
      for (let i = 0; i < imageBufferData.length; i += 3) {
        redArray.push(imageBufferData[i]);
        greenArray.push(imageBufferData[i + 1]);
        blueArray.push(imageBufferData[i + 2]);
      }

      // 3. Concatenate RGB to transpose [W, H, 3] -> [3, W, H] to a number array
      const transposedData = redArray.concat(greenArray).concat(blueArray);

      // 4. convert to float32
      let i, l = transposedData.length; // length, we need this for the loop
      // create the Float32Array size 3 * width * height for these dimensions output
      const float32Data = new Float32Array(dims[1] * dims[2] * dims[3]);
      for (i = 0; i < l; i++) {
        float32Data[i] = transposedData[i] / 255.0; // convert to float
      }
      // this implicitly already does img = (img - img_mean) * img_scale given img_mean=0, img_scale=1/255

      // 5. create the tensor object from onnxruntime-web.
      const inputTensor = new Tensor("float32", float32Data, dims);

      return inputTensor;
    }


    async getImageTensorFromImageElement(img, dims=  [1, 3, 640, 640]) {
      // 1. load the image
      const imageData = await this.loadImage(img, dims[2], dims[3]);
      // 2. convert to tensor
      const imageTensor = this.imageDataToTensor(imageData, dims);
      // 3. return the tensor
      return imageTensor;
    };

    async predict(input) {
        const feeds = { [this.model.inputNames[0]]: input };

        // feed inputs and run
        const results = await this.model.run(feeds);
        console.log('results', results)
        const extracted_results = results[this.model.outputNames[0]];
        const n_detections = extracted_results.dims[0];
        const n_attributes = extracted_results.dims[1];
        const result_data = extracted_results.data;

        const classProbabilities = this.model_config.labels.map((label, i) => {
            return {
                className: label,
                probability: 0
            }
        });
        const detections = [];
        for (let i = 0; i < n_detections; i++) {
            const detection = result_data.slice(i * n_attributes, (i + 1) * n_attributes);
            const probability = detection[4];
            const label_idx = detection[5];
            const boundingBox = detection.slice(0, 4);

            detections.push({
                className: this.model_config.labels[label_idx],
                probability: probability,
                boundingBox: boundingBox
            })

            classProbabilities[label_idx].probability = Math.max(classProbabilities[label_idx].probability, probability);
        }

        return [classProbabilities, detections]
    }

    get_top_class_index(predicted_class_probabilities) {
         const max_class = predicted_class_probabilities[
            predicted_class_probabilities.map(c => c.probability).indexOf(Math.max(...predicted_class_probabilities.map(c => c.probability)))
            ].className
        const max_class_idx = predicted_class_probabilities.map(c => c.className).indexOf(max_class)
        return max_class_idx
    }

    async compute_gradients(input, idx) {
        return undefined
    }

}