/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.evaluation;

import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import org.tribuo.DataSource;
import org.tribuo.Example;
import org.tribuo.Output;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.provenance.DataSourceProvenance;

public class TrainTestSplitter<T extends Output<T>> {
    private final DataSource<T> train;
    private final DataSource<T> test;
    private final DataSourceProvenance originalProvenance;
    private final long seed;
    private final double trainProportion;
    private final int size;

    public TrainTestSplitter(DataSource<T> data) {
        this(data, 1L);
    }

    public TrainTestSplitter(DataSource<T> data, long seed) {
        this(data, 0.7, seed);
    }

    public TrainTestSplitter(DataSource<T> data, double trainProportion, long seed) {
        this.seed = seed;
        this.trainProportion = trainProportion;
        this.originalProvenance = (DataSourceProvenance)data.getProvenance();
        ArrayList<Example> l = new ArrayList<Example>();
        for (Example example : data) {
            l.add(example);
        }
        this.size = l.size();
        Random rng = new Random(seed);
        Collections.shuffle(l, rng);
        int n = (int)(trainProportion * (double)l.size());
        this.train = new ListDataSource(l.subList(0, n), data.getOutputFactory(), new SplitDataSourceProvenance(this, true));
        this.test = new ListDataSource(l.subList(n, l.size()), data.getOutputFactory(), new SplitDataSourceProvenance(this, false));
    }

    public int totalSize() {
        return this.size;
    }

    public DataSource<T> getTrain() {
        return this.train;
    }

    public DataSource<T> getTest() {
        return this.test;
    }

    public static class SplitDataSourceProvenance
    implements DataSourceProvenance {
        private static final long serialVersionUID = 1L;
        private static final String SOURCE = "source";
        private static final String TRAIN_PROPORTION = "train-proportion";
        private static final String SEED = "seed";
        private static final String SIZE = "size";
        private static final String IS_TRAIN = "is-train";
        private final StringProvenance className;
        private final DataSourceProvenance innerSourceProvenance;
        private final DoubleProvenance trainProportion;
        private final LongProvenance seed;
        private final IntProvenance size;
        private final BooleanProvenance isTrain;

        <T extends Output<T>> SplitDataSourceProvenance(TrainTestSplitter<T> host, boolean isTrain) {
            this.className = new StringProvenance("class-name", host.getClass().getName());
            this.innerSourceProvenance = ((TrainTestSplitter)host).originalProvenance;
            this.trainProportion = new DoubleProvenance(TRAIN_PROPORTION, ((TrainTestSplitter)host).trainProportion);
            this.seed = new LongProvenance(SEED, ((TrainTestSplitter)host).seed);
            this.size = new IntProvenance(SIZE, ((TrainTestSplitter)host).size);
            this.isTrain = new BooleanProvenance(IS_TRAIN, isTrain);
        }

        public SplitDataSourceProvenance(Map<String, Provenance> map) {
            this.className = (StringProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)"class-name", StringProvenance.class, (String)SplitDataSourceProvenance.class.getSimpleName());
            this.innerSourceProvenance = (DataSourceProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)SOURCE, DataSourceProvenance.class, (String)SplitDataSourceProvenance.class.getSimpleName());
            this.trainProportion = (DoubleProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)TRAIN_PROPORTION, DoubleProvenance.class, (String)SplitDataSourceProvenance.class.getSimpleName());
            this.seed = (LongProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)SEED, LongProvenance.class, (String)SplitDataSourceProvenance.class.getSimpleName());
            this.size = (IntProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)SIZE, IntProvenance.class, (String)SplitDataSourceProvenance.class.getSimpleName());
            this.isTrain = (BooleanProvenance)ObjectProvenance.checkAndExtractProvenance(map, (String)IS_TRAIN, BooleanProvenance.class, (String)SplitDataSourceProvenance.class.getSimpleName());
        }

        public String getClassName() {
            return this.className.getValue();
        }

        public Iterator<Pair<String, Provenance>> iterator() {
            ArrayList<Pair> list = new ArrayList<Pair>();
            list.add(new Pair((Object)"class-name", (Object)this.className));
            list.add(new Pair((Object)SOURCE, (Object)this.innerSourceProvenance));
            list.add(new Pair((Object)TRAIN_PROPORTION, (Object)this.trainProportion));
            list.add(new Pair((Object)SEED, (Object)this.seed));
            list.add(new Pair((Object)SIZE, (Object)this.size));
            list.add(new Pair((Object)IS_TRAIN, (Object)this.isTrain));
            return list.iterator();
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof SplitDataSourceProvenance)) {
                return false;
            }
            SplitDataSourceProvenance pairs = (SplitDataSourceProvenance)o;
            return this.className.equals((Object)pairs.className) && this.innerSourceProvenance.equals(pairs.innerSourceProvenance) && this.trainProportion.equals((Object)pairs.trainProportion) && this.seed.equals((Object)pairs.seed) && this.size.equals((Object)pairs.size) && this.isTrain.equals((Object)pairs.isTrain);
        }

        public int hashCode() {
            return Objects.hash(this.className, this.innerSourceProvenance, this.trainProportion, this.seed, this.size, this.isTrain);
        }

        public String toString() {
            return "SplitDataSourceProvenance(className=" + this.className + ",innerSourceProvenance=" + this.innerSourceProvenance + ",trainProportion=" + this.trainProportion + ",seed=" + this.seed + ",size=" + this.size + ",isTrain=" + this.isTrain + ')';
        }
    }
}

