package com.alibaba.pairec.linucb;

import com.alibaba.pairec.io.Compress;
import com.alibaba.pairec.io.Hologres;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.table.data.GenericArrayData;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.util.Collector;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:com/alibaba/pairec/linucb/HybridLearner.class */
public class HybridLearner extends AbstractLearner {
    private transient ValueState<HybridArmEJML> armState;
    private long globalModelId = -1;
    private Hologres storage;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // com.alibaba.pairec.linucb.AbstractLearner
    public void open(Configuration configuration) {
        super.open(configuration);
        RuntimeContext runtimeContext = getRuntimeContext();
        ParameterTool fromMap = ParameterTool.fromMap(runtimeContext.getExecutionConfig().getGlobalJobParameters().toMap());
        this.globalModelId = fromMap.getLong("global.model.id", this.globalModelId);
        StateTtlConfig build = StateTtlConfig.newBuilder(Time.hours(1L)).setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite).setStateVisibility(StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp).cleanupIncrementally(1000, true).cleanupInRocksdbCompactFilter(1000L).build();
        ValueStateDescriptor valueStateDescriptor = new ValueStateDescriptor("armBuilder", HybridArmEJML.class);
        valueStateDescriptor.enableTimeToLive(build);
        this.armState = runtimeContext.getState(valueStateDescriptor);
        this.storage = Hologres.getInstance(fromMap);
    }

    public void onTimer(long j, KeyedProcessFunction<Long, Event, RowData>.OnTimerContext onTimerContext, Collector<RowData> collector) throws Exception {
        long longValue = ((Long) onTimerContext.getCurrentKey()).longValue();
        int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
        HybridArmEJML hybridArmEJML = (HybridArmEJML) this.armState.value();
        int i = 0;
        SimpleMatrix simpleMatrix = null;
        SimpleMatrix simpleMatrix2 = null;
        Iterator it = this.events.iterator();
        while (it.hasNext()) {
            Event event = (Event) ((Map.Entry) it.next()).getValue();
            if (!$assertionsDisabled && longValue != event.getArmId()) {
                throw new AssertionError();
            }
            long longValue2 = this.durationMsec.getOrDefault(event.getEventType(), Long.valueOf(this.defaultDuration)).longValue();
            if (j - event.getEventTime() >= Long.min(this.durationMsec.getOrDefault("check_feature", Long.valueOf(longValue2)).longValue(), longValue2)) {
                Pair<double[], double[]> features = event.getFeatures();
                double[] left = features.getLeft();
                if (null == left || left.length == 0 || left.length > 10000) {
                    it.remove();
                    this.filterCounter.inc();
                    long count = this.filterCounter.getCount();
                    if (count % 100 == 0) {
                        logger.warn(String.format("[%d] filter empty feature events:%d%n", Integer.valueOf(indexOfThisSubtask), Long.valueOf(count)));
                        logger.warn("empty feature:" + event);
                    }
                } else {
                    double[] right = features.getRight();
                    if (null == right || right.length == 0 || right.length > 10000) {
                        it.remove();
                        this.filterCounter.inc();
                        long count2 = this.filterCounter.getCount();
                        if (count2 % 100 == 0) {
                            logger.warn(String.format("[%d] filter empty feature events:%d%n", Integer.valueOf(indexOfThisSubtask), Long.valueOf(count2)));
                            logger.warn("empty share feature:" + event);
                        }
                    } else if (j - event.getEventTime() >= longValue2) {
                        if (i == 0) {
                            logger.info(String.format("[%d] learning arm %d", Integer.valueOf(indexOfThisSubtask), Long.valueOf(longValue)));
                        }
                        if (null == hybridArmEJML) {
                            hybridArmEJML = new HybridArmEJML(left.length, right.length);
                        } else {
                            int nonSharedFeatureLength = hybridArmEJML.getNonSharedFeatureLength();
                            int sharedFeatureLength = hybridArmEJML.getSharedFeatureLength();
                            if (nonSharedFeatureLength != left.length || sharedFeatureLength != right.length) {
                                hybridArmEJML = new HybridArmEJML(left.length, right.length);
                                logger.warn(String.format("Feature length change from %d to %d, Shared feature length change from %d to %d%n", Integer.valueOf(nonSharedFeatureLength), Integer.valueOf(left.length), Integer.valueOf(sharedFeatureLength), Integer.valueOf(right.length)));
                            }
                        }
                        if (simpleMatrix == null || simpleMatrix2 == null) {
                            Pair<double[][], double[]> modelArgs = this.storage.getModelArgs(this.globalModelId, right.length);
                            if (modelArgs == null) {
                                simpleMatrix = SimpleMatrix.identity(right.length);
                                simpleMatrix2 = new SimpleMatrix(right.length, 1);
                                simpleMatrix2.zero();
                                logger.warn("[" + indexOfThisSubtask + "] new global param");
                            } else {
                                double[] right2 = modelArgs.getRight();
                                simpleMatrix2 = new SimpleMatrix(right2.length, 1, true, right2);
                                simpleMatrix = new SimpleMatrix(modelArgs.getLeft());
                            }
                        }
                        int numRows = simpleMatrix2.numRows();
                        if (numRows != right.length) {
                            logger.warn("[" + indexOfThisSubtask + "] global param dim change from " + numRows + " to " + right.length);
                            simpleMatrix = SimpleMatrix.identity(right.length);
                            simpleMatrix2 = new SimpleMatrix(right.length, 1);
                            simpleMatrix2.zero();
                        }
                        double reward = event.getReward();
                        if (i == 0) {
                            Pair<SimpleMatrix, SimpleMatrix> computeUpdate = hybridArmEJML.computeUpdate();
                            simpleMatrix = simpleMatrix.plus(computeUpdate.getLeft());
                            simpleMatrix2 = simpleMatrix2.plus(computeUpdate.getRight());
                        }
                        SimpleMatrix simpleMatrix3 = new SimpleMatrix(right.length, 1, true, right);
                        SimpleMatrix transpose = simpleMatrix3.transpose();
                        hybridArmEJML.learn(left, transpose, reward);
                        simpleMatrix = simpleMatrix.plus(simpleMatrix3.mult(transpose));
                        if (reward != 0.0d) {
                            simpleMatrix2 = simpleMatrix2.plus(simpleMatrix3.scale(reward));
                        }
                        it.remove();
                        i++;
                        this.learnCounter.inc();
                        long count3 = this.learnCounter.getCount();
                        if (count3 % 10 == 0) {
                            logger.info(String.format("[%d] learned events:%d%n", Integer.valueOf(indexOfThisSubtask), Long.valueOf(count3)));
                        }
                    }
                }
            }
        }
        if (null != simpleMatrix && null != simpleMatrix2 && i > 0) {
            Pair<SimpleMatrix, SimpleMatrix> computeUpdate2 = hybridArmEJML.computeUpdate();
            SimpleMatrix simpleMatrix4 = (SimpleMatrix) simpleMatrix.minus(computeUpdate2.getLeft());
            if (!this.storage.putModelArgs(this.globalModelId, simpleMatrix4.pseudoInverse().getDDRM().data, simpleMatrix4.getDDRM().data, ((SimpleMatrix) simpleMatrix2.minus(computeUpdate2.getRight())).getDDRM().data)) {
                logger.error("[" + indexOfThisSubtask + "] put global model failed.");
            }
        }
        if (null != hybridArmEJML) {
            this.armState.update(hybridArmEJML);
            int[] convertToIntArray = Compress.convertToIntArray(Compress.compressArray(hybridArmEJML.getInvertMatrixA()));
            int[] convertToIntArray2 = Compress.convertToIntArray(Compress.compressArray(hybridArmEJML.getVectorB()));
            int[] convertToIntArray3 = Compress.convertToIntArray(Compress.compressArray(hybridArmEJML.getMatrixB()));
            GenericRowData genericRowData = new GenericRowData(5);
            genericRowData.setField(0, Long.valueOf(longValue));
            genericRowData.setField(1, Long.valueOf(System.currentTimeMillis()));
            genericRowData.setField(2, new GenericArrayData(convertToIntArray));
            genericRowData.setField(3, new GenericArrayData(convertToIntArray2));
            genericRowData.setField(4, new GenericArrayData(convertToIntArray3));
            collector.collect(genericRowData);
            this.emitCounter.inc();
            long count4 = this.emitCounter.getCount();
            if (count4 % 100 == 0) {
                logger.info(String.format("[%d] emit %d arms:%n", Integer.valueOf(indexOfThisSubtask), Long.valueOf(count4)));
            }
        }
    }

    static {
        $assertionsDisabled = !HybridLearner.class.desiredAssertionStatus();
    }
}
