package com.aliyun.openservices.eas.predict.http;

import com.aliyun.openservices.eas.discovery.core.DiscoveryClient;
import com.aliyun.openservices.eas.predict.auth.HmacSha1Signature;
import com.aliyun.openservices.eas.predict.request.BladeRequest;
import com.aliyun.openservices.eas.predict.request.CaffeRequest;
import com.aliyun.openservices.eas.predict.request.JsonRequest;
import com.aliyun.openservices.eas.predict.request.TFRequest;
import com.aliyun.openservices.eas.predict.request.TorchRequest;
import com.aliyun.openservices.eas.predict.response.BladeResponse;
import com.aliyun.openservices.eas.predict.response.CaffeResponse;
import com.aliyun.openservices.eas.predict.response.JsonResponse;
import com.aliyun.openservices.eas.predict.response.TFResponse;
import com.aliyun.openservices.eas.predict.response.TorchResponse;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.TimeZone;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.commons.io.IOUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.Header;
import org.apache.http.HttpResponse;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.impl.nio.client.HttpAsyncClients;
import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager;
import org.apache.http.impl.nio.reactor.DefaultConnectingIOReactor;
import org.apache.http.impl.nio.reactor.IOReactorConfig;
import org.apache.http.nio.entity.NByteArrayEntity;
import org.xerial.snappy.Snappy;

/* loaded from: input_file:com/aliyun/openservices/eas/predict/http/PredictClient.class */
public class PredictClient {
    private final int endpointRetryCount = 10;
    private static Log log = LogFactory.getLog(PredictClient.class);
    private CloseableHttpAsyncClient httpclient;
    private String token;
    private String modelName;
    private String endpoint;
    private boolean isCompressed;
    HashMap<String, String> mapHeader;
    private int retryCount;
    private String contentType;
    private int errorCode;
    private String errorMessage;
    private String vipSrvEndPoint;
    private String directEndPoint;
    private int requestTimeout;
    private boolean enableBlacklist;
    private int blacklistSize;
    private int blacklistTimeout;
    private int blacklistTimeoutCount;
    private Map<String, BlacklistData> blacklist;
    private ReentrantReadWriteLock rwlock;

    public PredictClient() {
        this.endpointRetryCount = 10;
        this.httpclient = null;
        this.token = null;
        this.modelName = null;
        this.endpoint = null;
        this.isCompressed = false;
        this.mapHeader = null;
        this.retryCount = 3;
        this.contentType = "application/octet-stream";
        this.errorCode = 400;
        this.vipSrvEndPoint = null;
        this.directEndPoint = null;
        this.requestTimeout = 0;
        this.enableBlacklist = false;
        this.blacklistSize = 10;
        this.blacklistTimeout = 30;
        this.blacklistTimeoutCount = 10;
        this.blacklist = null;
        this.rwlock = new ReentrantReadWriteLock();
    }

