package net.sourceforge.evoj.neural;

import java.util.EnumMap;
import java.util.Map;
import net.sourceforge.evoj.PropertyHandler;
import net.sourceforge.evoj.util.Matrix;

/* loaded from: input_file:net/sourceforge/evoj/neural/MultiplicativeRecurrentLayer.class */
public class MultiplicativeRecurrentLayer extends AbsctractLayer {
    private float[] initialState;
    private float[][] fx;
    private PropertyHandler<Float>[] fxHandlers;
    private float[][] fh;
    private PropertyHandler<Float>[] fhHandlers;
    private float[][] hf;
    private PropertyHandler<Float>[] hfHandlers;
    private float[][] hx;
    private PropertyHandler<Float>[] hxHandlers;
    private PropertyHandler<Float>[] initialStateHandlers;
    private final int factorCount;
    private final int inputCount;
    private final int outputCount;
    private float[] tmpFX;
    private float[] tmpFH;
    private float[] tmpOut;
    private float[] lastOutput;
    private final float[] totalInputs;
    private final float[] outputs;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MultiplicativeRecurrentLayer(LayerDescriptor layerDescriptor, HandlerHelper handlerHelper) {
        super(layerDescriptor, handlerHelper);
        this.factorCount = this.layerModel.getStateModel().getFactorCount();
        this.outputCount = layerDescriptor.getOutputCount();
        this.inputCount = layerDescriptor.getInputCount();
        this.fx = new float[this.factorCount][this.inputCount];
        this.fxHandlers = initHandlers(handlerHelper, this.fx, "fxFactors");
        this.fh = new float[this.factorCount][this.outputCount];
        this.fhHandlers = initHandlers(handlerHelper, this.fh, "fhFactors");
        this.hf = new float[this.outputCount][this.factorCount];
        this.hfHandlers = initHandlers(handlerHelper, this.hf, "hfFactors");
        this.hx = new float[this.outputCount][this.inputCount];
        this.hfHandlers = initHandlers(handlerHelper, this.hx, "weights");
        this.initialStateHandlers = doInit(handlerHelper, this.outputCount, "initialState");
        this.tmpFX = new float[this.fx.length];
        this.tmpFH = new float[this.fh.length];
        this.tmpOut = new float[this.outputCount];
        this.lastOutput = new float[this.outputCount];
        this.totalInputs = new float[this.outputCount];
        this.outputs = new float[this.outputCount];
    }

    @Override // net.sourceforge.evoj.neural.NeuronLayer
    public float[] getOutputs(float[] fArr) {
        float[] mult = Matrix.mult(this.fx, fArr, this.tmpFX);
        this.tmpFH = Matrix.mult(this.fh, this.lastOutput, this.tmpFH);
        for (int i = 0; i < this.tmpFH.length; i++) {
            int i2 = i;
            mult[i2] = mult[i2] * this.tmpFH[i];
        }
        this.tmpOut = Matrix.mult(this.hf, mult, this.tmpOut);
        Matrix.mult(this.hx, fArr, this.totalInputs);
        add(this.totalInputs, this.tmpOut);
        this.function.calc(this.totalInputs, this.outputs);
        System.arraycopy(this.outputs, 0, this.lastOutput, 0, this.outputs.length);
        return this.outputs;
    }

    private float[] add(float[] fArr, float[] fArr2) {
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = fArr[i] + fArr2[i];
        }
        return fArr;
    }

    @Override // net.sourceforge.evoj.neural.NeuronLayer
    public void readDna() {
        readMatrix(this.fx, this.fxHandlers);
        readMatrix(this.fh, this.fhHandlers);
        readMatrix(this.hf, this.hfHandlers);
        readMatrix(this.hx, this.hxHandlers);
        readArray(this.initialState, this.initialStateHandlers);
    }

    @Override // net.sourceforge.evoj.neural.NeuronLayer
    public void writeDna() {
        writeMatrix(this.fx, this.fxHandlers);
        writeMatrix(this.fh, this.fhHandlers);
        writeMatrix(this.hf, this.hfHandlers);
        writeMatrix(this.hx, this.hxHandlers);
        writeArray(this.initialState, this.initialStateHandlers);
    }

    public float[] getInitialState() {
        return this.initialState;
    }

    public float[][] getFxMatrix() {
        return this.fx;
    }

    public float[][] getFhMatrix() {
        return this.fh;
    }

    public float[][] getHfMatrix() {
        return this.hf;
    }

    public float[][] getHxMatrix() {
        return this.hx;
    }

    @Override // net.sourceforge.evoj.neural.NeuronLayer
    public Map<InternalVariableType, Object> getInternalVars() {
        EnumMap enumMap = new EnumMap(InternalVariableType.class);
        enumMap.put((EnumMap) InternalVariableType.FX_MATRIX, (InternalVariableType) cloneMatrix(this.fx));
        enumMap.put((EnumMap) InternalVariableType.FH_MATRIX, (InternalVariableType) cloneMatrix(this.fh));
        enumMap.put((EnumMap) InternalVariableType.HF_MATRIX, (InternalVariableType) cloneMatrix(this.hf));
        enumMap.put((EnumMap) InternalVariableType.HX_MATRIX, (InternalVariableType) cloneMatrix(this.hx));
        enumMap.put((EnumMap) InternalVariableType.INITIAL_STATE, (InternalVariableType) this.initialState.clone());
        return enumMap;
    }

    @Override // net.sourceforge.evoj.neural.NeuronLayer
    public void setInternalVars(Map<InternalVariableType, Object> map) {
        copyMatrix((float[][]) map.get(InternalVariableType.FX_MATRIX), this.fx, InternalVariableType.FX_MATRIX);
        copyMatrix((float[][]) map.get(InternalVariableType.FH_MATRIX), this.fh, InternalVariableType.FH_MATRIX);
        copyMatrix((float[][]) map.get(InternalVariableType.HF_MATRIX), this.hf, InternalVariableType.HF_MATRIX);
        copyMatrix((float[][]) map.get(InternalVariableType.HX_MATRIX), this.hx, InternalVariableType.HX_MATRIX);
        copyArray((float[]) map.get(InternalVariableType.INITIAL_STATE), this.initialState, InternalVariableType.INITIAL_STATE);
    }
}
