import * as tf_conv from '@tensorflow/tfjs-converter';
import * as tfc from '@tensorflow/tfjs-core';
import * as tfb_cpu from '@tensorflow/tfjs-backend-cpu'; // As backup
//import * as tfb_wasm from '@tensorflow/tfjs-backend-wasm';
import * as tfb_webgl from '@tensorflow/tfjs-backend-webgl';

/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
// As backup
//import * as tfb_wasm from '@tensorflow/tfjs-backend-wasm';
const MOBILENET_MODEL_PATH = '../js/kalaaju/fish_model/model.json';
const IMAGE_WIDTH = 480;
const IMAGE_HEIGHT = 320;
const PREDICTION_THRESHOLD = 0.2; // at least this much probability needed in class to show

const TOPK_PREDICTIONS = 4; // max number of predictions to show

const THRESHOLD_DEBUG = 0.0; // at least this much probability needed in class to show

const TOPK_DEBUG = 5; // max number of predictions to show

const CLASSES = {
    0: 'Ahven',
    1: 'Angerjas',
    2: 'Atlandi-pelamiid',
    3: 'Emakala',
    4: 'Euroopa-ansoovis',
    5: 'Euroopa-mõrukas',
    6: 'Forell',
    7: 'Harilik-huntahven',
    8: 'Harilik-koger',
    9: 'Harilik-makrell',
    10: 'Harilik-pakslaup',
    11: 'Harjus',
    12: 'Haug',
    13: 'Hink',
    14: 'Hõbekoger',
    15: 'Jõesilm',
    16: 'Kammeljas',
    17: 'Karpkala',
    18: 'Kaugida-unimudil',
    19: 'Kiisk',
    20: 'Kilu',
    21: 'Kirju-pakslaup',
    22: 'Kirjumudil',
    23: 'Kitsashuul-hallkefaal',
    24: 'Koha',
    25: 'Latikas',
    26: 'Lepamaim',
    27: 'Lest',
    28: 'Linask',
    29: 'Luts',
    30: 'Luukarits',
    31: 'Lõhe',
    32: 'Madunõel',
    33: 'Merihärg',
    34: 'Merilest',
    35: 'Merinõel',
    36: 'Meripühvel',
    37: 'Merivarblane',
    38: 'Mudamaim',
    39: 'Must-mudil',
    40: 'Muu-kala',
    41: 'Mõõkkala',
    42: 'Neljapoiseluts',
    43: 'Nolgus',
    44: 'Nugakala',
    45: 'Nurg',
    46: 'Ogalik',
    47: 'Pakshuul-hallkefaal',
    48: 'Pisimudilake',
    49: 'Pullukala',
    50: 'Raudkiisk',
    51: 'Roosärg',
    52: 'Räim',
    53: 'Rääbis',
    54: 'Rünt',
    55: 'Siig',
    56: 'Soomuslest',
    57: 'Suttlimusk',
    58: 'Säga',
    59: 'Säinas',
    60: 'Särg',
    61: 'Teib',
    62: 'Tint',
    63: 'Tippviidikas',
    64: 'Tobias',
    65: 'Trulling',
    66: 'Turb',
    67: 'Tursk',
    68: 'Tuulehaug',
    69: 'Tõugjas',
    70: 'Valgeamuur',
    71: 'Viidikas',
    72: 'Vikerforell',
    73: 'Vimb',
    74: 'Vingerjas',
    75: 'Vinträim',
    76: 'Väike-mudilake',
    77: 'Võikala',
    78: 'Võldas',
    79: 'other',
    80: 'Ümarmudil'
};

class Classifier {
    constructor() {
        this.mobilenet = null;
        this.load_promise = this.mobilenetDemo();
    }

    getExpectedSize() {
        return [IMAGE_WIDTH, IMAGE_HEIGHT];
    }

    async mobilenetDemo() {
        this.mobilenet = await tf_conv.loadGraphModel(MOBILENET_MODEL_PATH); // Warmup the model. This isn't necessary, but makes the first prediction
        // faster. Call `dispose` to release the WebGL memory allocated for the return
        // value of `predict`.

        this.mobilenet.predict(tfc.zeros([1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])).dispose();
    }

    async predict(imgData) {
        if (this.mobilenet == null) await this.load_promise;
        const logits = tfc.tidy(() => {
            // tfc.browser.fromPixels() returns a Tensor from an image element.
            const img = tfc.browser.fromPixels(imgData); // Normalize the image from [0, 255] to [-1, 1].
            //const offset = tfc.scalar(127.5);
            //const normalized = img.sub(offset).div(offset);
            // const normalized = img.div(255.0);

            const normalized = img; // Reshape to a single-element batch so we can pass it to predict.
            const batched = tfc.cast(tfc.reshape(normalized,[1, IMAGE_HEIGHT, IMAGE_WIDTH, 3]), "float32");
            return this.mobilenet.predict(batched);
        });
        const debug_classes = await this.getTopKClasses(logits, TOPK_DEBUG, THRESHOLD_DEBUG); // Write debug output to log

        console.log("Network results:");
        debug_classes.forEach(c => {
           console.log(c.className + ': ' + c.probability + '%');
        });
        const classes = await this.getTopKClasses(logits, TOPK_PREDICTIONS, PREDICTION_THRESHOLD);
        return classes;
    }

    async getTopKClasses(logits, topK, thresh) {
        const values = await logits.data();
        const valuesAndIndices = [];

        for (let i = 0; i < values.length; i++) {
            valuesAndIndices.push({
                value: values[i],
                index: i
            });
        }

        valuesAndIndices.sort((a, b) => {
            return b.value - a.value;
        }); // Determine how many to show:

        for (let i = 0; i < topK; i++) if (valuesAndIndices[i].value < thresh) {
            topK = i;
            break;
        }

        const topkValues = new Float32Array(topK);
        const topkIndices = new Int32Array(topK);

        for (let i = 0; i < topK; i++) {
            if (valuesAndIndices[i].value < thresh) break;
            topkValues[i] = valuesAndIndices[i].value;
            topkIndices[i] = valuesAndIndices[i].index;
        }

        const topClassesAndProbs = [];

        for (let i = 0; i < topkIndices.length; i++) {
            topClassesAndProbs.push({
                className: CLASSES[topkIndices[i]],
                probability: Math.round(100 * topkValues[i])
            });
        }

        return topClassesAndProbs;
    }
}

export default Classifier;