/*=======================================================================
 ***         THE CONTENT OF THIS WORK IS PROPRIETARY TO FEI S.A.S,                  ***
 ***                   A PART OF THERMO FISHER SCIENTIFIC,                          ***
 ***              AND IS DISTRIBUTED UNDER A LICENSE AGREEMENT.                     ***
 ***                                                                                ***
 ***  REPRODUCTION, DISCLOSURE,  OR USE,  IN WHOLE OR IN PART,  OTHER THAN AS       ***
 ***  SPECIFIED  IN THE LICENSE ARE  NOT TO BE  UNDERTAKEN  EXCEPT WITH PRIOR       ***
 ***  WRITTEN AUTHORIZATION OF FEI S.A.S, A PART OF THERMO FISHER SCIENTIFIC.       ***
 ***                                                                                ***
 ***                        RESTRICTED RIGHTS LEGEND                                ***
 ***  USE, DUPLICATION, OR DISCLOSURE BY THE GOVERNMENT OF THE CONTENT OF THIS      ***
 ***  WORK OR RELATED DOCUMENTATION IS SUBJECT TO RESTRICTIONS AS SET FORTH IN      ***
 ***  SUBPARAGRAPH (C)(1) OF THE COMMERCIAL COMPUTER SOFTWARE RESTRICTED RIGHT      ***
 ***  CLAUSE  AT FAR 52.227-19  OR SUBPARAGRAPH  (C)(1)(II)  OF  THE RIGHTS IN      ***
 ***  TECHNICAL DATA AND COMPUTER SOFTWARE CLAUSE AT DFARS 52.227-7013.             ***
 ***                                                                                ***
 ***   COPYRIGHT (C) 2021-2025 BY FEI S.A.S, A PART OF THERMO FISHER SCIENTIFIC,    ***
 ***                       BORDEAUX, FRANCE                                         ***
 ***                      ALL RIGHTS RESERVED                                       ***
 **=======================================================================*/
#pragma once

#include <ImageDev/Processing/GenericAlgorithm.h>
#include <ImageDev/ImageDevCppExports.h>
#include <ImageDev/Exception.h>
#include <iolink/Vector.h>
#include <iolink/view/ImageView.h>
#include <ImageDev/Data/Model/OnnxModel.h>

namespace imagedev
{
/// Computes a prediction on a two-dimensional image from an ONNX model and applies a post processing to generate a label or a binary image.
class IMAGEDEV_CPP_API OnnxPredictionSegmentation2d final : public GenericAlgorithm
{
public:
    /// The tensor layout expected as input by the model. The input image is automatically converted to this layout by the algorithm.
    enum DataFormat
    {
    /// The layout is organized with interlaced channels. For instance, if the input is a color image, each pixel presents is RGB components successively.
        NHWC = 0,
    /// The layout is organized with separated channels. Each channel is an individual plane.
        NCHW
    };
    /// The type of normalization to apply before computing the prediction. It is recommended to apply the same pre-processing as during the training.
    enum InputNormalizationType
    {
    /// No normalization is applied before executing the prediction.
        NONE = 0,
    /// A normalization is applied by subtracting the minimum and dividing by data range.
        MIN_MAX,
    /// A normalization is applied by subtracting the mean and dividing by the standard deviation.
        STANDARDIZATION
    };
    /// The scope for computing normalization (mean, standard deviation, minimum or maximum). This parameter is ignored if the normalization type is set to NONE.
    enum NormalizationScope
    {
    /// The normalization is applied globally on the input batch.
        GLOBAL = 0,
    /// The normalization is applied individually on each image of the input batch.
        PER_SLICE
    };

    // Command constructor.
    OnnxPredictionSegmentation2d();


    /// Gets the inputImage parameter.
    /// The input image. It can be a grayscale or color image, depending on the selected model.
    std::shared_ptr< iolink::ImageView > inputImage() const;
    /// Sets the inputImage parameter.
    /// The input image. It can be a grayscale or color image, depending on the selected model.
    void setInputImage( std::shared_ptr< iolink::ImageView > inputImage );

    /// Gets the inputOnnxModel parameter.
    /// The in memory ONNX model.
    OnnxModel::Ptr inputOnnxModel() const;
    /// Sets the inputOnnxModel parameter.
    /// The in memory ONNX model.
    void setInputOnnxModel( const OnnxModel::Ptr& inputOnnxModel );

    /// Gets the dataFormat parameter.
    /// The tensor layout expected as input by the model. The input image is automatically converted to this layout by the algorithm.
    OnnxPredictionSegmentation2d::DataFormat dataFormat() const;
    /// Sets the dataFormat parameter.
    /// The tensor layout expected as input by the model. The input image is automatically converted to this layout by the algorithm.
    void setDataFormat( const OnnxPredictionSegmentation2d::DataFormat& dataFormat );

    /// Gets the inputNormalizationType parameter.
    /// The type of normalization to apply before computing the prediction. It is recommended to apply the same pre-processing as during the training.
    OnnxPredictionSegmentation2d::InputNormalizationType inputNormalizationType() const;
    /// Sets the inputNormalizationType parameter.
    /// The type of normalization to apply before computing the prediction. It is recommended to apply the same pre-processing as during the training.
    void setInputNormalizationType( const OnnxPredictionSegmentation2d::InputNormalizationType& inputNormalizationType );

