package com.alibaba.pairec.linucb;

import java.io.Serializable;
import org.apache.commons.lang3.tuple.Pair;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:com/alibaba/pairec/linucb/HybridArmEJML.class */
public class HybridArmEJML implements Serializable {
    private SimpleMatrix matrixA;
    private SimpleMatrix matrixB;
    private SimpleMatrix vectorB;
    private SimpleMatrix transposeMatrixB = null;
    private SimpleMatrix invertMatrixA = null;

    public HybridArmEJML(int i, int i2) {
        this.matrixA = SimpleMatrix.identity(i);
        this.matrixB = new SimpleMatrix(i, i2);
        this.matrixB.zero();
        this.vectorB = new SimpleMatrix(i, 1);
        this.vectorB.zero();
    }

    public int getNonSharedFeatureLength() {
        if (null == this.matrixB) {
            return 0;
        }
        return this.matrixB.numRows();
    }

    public int getSharedFeatureLength() {
        if (null == this.matrixB) {
            return 0;
        }
        return this.matrixB.numCols();
    }

    public void learn(double[] dArr, SimpleMatrix simpleMatrix, double d) {
        if (null == dArr || dArr.length == 0) {
            return;
        }
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(dArr.length, 1, true, dArr);
        this.matrixA = this.matrixA.plus(simpleMatrix2.mult(simpleMatrix2.transpose()));
        this.matrixB = this.matrixB.plus(simpleMatrix2.mult(simpleMatrix));
        if (d != 0.0d) {
            this.vectorB = this.vectorB.plus(simpleMatrix2.scale(d));
        }
        this.transposeMatrixB = null;
        this.invertMatrixA = null;
    }

    public Pair<SimpleMatrix, SimpleMatrix> computeUpdate() {
        if (null == this.transposeMatrixB) {
            this.transposeMatrixB = this.matrixB.transpose();
        }
        if (null == this.invertMatrixA) {
            this.invertMatrixA = this.matrixA.pseudoInverse();
        }
        SimpleMatrix mult = this.transposeMatrixB.mult(this.invertMatrixA);
        return Pair.of(mult.mult(this.matrixB), mult.mult(this.vectorB));
    }

    public double[] getInvertMatrixA() {
        if (null == this.invertMatrixA) {
            this.invertMatrixA = this.matrixA.pseudoInverse();
        }
        return this.invertMatrixA.getDDRM().data;
    }

    public double[] getMatrixB() {
        return this.matrixB.getDDRM().data;
    }

    public double[] getVectorB() {
        return this.vectorB.getDDRM().data;
    }
}
