/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.labkit.ui.segmentation;

import bdv.export.ProgressWriter;
import bdv.export.ProgressWriterConsole;
import java.util.ArrayList;
import java.util.Objects;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imagej.axis.CalibratedAxis;
import net.imagej.axis.IdentityAxis;
import net.imglib2.Dimensions;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import org.apache.commons.lang3.ArrayUtils;
import org.scijava.Context;
import sc.fiji.labkit.pixel_classification.utils.SingletonContext;
import sc.fiji.labkit.ui.inputimage.DatasetInputImage;
import sc.fiji.labkit.ui.inputimage.ImgPlusViewsOld;
import sc.fiji.labkit.ui.models.CachedImageFactory;
import sc.fiji.labkit.ui.models.DefaultCachedImageFactory;
import sc.fiji.labkit.ui.segmentation.SegmentationUtils;
import sc.fiji.labkit.ui.segmentation.Segmenter;
import sc.fiji.labkit.ui.segmentation.weka.TrainableSegmentationSegmenter;
import sc.fiji.labkit.ui.utils.ParallelUtils;

public class SegmentationTool {
    private Segmenter segmenter = null;
    private Context context = null;
    private ProgressWriter progressWriter = new ProgressWriterConsole();
    private Boolean useGpu = null;
    private final CachedImageFactory cachedImageFactory = DefaultCachedImageFactory.getInstance();

    public SegmentationTool() {
    }

    public SegmentationTool(Segmenter segmenter) {
        this.segmenter = segmenter;
    }

    public void openModel(String classifierFile) {
        Context context = this.context != null ? this.context : SingletonContext.getInstance();
        TrainableSegmentationSegmenter segmenter = new TrainableSegmentationSegmenter(context);
        segmenter.openModel(classifierFile);
        this.setSegmenter(segmenter);
    }

    public void setSegmenter(Segmenter segmenter) {
        this.segmenter = segmenter;
        if (this.useGpu != null) {
            this.segmenter.setUseGpu(this.useGpu);
        }
    }

    public void setContext(Context context) {
        this.context = Objects.requireNonNull(context);
    }

    public void setProgressWriter(ProgressWriter progressWriter) {
        this.progressWriter = Objects.requireNonNull(progressWriter);
    }

    public void setUseGpu(boolean useGpu) {
        this.useGpu = useGpu;
        if (this.segmenter != null) {
            this.segmenter.setUseGpu(useGpu);
        }
    }

    public ImgPlus<UnsignedByteType> segment(ImgPlus<?> image) {
        return this.segment(image, new UnsignedByteType());
    }

    public <T extends IntegerType<?>> ImgPlus<T> segment(ImgPlus<?> image, T type) {
        ImgPlus<? extends NumericType<?>> imgPlus = new DatasetInputImage(image).imageForSegmentation();
        Img<T> outputImg = this.useCacheForSegmentation(imgPlus) ? this.calculateSegmentationOnCachedImg(imgPlus, type) : this.calculateSegmentation(imgPlus, type);
        ArrayList<CalibratedAxis> axes = new ArrayList<CalibratedAxis>(ImgPlusViewsOld.getCalibratedAxes(imgPlus));
        axes.removeIf(axis -> axis.type() == Axes.CHANNEL);
        return new ImgPlus(outputImg, "segmentation of " + image.getName(), axes.toArray(new CalibratedAxis[0]));
    }

    private boolean useCacheForSegmentation(ImgPlus<?> imgPlus) {
        return Intervals.numElements(imgPlus) > 100000000L;
    }

    private <T extends IntegerType<?>> Img<T> calculateSegmentation(ImgPlus<?> imgPlus, T type) {
        Interval outputInterval = SegmentationUtils.intervalNoChannels(imgPlus);
        int[] cellSize = this.segmenter.suggestCellSize(imgPlus);
        ArrayImg outputImg = new ArrayImgFactory((NativeType)type).create(Intervals.dimensionsAsLongArray((Dimensions)outputInterval));
        ParallelUtils.applyOperationOnCells(outputImg, cellSize, outputCell -> this.segmenter.segment(imgPlus, (RandomAccessibleInterval<? extends IntegerType<?>>)outputCell), this.progressWriter);
        return outputImg;
    }

    private <T extends IntegerType<?>> Img<T> calculateSegmentationOnCachedImg(ImgPlus<?> imgPlus, T type) {
        Img outputImg = (Img)Cast.unchecked(SegmentationUtils.createCachedSegmentation(this.segmenter, imgPlus, this.cachedImageFactory, (IntegerType)Cast.unchecked(type)));
        ParallelUtils.populateCachedImg(outputImg, this.progressWriter);
        return outputImg;
    }

    public ImgPlus<FloatType> probabilityMap(ImgPlus<?> image) {
        ImgPlus<? extends NumericType<?>> imgPlus = new DatasetInputImage(image).imageForSegmentation();
        Img<FloatType> outputImg = this.useCacheForProbabilityMap(imgPlus) ? this.calculateOnCachedImg(imgPlus) : this.calculateProbabilityMap(imgPlus);
        ArrayList<CalibratedAxis> axes = new ArrayList<CalibratedAxis>(ImgPlusViewsOld.getCalibratedAxes(imgPlus));
        axes.removeIf(axis -> axis.type() == Axes.CHANNEL);
        axes.add((CalibratedAxis)new IdentityAxis(Axes.CHANNEL));
        return new ImgPlus(outputImg, "probability map for " + image.getName(), axes.toArray(new CalibratedAxis[0]));
    }

    private boolean useCacheForProbabilityMap(ImgPlus<?> image) {
        int numberOfChannels = this.segmenter.classNames().size();
        return Intervals.numElements((Dimensions)SegmentationUtils.intervalNoChannels(image)) * (long)numberOfChannels > 100000000L;
    }

    private Img<FloatType> calculateOnCachedImg(ImgPlus<?> image) {
        Img<FloatType> outputImg = SegmentationUtils.createCachedProbabilityMap(this.segmenter, image, this.cachedImageFactory);
        ParallelUtils.populateCachedImg(outputImg, this.progressWriter);
        return outputImg;
    }

    private Img<FloatType> calculateProbabilityMap(ImgPlus<?> imgPlus) {
        int numberOfChannels = this.segmenter.classNames().size();
        long[] imageSize = ArrayUtils.add((long[])Intervals.dimensionsAsLongArray((Dimensions)SegmentationUtils.intervalNoChannels(imgPlus)), (long)numberOfChannels);
        int[] cellSize = ArrayUtils.add((int[])this.segmenter.suggestCellSize(imgPlus), (int)numberOfChannels);
        ArrayImg outputImg = ArrayImgs.floats((long[])imageSize);
        ParallelUtils.applyOperationOnCells(outputImg, cellSize, outputCell -> this.segmenter.predict(imgPlus, (RandomAccessibleInterval<? extends RealType<?>>)outputCell), this.progressWriter);
        return outputImg;
    }
}