    /// Gets the normalizationRange parameter.
    /// The data range in which the input image is normalized before computing the prediction. It is recommended to apply the same pre-processing as during the training. This parameter is ignored if the normalization type is set to NONE.
    iolink::Vector2d normalizationRange() const;
    /// Sets the normalizationRange parameter.
    /// The data range in which the input image is normalized before computing the prediction. It is recommended to apply the same pre-processing as during the training. This parameter is ignored if the normalization type is set to NONE.
    void setNormalizationRange( const iolink::Vector2d& normalizationRange );

    /// Gets the normalizationScope parameter.
    /// The scope for computing normalization (mean, standard deviation, minimum or maximum). This parameter is ignored if the normalization type is set to NONE.
    OnnxPredictionSegmentation2d::NormalizationScope normalizationScope() const;
    /// Sets the normalizationScope parameter.
    /// The scope for computing normalization (mean, standard deviation, minimum or maximum). This parameter is ignored if the normalization type is set to NONE.
    void setNormalizationScope( const OnnxPredictionSegmentation2d::NormalizationScope& normalizationScope );

    /// Gets the tileSize parameter.
    /// The width and height in pixels of the sliding window. This size includes the user defined tile overlap. It must be a multiple of 2 to the power of the number of downsampling or upsampling layers.
    iolink::Vector2u32 tileSize() const;
    /// Sets the tileSize parameter.
    /// The width and height in pixels of the sliding window. This size includes the user defined tile overlap. It must be a multiple of 2 to the power of the number of downsampling or upsampling layers.
    void setTileSize( const iolink::Vector2u32& tileSize );

    /// Gets the tileOverlap parameter.
    /// The number of pixels used as overlap between the tiles. An overlap of zero may lead to artifacts in the prediction result. A non-zero overlap reduces such artifacts but increases the computation time.
    uint32_t tileOverlap() const;
    /// Sets the tileOverlap parameter.
    /// The number of pixels used as overlap between the tiles. An overlap of zero may lead to artifacts in the prediction result. A non-zero overlap reduces such artifacts but increases the computation time.
    void setTileOverlap( const uint32_t& tileOverlap );

    /// Gets the outputObjectImage parameter.
    /// The output image. Its dimensions, and calibration are forced to the same values as the input. Its interpretation is binary if the model produces one channel, label otherwise.
    std::shared_ptr< iolink::ImageView > outputObjectImage() const;
    /// Sets the outputObjectImage parameter.
    /// The output image. Its dimensions, and calibration are forced to the same values as the input. Its interpretation is binary if the model produces one channel, label otherwise.
    void setOutputObjectImage( std::shared_ptr< iolink::ImageView > outputObjectImage );

    // Method to launch the command.
    void execute();

};

/// Computes a prediction on a two-dimensional image from an ONNX model and applies a post processing to generate a label or a binary image.
/// @param inputImage The input image. It can be a grayscale or color image, depending on the selected model.
/// @param inputOnnxModel The in memory ONNX model.
/// @param dataFormat The tensor layout expected as input by the model. The input image is automatically converted to this layout by the algorithm.
/// @param inputNormalizationType The type of normalization to apply before computing the prediction. It is recommended to apply the same pre-processing as during the training.
/// @param normalizationRange The data range in which the input image is normalized before computing the prediction. It is recommended to apply the same pre-processing as during the training. This parameter is ignored if the normalization type is set to NONE.
/// @param normalizationScope The scope for computing normalization (mean, standard deviation, minimum or maximum). This parameter is ignored if the normalization type is set to NONE.
/// @param tileSize The width and height in pixels of the sliding window. This size includes the user defined tile overlap. It must be a multiple of 2 to the power of the number of downsampling or upsampling layers.
/// @param tileOverlap The number of pixels used as overlap between the tiles. An overlap of zero may lead to artifacts in the prediction result. A non-zero overlap reduces such artifacts but increases the computation time.
/// @param outputObjectImage The output image. Its dimensions, and calibration are forced to the same values as the input. Its interpretation is binary if the model produces one channel, label otherwise.
/// @return Returns the outputObjectImage output parameter.
IMAGEDEV_CPP_API 
std::shared_ptr< iolink::ImageView >
onnxPredictionSegmentation2d( std::shared_ptr< iolink::ImageView > inputImage,
                              OnnxModel::Ptr inputOnnxModel,
                              OnnxPredictionSegmentation2d::DataFormat dataFormat,
                              OnnxPredictionSegmentation2d::InputNormalizationType inputNormalizationType,
                              const iolink::Vector2d& normalizationRange,
                              OnnxPredictionSegmentation2d::NormalizationScope normalizationScope,
                              const iolink::Vector2u32& tileSize,
                              uint32_t tileOverlap,
                              std::shared_ptr< iolink::ImageView > outputObjectImage = nullptr );
} // namespace imagedev