    public PredictClient(HttpConfig httpConfig) {
        this.endpointRetryCount = 10;
        this.httpclient = null;
        this.token = null;
        this.modelName = null;
        this.endpoint = null;
        this.isCompressed = false;
        this.mapHeader = null;
        this.retryCount = 3;
        this.contentType = "application/octet-stream";
        this.errorCode = 400;
        this.vipSrvEndPoint = null;
        this.directEndPoint = null;
        this.requestTimeout = 0;
        this.enableBlacklist = false;
        this.blacklistSize = 10;
        this.blacklistTimeout = 30;
        this.blacklistTimeoutCount = 10;
        this.blacklist = null;
        this.rwlock = new ReentrantReadWriteLock();
        try {
            PoolingNHttpClientConnectionManager poolingNHttpClientConnectionManager = new PoolingNHttpClientConnectionManager(new DefaultConnectingIOReactor());
            poolingNHttpClientConnectionManager.setMaxTotal(httpConfig.getMaxConnectionCount());
            poolingNHttpClientConnectionManager.setDefaultMaxPerRoute(httpConfig.getMaxConnectionPerRoute());
            this.requestTimeout = httpConfig.getRequestTimeout();
            IOReactorConfig build = IOReactorConfig.custom().setTcpNoDelay(true).setSoTimeout(httpConfig.getReadTimeout()).setSoReuseAddress(true).setConnectTimeout(httpConfig.getConnectTimeout()).setIoThreadCount(httpConfig.getIoThreadNum()).setSoKeepAlive(httpConfig.isKeepAlive()).build();
            this.httpclient = HttpAsyncClients.custom().setConnectionManager(poolingNHttpClientConnectionManager).setDefaultIOReactorConfig(build).setDefaultRequestConfig(RequestConfig.custom().setConnectTimeout(httpConfig.getConnectTimeout()).setSocketTimeout(httpConfig.getReadTimeout()).build()).build();
            this.httpclient.start();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private PredictClient setHttp(CloseableHttpAsyncClient closeableHttpAsyncClient) {
        this.httpclient = closeableHttpAsyncClient;
        return this;
    }

    public PredictClient setToken(String str) {
        if (str == null || str.length() > 0) {
            this.token = str;
        }
        return this;
    }

    public PredictClient setRequestTimeout(int i) {
        this.requestTimeout = i;
        return this;
    }

    public PredictClient setModelName(String str) {
        this.modelName = str;
        return this;
    }

    public PredictClient setEndpoint(String str) {
        this.endpoint = str;
        return this;
    }

    public PredictClient setVIPServer(String str) {
        if (str == null || str.length() > 0) {
            this.vipSrvEndPoint = str;
        }
        return this;
    }

    public PredictClient setDirectEndpoint(String str) {
        if (this.directEndPoint == null || this.directEndPoint.length() > 0) {
            this.directEndPoint = str;
            System.setProperty("com.aliyun.eas.discovery", str);
        }
        return this;
    }

    public PredictClient setIsCompressed(boolean z) {
        this.isCompressed = z;
        return this;
    }

    public PredictClient setRetryCount(int i) {
        this.retryCount = i;
        return this;
    }

    public PredictClient setTracing(HashMap<String, String> hashMap) {
        this.mapHeader = hashMap;
        return this;
    }

    public PredictClient setContentType(String str) {
        this.contentType = str;
        return this;
    }

    public PredictClient startBlacklistMechanism(int i, int i2, int i3) {
        this.enableBlacklist = true;
        this.blacklistSize = i;
        this.blacklistTimeout = i2;
        this.blacklistTimeoutCount = i3;
        this.blacklist = new HashMap();
        new Thread(new BlacklistTask(this.blacklist, this.rwlock, this.blacklistTimeout)).start();
        return this;
    }

    public int getErrorCode() {
        return this.errorCode;
    }

    public String getErrorMessage() {
        return this.errorMessage;
    }

    public PredictClient createChlidClient(String str, String str2, String str3) {
        PredictClient predictClient = new PredictClient();
        predictClient.setHttp(this.httpclient).setToken(str).setEndpoint(str2).setModelName(str3);
        return predictClient;
    }

    public PredictClient createChlidClient() {
        PredictClient predictClient = new PredictClient();
        predictClient.setHttp(this.httpclient).setToken(this.token).setModelName(this.modelName);
        if (this.vipSrvEndPoint != null) {
            predictClient.setVIPServer(this.vipSrvEndPoint);
        } else if (this.directEndPoint != null) {
            predictClient.setDirectEndpoint(this.directEndPoint);
        } else {
            predictClient.setEndpoint(this.endpoint);
        }
        return predictClient;
    }

    private String getUrl(String str) throws Exception {
        String str2 = this.endpoint;
        String str3 = "";
        if (!this.enableBlacklist) {
            int i = 0;
            while (true) {
                if (i < 10) {
                    if (this.directEndPoint == null) {
                        str3 = "http://" + str2 + "/api/predict/" + this.modelName;
                        break;
                    }
                    str2 = DiscoveryClient.srvHost(this.modelName).toInetAddr();
                    str3 = "http://" + str2 + "/api/predict/" + this.modelName;
                    if (DiscoveryClient.getHosts(this.modelName).size() >= 2 && str3.equals(str)) {
                        i++;
                    }
                    return str3;
                }
                break;
            }
        }
        int i2 = 10;
        if (this.blacklistSize * 2 > 10) {
            i2 = this.blacklistSize * 2;
        }
        int i3 = 0;
        while (true) {
            if (i3 < i2) {
                if (this.directEndPoint == null) {
                    str3 = "http://" + str2 + "/api/predict/" + this.modelName;
                    break;
                }
                str2 = DiscoveryClient.srvHost(this.modelName).toInetAddr();
                str3 = "http://" + str2 + "/api/predict/" + this.modelName;
                if (DiscoveryClient.getHosts(this.modelName).size() < 2) {
                    return str3;
                }
                try {
                    this.rwlock.readLock().lock();
                    if (!str3.equals(str)) {
                        if (!this.blacklist.containsKey(str3)) {
                            return str3;
                        }
                        if (this.blacklist.get(str3).getCount() < this.blacklistTimeoutCount) {
                            this.rwlock.readLock().unlock();
                            return str3;
                        }
                    }
                    this.rwlock.readLock().unlock();
                    i3++;
                } finally {
                    this.rwlock.readLock().unlock();
                }
            } else {
                break;
            }
        }
        return str3;
    }

    private HttpPost generateSignature(byte[] bArr, String str) throws Exception {
        HttpPost httpPost = new HttpPost(getUrl(str));
        httpPost.setEntity(new NByteArrayEntity(bArr));
        if (this.isCompressed) {
            try {
                bArr = Snappy.compress(bArr);
            } catch (IOException e) {
                log.error("Compress Error", e);
            }
        }
        HmacSha1Signature hmacSha1Signature = new HmacSha1Signature();
        String md5 = hmacSha1Signature.getMD5(bArr);
        httpPost.addHeader("Content-MD5", md5);
        Date date = new Date();
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss", Locale.ENGLISH);
        simpleDateFormat.setTimeZone(TimeZone.getTimeZone("GMT"));
        String str2 = simpleDateFormat.format(date) + " GMT";
        httpPost.addHeader("Date", str2);
        httpPost.addHeader("Content-Type", this.contentType);
        if (this.mapHeader != null) {
            httpPost.addHeader("Client-Timestamp", String.valueOf(System.currentTimeMillis()));
        }
        if (this.token != null) {
            httpPost.addHeader("Authorization", "EAS " + hmacSha1Signature.computeSignature(this.token, "POST\n" + md5 + "\napplication/octet-stream\n" + str2 + "\n/api/predict/" + this.modelName));
        }
        return httpPost;
    }

    private byte[] getContent(HttpPost httpPost) throws IOException, InterruptedException, ExecutionException, TimeoutException {
        byte[] bArr = null;
        Future<HttpResponse> execute = this.httpclient.execute(httpPost, null);
        HttpResponse httpResponse = this.requestTimeout > 0 ? execute.get(this.requestTimeout, TimeUnit.MILLISECONDS) : execute.get();
        if (this.mapHeader != null) {
            Header[] allHeaders = httpResponse.getAllHeaders();
            for (int i = 0; i < allHeaders.length; i++) {
                this.mapHeader.put(allHeaders[i].getName(), allHeaders[i].getValue());
            }
        }
        if (execute.isDone()) {
            try {
                this.errorCode = httpResponse.getStatusLine().getStatusCode();
                this.errorMessage = "";
                if (this.errorCode != 200) {
                    this.errorMessage = IOUtils.toString(httpResponse.getEntity().getContent(), "UTF-8");
                    throw new HttpException(this.errorCode, this.errorMessage);
                }
                bArr = IOUtils.toByteArray(httpResponse.getEntity().getContent());
                if (this.isCompressed) {
                    bArr = Snappy.uncompress(bArr);
                }
            } catch (IllegalStateException e) {
                log.error("Illegal State", e);
            }
        } else {
            if (!execute.isCancelled()) {
                throw new HttpException(-1, "request failed!");
            }
            log.error("request cancelled!", new Exception("Request cancelled"));
        }
        return bArr;
    }

    public BladeResponse predict(BladeRequest bladeRequest) throws Exception {
        BladeResponse bladeResponse = new BladeResponse();
        byte[] predict = predict(bladeRequest.getRequest().toByteArray());
        if (predict != null) {
            bladeResponse.setContentValues(predict);
        }
        return bladeResponse;
    }

    public TFResponse predict(TFRequest tFRequest) throws Exception {
        TFResponse tFResponse = new TFResponse();
        byte[] predict = predict(tFRequest.getRequest().toByteArray());
        if (predict != null) {
            tFResponse.setContentValues(predict);
        }
        return tFResponse;
    }

    public CaffeResponse predict(CaffeRequest caffeRequest) throws Exception {
        CaffeResponse caffeResponse = new CaffeResponse();
        byte[] predict = predict(caffeRequest.getRequest().toByteArray());
        if (predict != null) {
            caffeResponse.setContentValues(predict);
        }
        return caffeResponse;
    }

    public JsonResponse predict(JsonRequest jsonRequest) throws Exception {
        byte[] predict = predict(jsonRequest.getJSON().getBytes());
        JsonResponse jsonResponse = new JsonResponse();
        if (predict != null) {
            jsonResponse.setContentValues(predict);
        }
        return jsonResponse;
    }

    public TorchResponse predict(TorchRequest torchRequest) throws Exception {
        TorchResponse torchResponse = new TorchResponse();
        byte[] predict = predict(torchRequest.getRequest().toByteArray());
        if (predict != null) {
            torchResponse.setContentValues(predict);
        }
        return torchResponse;
    }

    public String predict(String str) throws Exception {
        byte[] predict = predict(str.getBytes());
        if (predict != null) {
            return new String(predict);
        }
        return null;
    }

    private void handleBlacklist(String str) {
        if (!this.blacklist.containsKey(str)) {
            if (this.blacklist.size() < this.blacklistSize) {
                this.blacklist.put(str, new BlacklistData(System.currentTimeMillis() + (this.blacklistTimeout * 1000), 1));
                log.info("Put [" + str + "] into blacklist");
                return;
            }
            return;
        }
        int count = this.blacklist.get(str).getCount();
        if (count < this.blacklistTimeoutCount) {
            this.blacklist.get(str).setCount(count + 1);
            log.info("Set [" + str + "] timeoutCount:" + this.blacklist.get(str).getCount());
        } else {
            this.blacklist.get(str).setTimestamp(System.currentTimeMillis() + (this.blacklistTimeout * 1000));
            log.info("Set [" + str + "] timestamp: " + this.blacklist.get(str).getTimestamp() + " timeoutCount: " + this.blacklist.get(str).getCount());
        }
    }

    public byte[] predict(byte[] bArr) throws Exception {
        byte[] bArr2 = null;
        String str = "";
        for (int i = 0; i <= this.retryCount; i++) {
            try {
                HttpPost generateSignature = generateSignature(bArr, str);
                str = generateSignature.getURI().toString();
                bArr2 = getContent(generateSignature);
                break;
            } catch (ConnectTimeoutException e) {
                String str2 = "URL: " + str + ", " + e.getMessage();
                if (this.enableBlacklist) {
                    this.rwlock.writeLock().lock();
                    handleBlacklist(str);
                    this.rwlock.writeLock().unlock();
                }
                if (i == this.retryCount) {
                    log.error(str2);
                    throw new Exception(str2);
                }
                log.debug(str2);
            } catch (Exception e2) {
                String str3 = "URL: " + str + ", " + e2.getMessage();
                if (i == this.retryCount) {
                    log.error(str3);
                    e2.printStackTrace();
                    throw new Exception(str3);
                }
                log.debug(str3);
            }
        }
        return bArr2;
    }

    public void shutdown() {
        try {
            this.httpclient.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
