/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kudu.client;

import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.PrivilegedActionException;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.Callable;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.kerberos.KerberosTicket;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.apache.kudu.client.Bytes;
import org.apache.kudu.client.CallResponse;
import org.apache.kudu.client.KuduException;
import org.apache.kudu.client.KuduRpc;
import org.apache.kudu.client.NonRecoverableException;
import org.apache.kudu.client.RecoverableException;
import org.apache.kudu.client.RpcOutboundMessage;
import org.apache.kudu.client.SecurityContext;
import org.apache.kudu.client.Status;
import org.apache.kudu.client.internals.SecurityManagerCompatibility;
import org.apache.kudu.rpc.RpcHeader;
import org.apache.kudu.security.Token;
import org.apache.kudu.shaded.com.google.common.base.Joiner;
import org.apache.kudu.shaded.com.google.common.base.Preconditions;
import org.apache.kudu.shaded.com.google.common.collect.ImmutableSet;
import org.apache.kudu.shaded.com.google.common.collect.Lists;
import org.apache.kudu.shaded.com.google.common.collect.Maps;
import org.apache.kudu.shaded.com.google.common.collect.Sets;
import org.apache.kudu.shaded.com.google.protobuf.ByteString;
import org.apache.kudu.shaded.com.google.protobuf.UnsafeByteOperations;
import org.apache.kudu.shaded.io.netty.buffer.ByteBuf;
import org.apache.kudu.shaded.io.netty.buffer.Unpooled;
import org.apache.kudu.shaded.io.netty.channel.Channel;
import org.apache.kudu.shaded.io.netty.channel.ChannelHandler;
import org.apache.kudu.shaded.io.netty.channel.ChannelHandlerAdapter;
import org.apache.kudu.shaded.io.netty.channel.ChannelHandlerContext;
import org.apache.kudu.shaded.io.netty.channel.SimpleChannelInboundHandler;
import org.apache.kudu.shaded.io.netty.channel.embedded.EmbeddedChannel;
import org.apache.kudu.shaded.io.netty.handler.ssl.SslHandler;
import org.apache.kudu.shaded.io.netty.util.concurrent.Future;
import org.apache.kudu.util.SecurityUtil;
import org.apache.yetus.audience.InterfaceAudience;
import org.ietf.jgss.GSSException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.Private
public class Negotiator
extends SimpleChannelInboundHandler<CallResponse> {
    private static final Logger LOG = LoggerFactory.getLogger(Negotiator.class);
    private final SaslClientCallbackHandler saslCallback = new SaslClientCallbackHandler();
    private static final ImmutableSet<RpcHeader.RpcFeatureFlag> SUPPORTED_RPC_FEATURES = ImmutableSet.of(RpcHeader.RpcFeatureFlag.APPLICATION_FEATURE_FLAGS, RpcHeader.RpcFeatureFlag.TLS);
    static final int CONNECTION_CTX_CALL_ID = -3;
    static final int SASL_CALL_ID = -33;
    static final String[] PREFERRED_CIPHER_SUITES = new String[]{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_CCM", "TLS_ECDHE_ECDSA_WITH_AES_256_CCM", "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384", "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384"};
    static final String[] PREFERRED_PROTOCOLS = new String[]{"TLSv1.3", "TLSv1.2"};
    private final String remoteHostname;
    private final SecurityContext securityContext;
    private final Token.SignedTokenPB authnToken;
    private final Token.JwtRawPB jsonWebToken;
    private AuthnTokenNotUsedReason authnTokenNotUsedReason = null;
    private State state = State.INITIAL;
    private SaslClient saslClient;
    private SaslMechanism chosenMech;
    private RpcHeader.AuthenticationTypePB.TypeCase chosenAuthnType;
    private Set<RpcHeader.RpcFeatureFlag> serverFeatures;
    private EmbeddedChannel sslEmbedder;
    private byte[] nonce;
    private Future<Channel> sslHandshakeFuture;
    private Certificate peerCert;
    private final String saslProtocolName;
    private final boolean requireAuthentication;
    private final boolean requireEncryption;
    private final boolean encryptLoopback;
    @InterfaceAudience.LimitedPrivate(value={"Test"})
    boolean overrideLoopbackForTests;

    public Negotiator(String remoteHostname, SecurityContext securityContext, boolean ignoreAuthnToken, String saslProtocolName, boolean requireAuthentication, boolean requireEncryption, boolean encryptLoopback) {
        this.remoteHostname = remoteHostname;
        this.securityContext = securityContext;
        this.saslProtocolName = saslProtocolName;
        this.requireAuthentication = requireAuthentication;
        this.requireEncryption = requireEncryption;
        this.encryptLoopback = encryptLoopback;
        Token.SignedTokenPB token = securityContext.getAuthenticationToken();
        if (token != null) {
            if (ignoreAuthnToken) {
                this.authnToken = null;
                this.authnTokenNotUsedReason = AuthnTokenNotUsedReason.FORBIDDEN_BY_POLICY;
            } else if (!securityContext.hasTrustedCerts()) {
                this.authnToken = null;
                this.authnTokenNotUsedReason = AuthnTokenNotUsedReason.NO_TRUSTED_CERTS;
            } else {
                this.authnToken = token;
            }
        } else {
            this.authnToken = null;
            this.authnTokenNotUsedReason = AuthnTokenNotUsedReason.NONE_AVAILABLE;
        }
        Token.JwtRawPB jwt = securityContext.getJsonWebToken();
        this.jsonWebToken = jwt != null && securityContext.hasTrustedCerts() ? jwt : null;
    }

    public void sendHello(ChannelHandlerContext ctx) {
        this.sendNegotiateMessage(ctx);
    }

    private void sendNegotiateMessage(ChannelHandlerContext ctx) {
        RpcHeader.NegotiatePB.Builder builder = RpcHeader.NegotiatePB.newBuilder().setStep(RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE);
        for (RpcHeader.RpcFeatureFlag flag : SUPPORTED_RPC_FEATURES) {
            builder.addSupportedFeatures(flag);
        }
        if (this.isLoopbackConnection(ctx.channel()) && !this.encryptLoopback) {
            builder.addSupportedFeatures(RpcHeader.RpcFeatureFlag.TLS_AUTHENTICATION_ONLY);
        }
        builder.addAuthnTypesBuilder().setSasl(RpcHeader.AuthenticationTypePB.Sasl.getDefaultInstance());
        if (this.authnToken != null) {
            builder.addAuthnTypesBuilder().setToken(RpcHeader.AuthenticationTypePB.Token.getDefaultInstance());
        }
        if (this.jsonWebToken != null) {
            builder.addAuthnTypesBuilder().setJwt(RpcHeader.AuthenticationTypePB.Jwt.getDefaultInstance());
        }
        this.state = State.AWAIT_NEGOTIATE;
        this.sendSaslMessage(ctx, builder.build());
    }

    private void sendSaslMessage(ChannelHandlerContext ctx, RpcHeader.NegotiatePB msg) {
        RpcHeader.RequestHeader.Builder builder = RpcHeader.RequestHeader.newBuilder();
        builder.setCallId(-33);
        ctx.writeAndFlush(new RpcOutboundMessage(builder, msg), ctx.voidPromise());
    }

    @Override
    public void channelRead0(ChannelHandlerContext ctx, CallResponse msg) throws IOException {
        RpcHeader.ResponseHeader header = msg.getHeader();
        if (header.getIsError()) {
            RpcHeader.ErrorStatusPB.Builder errBuilder = RpcHeader.ErrorStatusPB.newBuilder();
            KuduRpc.readProtobuf(msg.getPBMessage(), errBuilder);
            RpcHeader.ErrorStatusPB error = errBuilder.build();
            LOG.debug("peer {} sent connection negotiation error: {}", (Object)ctx.channel().remoteAddress(), (Object)error.getMessage());
            this.state = State.FINISHED;
            ctx.pipeline().remove(this);
            ctx.fireChannelRead(new Failure(error));
            return;
        }
        RpcHeader.NegotiatePB response = this.parseSaslMsgResponse(msg);
        switch (this.state) {
            case AWAIT_NEGOTIATE: {
                this.handleNegotiateResponse(ctx, response);
                break;
            }
            case AWAIT_SASL: {
                this.handleSaslMessage(ctx, response);
                break;
            }
            case AWAIT_AUTHN_TOKEN_EXCHANGE: {
                this.handleAuthnTokenExchangeResponse(ctx, response);
                break;
            }
            case AWAIT_JWT_EXCHANGE: {
                this.handleJwtExchangeResponse(ctx, response);
                break;
            }
            case AWAIT_TLS_HANDSHAKE: {
                this.handleTlsMessage(ctx, response);
                break;
            }
            default: {
                throw new IllegalStateException("received a message in unexpected state: " + this.state.toString());
            }
        }
    }

    private void handleSaslMessage(ChannelHandlerContext ctx, RpcHeader.NegotiatePB response) throws IOException {
        switch (response.getStep()) {
            case SASL_CHALLENGE: {
                this.handleChallengeResponse(ctx, response);
                break;
            }
            case SASL_SUCCESS: {
                this.handleSuccessResponse(ctx, response);
                break;
            }
            default: {
                throw new IllegalStateException("Wrong negotiation step: " + response.getStep());
            }
        }
    }

    private RpcHeader.NegotiatePB parseSaslMsgResponse(CallResponse response) {
        RpcHeader.ResponseHeader responseHeader = response.getHeader();
        int id = responseHeader.getCallId();
        if (id != -33) {
            throw new IllegalStateException("Received a call that wasn't for SASL");
        }
        RpcHeader.NegotiatePB.Builder saslBuilder = RpcHeader.NegotiatePB.newBuilder();
        KuduRpc.readProtobuf(response.getPBMessage(), saslBuilder);
        return saslBuilder.build();
    }

    private void handleNegotiateResponse(ChannelHandlerContext ctx, RpcHeader.NegotiatePB response) throws IOException {
        Preconditions.checkState(response.getStep() == RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE, "Expected NEGOTIATE message, got {}", (Object)response.getStep());
        this.serverFeatures = this.getFeatureFlags(response);
        boolean negotiatedTls = this.serverFeatures.contains(RpcHeader.RpcFeatureFlag.TLS);
        if (!negotiatedTls && this.requireEncryption) {
            throw new NonRecoverableException(Status.NotAuthorized("server does not support required TLS encryption"));
        }
        this.chosenAuthnType = this.chooseAuthenticationType(response);
        if (this.chosenAuthnType == RpcHeader.AuthenticationTypePB.TypeCase.SASL) {
            this.chooseAndInitializeSaslMech(response);
        }
        if (negotiatedTls) {
            this.startTlsHandshake(ctx);
        } else {
            this.startAuthentication(ctx);
        }
    }

    private boolean isLoopbackConnection(Channel channel) {
        if (this.overrideLoopbackForTests) {
            return true;
        }
        try {
            InetAddress local = ((InetSocketAddress)channel.localAddress()).getAddress();
            InetAddress remote = ((InetSocketAddress)channel.remoteAddress()).getAddress();
            return local.equals(remote);
        }
        catch (ClassCastException cce) {
            return false;
        }
    }

    private void chooseAndInitializeSaslMech(RpcHeader.NegotiatePB response) throws KuduException {
        String message;
        this.securityContext.refreshSubject();
        HashMap<String, String> errorsByMech = Maps.newHashMap();
        HashSet<SaslMechanism> serverMechs = Sets.newHashSet();
        block10: for (RpcHeader.NegotiatePB.SaslMechanism mech : response.getSaslMechanismsList()) {
            switch (mech.getMechanism().toUpperCase(Locale.ENGLISH)) {
                case "GSSAPI": {
                    serverMechs.add(SaslMechanism.GSSAPI);
                    continue block10;
                }
                case "PLAIN": {
                    serverMechs.add(SaslMechanism.PLAIN);
                    continue block10;
                }
            }
            errorsByMech.put(mech.getMechanism(), "unrecognized mechanism");
        }
        for (SaslMechanism clientMech : SaslMechanism.values()) {
            if (clientMech.equals((Object)SaslMechanism.GSSAPI)) {
                Subject s = this.securityContext.getSubject();
                if (s == null || s.getPrivateCredentials(KerberosTicket.class).isEmpty()) {
                    errorsByMech.put(clientMech.name(), "client does not have Kerberos credentials (tgt)");
                    continue;
                }
                if (SecurityUtil.isTgtExpired(s)) {
                    errorsByMech.put(clientMech.name(), "client Kerberos credentials (TGT) have expired");
                    continue;
                }
            }
            if (!serverMechs.contains((Object)clientMech)) {
                errorsByMech.put(clientMech.name(), "not advertised by server");
                continue;
            }
            HashMap<String, String> props = Maps.newHashMap();
            if (clientMech == SaslMechanism.GSSAPI) {
                props.put("javax.security.sasl.qop", "auth-int");
            }
            try {
                this.saslClient = Sasl.createSaslClient(new String[]{clientMech.name()}, null, this.saslProtocolName, this.remoteHostname, props, this.saslCallback);
                this.chosenMech = clientMech;
                break;
            }
            catch (SaslException e) {
                errorsByMech.put(clientMech.name(), e.getMessage());
            }
        }
        if (this.chosenMech != null) {
            LOG.debug("SASL mechanism {} chosen for peer {}", (Object)this.chosenMech.name(), (Object)this.remoteHostname);
            if (this.chosenMech.equals((Object)SaslMechanism.PLAIN) && this.requireAuthentication) {
                message = "client requires authentication, but server does not have Kerberos enabled";
                throw new NonRecoverableException(Status.NotAuthorized(message));
            }
            return;
        }
        message = serverMechs.size() == 1 && serverMechs.contains((Object)SaslMechanism.GSSAPI) ? "server requires authentication, but " + (String)errorsByMech.get(SaslMechanism.GSSAPI.name()) : "client/server supported SASL mechanism mismatch: [" + Joiner.on(", ").withKeyValueSeparator(": ").join(errorsByMech) + "]";
        if (this.authnTokenNotUsedReason != null) {
            message = message + ". Authentication tokens were not used because " + this.authnTokenNotUsedReason.msg;
        }
        if (this.authnToken != null) {
            throw new RecoverableException(Status.NotAuthorized(message));
        }
        throw new NonRecoverableException(Status.NotAuthorized(message));
    }

    private RpcHeader.AuthenticationTypePB.TypeCase chooseAuthenticationType(RpcHeader.NegotiatePB response) {
        Preconditions.checkArgument(response.getAuthnTypesCount() <= 1, "Expected server to reply with at most one authn type");
        if (response.getAuthnTypesCount() == 0) {
            return RpcHeader.AuthenticationTypePB.TypeCase.SASL;
        }
        RpcHeader.AuthenticationTypePB.TypeCase type = response.getAuthnTypes(0).getTypeCase();
        switch (type) {
            case SASL: {
                if (this.authnToken == null) break;
                this.authnTokenNotUsedReason = AuthnTokenNotUsedReason.NOT_CHOSEN_BY_SERVER;
                break;
            }
            case TOKEN: {
                if (this.authnToken != null) break;
                throw new IllegalArgumentException("server chose token authentication but client had no valid token");
            }
            case JWT: {
                if (this.jsonWebToken != null) break;
                throw new IllegalArgumentException("server chose JWT authentication but client had no valid JWT");
            }
            default: {
                throw new IllegalArgumentException("server chose bad authn type " + this.chosenAuthnType);
            }
        }
        return type;
    }

    private Set<RpcHeader.RpcFeatureFlag> getFeatureFlags(RpcHeader.NegotiatePB response) {
        ImmutableSet.Builder features = ImmutableSet.builder();
        for (RpcHeader.RpcFeatureFlag feature : response.getSupportedFeaturesList()) {
            if (feature == RpcHeader.RpcFeatureFlag.UNKNOWN) continue;
            features.add(feature);
        }
        return features.build();
    }

    private void startTlsHandshake(ChannelHandlerContext ctx) throws SSLException {
        SSLEngine engine;
        switch (this.chosenAuthnType) {
            case SASL: {
                engine = this.securityContext.createSSLEngineTrustAll();
                break;
            }
            case TOKEN: 
            case JWT: {
                engine = this.securityContext.createSSLEngine();
                break;
            }
            default: {
                throw new AssertionError((Object)"unreachable");
            }
        }
        engine.setUseClientMode(true);
        HashSet<String> supported = Sets.newHashSet(engine.getSupportedCipherSuites());
        ArrayList<String> toEnable = Lists.newArrayList();
        for (String c : PREFERRED_CIPHER_SUITES) {
            if (!supported.contains(c)) continue;
            toEnable.add(c);
        }
        if (toEnable.isEmpty()) {
            throw new RuntimeException("found no preferred cipher suite among supported: " + Joiner.on(',').join(supported));
        }
        engine.setEnabledCipherSuites(toEnable.toArray(new String[0]));
        supported = Sets.newHashSet(engine.getSupportedProtocols());
        toEnable = Lists.newArrayList();
        for (String p : PREFERRED_PROTOCOLS) {
            if (!supported.contains(p)) continue;
            toEnable.add(p);
        }
        if (toEnable.isEmpty()) {
            throw new RuntimeException("found no preferred TLS protocol among supported: " + Joiner.on(',').join(supported));
        }
        engine.setEnabledProtocols(toEnable.toArray(new String[0]));
        SharableSslHandler handler = new SharableSslHandler(engine);
        this.sslEmbedder = new EmbeddedChannel(handler);
        this.sslHandshakeFuture = handler.handshakeFuture();
        this.state = State.AWAIT_TLS_HANDSHAKE;
        boolean sent = this.sendPendingOutboundTls(ctx);
        assert (sent);
    }

    private void handleTlsMessage(ChannelHandlerContext ctx, RpcHeader.NegotiatePB response) throws IOException {
        boolean isAuthOnly;
        Preconditions.checkState(response.getStep() == RpcHeader.NegotiatePB.NegotiateStep.TLS_HANDSHAKE);
        Preconditions.checkArgument(!response.getTlsHandshake().isEmpty(), "empty TLS message from server");
        this.sslEmbedder.writeInbound(Unpooled.copiedBuffer(response.getTlsHandshake().asReadOnlyByteBuffer()));
        this.sslEmbedder.flush();
        if (this.sendPendingOutboundTls(ctx)) {
            return;
        }
        SharableSslHandler handler = (SharableSslHandler)this.sslEmbedder.pipeline().first();
        handler.resetAdded();
        Certificate[] certs = handler.engine().getSession().getPeerCertificates();
        if (certs.length == 0) {
            throw new SSLPeerUnverifiedException("no peer cert found");
        }
        this.peerCert = certs[0];
        boolean bl = isAuthOnly = this.serverFeatures.contains(RpcHeader.RpcFeatureFlag.TLS_AUTHENTICATION_ONLY) && this.isLoopbackConnection(ctx.channel()) && !this.encryptLoopback;
        if (!isAuthOnly) {
            ctx.pipeline().addFirst("tls", (ChannelHandler)handler);
        }
        this.startAuthentication(ctx);
    }

    private boolean sendPendingOutboundTls(ChannelHandlerContext ctx) {
        ArrayList<ByteString> bufs = Lists.newArrayList();
        while (!this.sslEmbedder.outboundMessages().isEmpty()) {
            ByteBuf msg = (ByteBuf)this.sslEmbedder.readOutbound();
            bufs.add(ByteString.copyFrom(msg.nioBuffer()));
            msg.release();
        }
        ByteString data = ByteString.copyFrom(bufs);
        if (this.sslHandshakeFuture.isDone()) {
            if (!data.isEmpty()) {
                this.sendTunneledTls(ctx, data);
            }
            return false;
        }
        assert (data.size() > 0);
        this.sendTunneledTls(ctx, data);
        return true;
    }

    private void sendTunneledTls(ChannelHandlerContext ctx, ByteString buf) {
        this.sendSaslMessage(ctx, RpcHeader.NegotiatePB.newBuilder().setStep(RpcHeader.NegotiatePB.NegotiateStep.TLS_HANDSHAKE).setTlsHandshake(buf).build());
    }

    private void startAuthentication(ChannelHandlerContext ctx) throws SaslException, NonRecoverableException {
        switch (this.chosenAuthnType) {
            case SASL: {
                this.sendSaslInitiate(ctx);
                break;
            }
            case TOKEN: {
                this.sendTokenExchange(ctx);
                break;
            }
            case JWT: {
                this.sendJwtExchange(ctx);
                break;
            }
            default: {
                throw new AssertionError((Object)"unreachable");
            }
        }
    }

    private void sendTokenExchange(ChannelHandlerContext ctx) {
        Preconditions.checkNotNull(this.authnToken);
        Preconditions.checkNotNull(this.sslHandshakeFuture);
        Preconditions.checkState(this.sslHandshakeFuture.isSuccess());
        RpcHeader.NegotiatePB.Builder builder = RpcHeader.NegotiatePB.newBuilder().setStep(RpcHeader.NegotiatePB.NegotiateStep.TOKEN_EXCHANGE).setAuthnToken(this.authnToken);
        this.state = State.AWAIT_AUTHN_TOKEN_EXCHANGE;
        this.sendSaslMessage(ctx, builder.build());
    }

    private void sendJwtExchange(ChannelHandlerContext ctx) {
        Preconditions.checkNotNull(this.jsonWebToken);
        Preconditions.checkNotNull(this.sslHandshakeFuture);
        Preconditions.checkState(this.sslHandshakeFuture.isSuccess());
        RpcHeader.NegotiatePB.Builder builder = RpcHeader.NegotiatePB.newBuilder().setStep(RpcHeader.NegotiatePB.NegotiateStep.JWT_EXCHANGE).setJwtRaw(this.jsonWebToken);
        this.state = State.AWAIT_JWT_EXCHANGE;
        this.sendSaslMessage(ctx, builder.build());
    }

    private void handleAuthnTokenExchangeResponse(ChannelHandlerContext ctx, RpcHeader.NegotiatePB response) throws SaslException {
        Preconditions.checkArgument(response.getStep() == RpcHeader.NegotiatePB.NegotiateStep.TOKEN_EXCHANGE, "expected TOKEN_EXCHANGE, got step: {}", (Object)response.getStep());
        this.finish(ctx);
    }

    private void handleJwtExchangeResponse(ChannelHandlerContext ctx, RpcHeader.NegotiatePB response) throws SaslException {
        Preconditions.checkArgument(response.getStep() == RpcHeader.NegotiatePB.NegotiateStep.JWT_EXCHANGE, "expected JWT_EXCHANGE, got step: {}", (Object)response.getStep());
        this.finish(ctx);
    }

    private void sendSaslInitiate(ChannelHandlerContext ctx) throws SaslException, NonRecoverableException {
        RpcHeader.NegotiatePB.Builder builder = RpcHeader.NegotiatePB.newBuilder();
        if (this.saslClient.hasInitialResponse()) {
            byte[] initialResponse = this.evaluateChallenge(new byte[0]);
            builder.setToken(UnsafeByteOperations.unsafeWrap(initialResponse));
        }
        builder.setStep(RpcHeader.NegotiatePB.NegotiateStep.SASL_INITIATE);
        builder.addSaslMechanismsBuilder().setMechanism(this.chosenMech.name());
        this.state = State.AWAIT_SASL;
        this.sendSaslMessage(ctx, builder.build());
    }

    private void handleChallengeResponse(ChannelHandlerContext ctx, RpcHeader.NegotiatePB response) throws SaslException, NonRecoverableException {
        byte[] saslToken = this.evaluateChallenge(response.getToken().toByteArray());
        if (saslToken == null) {
            throw new IllegalStateException("Not expecting an empty token");
        }
        RpcHeader.NegotiatePB.Builder builder = RpcHeader.NegotiatePB.newBuilder();
        builder.setToken(UnsafeByteOperations.unsafeWrap(saslToken));
        builder.setStep(RpcHeader.NegotiatePB.NegotiateStep.SASL_RESPONSE);
        this.sendSaslMessage(ctx, builder.build());
    }

    private void verifyChannelBindings(RpcHeader.NegotiatePB response) throws IOException {
        byte[] expected = SecurityUtil.getEndpointChannelBindings(this.peerCert);
        if (!response.hasChannelBindings()) {
            throw new SSLPeerUnverifiedException("no channel bindings provided by remote peer");
        }
        byte[] provided = response.getChannelBindings().toByteArray();
        if (provided.length < 4) {
            throw new SSLPeerUnverifiedException("invalid too-short channel bindings");
        }
        byte[] unwrapped = this.saslClient.unwrap(provided, 4, provided.length - 4);
        if (!Bytes.equals(expected, unwrapped)) {
            throw new SSLPeerUnverifiedException("invalid channel bindings provided by remote peer");
        }
    }

    private void handleSuccessResponse(ChannelHandlerContext ctx, RpcHeader.NegotiatePB response) throws IOException {
        Preconditions.checkState(this.saslClient.isComplete(), "server sent SASL_SUCCESS step, but SASL negotiation is not complete");
        if (this.chosenMech == SaslMechanism.GSSAPI) {
            if (response.hasNonce()) {
                this.nonce = response.getNonce().toByteArray();
            }
            if (this.peerCert != null) {
                this.verifyChannelBindings(response);
            }
        }
        this.finish(ctx);
    }

    private void finish(ChannelHandlerContext ctx) throws SaslException {
        this.state = State.FINISHED;
        ctx.pipeline().remove(this);
        ctx.writeAndFlush(this.makeConnectionContext(), ctx.voidPromise());
        LOG.debug("Authenticated connection {} using {}/{}", new Object[]{ctx.channel(), this.chosenAuthnType, this.chosenMech});
        ctx.fireChannelRead(new Success(this.serverFeatures));
    }

    private RpcOutboundMessage makeConnectionContext() throws SaslException {
        RpcHeader.ConnectionContextPB.Builder builder = RpcHeader.ConnectionContextPB.newBuilder();
        RpcHeader.UserInformationPB.Builder userBuilder = RpcHeader.UserInformationPB.newBuilder();
        String user = this.securityContext.getRealUser();
        userBuilder.setEffectiveUser(user);
        userBuilder.setRealUser(user);
        builder.setDEPRECATEDUserInfo(userBuilder.build());
        if (this.nonce != null) {
            byte[] encodedNonce = this.saslClient.wrap(this.nonce, 0, this.nonce.length);
            ByteBuffer buf = ByteBuffer.allocate(encodedNonce.length + 4);
            buf.order(ByteOrder.BIG_ENDIAN);
            buf.putInt(encodedNonce.length);
            buf.put(encodedNonce);
            builder.setEncodedNonce(UnsafeByteOperations.unsafeWrap(buf.array()));
        }
        RpcHeader.ConnectionContextPB pb = builder.build();
        RpcHeader.RequestHeader.Builder header = RpcHeader.RequestHeader.newBuilder().setCallId(-3);
        return new RpcOutboundMessage(header, pb);
    }

    private byte[] evaluateChallenge(final byte[] challenge) throws SaslException, NonRecoverableException {
        try {
            return SecurityManagerCompatibility.get().callAs(this.securityContext.getSubject(), new Callable<byte[]>(){

                @Override
                public byte[] call() throws SaslException {
                    return Negotiator.this.saslClient.evaluateChallenge(challenge);
                }
            });
        }
        catch (RuntimeException e) {
            if (e.getCause() instanceof PrivilegedActionException) {
                SaslException saslException = (SaslException)e.getCause().getCause();
                Throwable cause = saslException.getCause();
                if (cause instanceof GSSException && ((GSSException)cause).getMajor() == 13) {
                    throw new NonRecoverableException(Status.ConfigurationError("Server requires Kerberos, but this client is not authenticated (missing or expired TGT)"), (Throwable)saslException);
                }
                throw saslException;
            }
            throw e;
        }
    }

    static class SharableSslHandler
    extends SslHandler {
        public SharableSslHandler(SSLEngine engine) {
            super(engine);
        }

        void resetAdded() {
            Field addedField = SecurityManagerCompatibility.get().doPrivileged(() -> {
                try {
                    Class<ChannelHandlerAdapter> c = ChannelHandlerAdapter.class;
                    Field added = c.getDeclaredField("added");
                    added.setAccessible(true);
                    return added;
                }
                catch (NoSuchFieldException e) {
                    throw new RuntimeException(e);
                }
            });
            try {
                addedField.setBoolean(this, false);
            }
            catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }
    }

    static class Failure {
        final RpcHeader.ErrorStatusPB status;

        public Failure(RpcHeader.ErrorStatusPB status) {
            this.status = status;
        }
    }

    static class Success {
        final Set<RpcHeader.RpcFeatureFlag> serverFeatures;

        public Success(Set<RpcHeader.RpcFeatureFlag> serverFeatures) {
            this.serverFeatures = serverFeatures;
        }
    }

    private class SaslClientCallbackHandler
    implements CallbackHandler {
        private SaslClientCallbackHandler() {
        }

        @Override
        public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
            for (Callback callback : callbacks) {
                if (callback instanceof NameCallback) {
                    ((NameCallback)callback).setName(Negotiator.this.securityContext.getRealUser());
                    continue;
                }
                if (callback instanceof PasswordCallback) {
                    ((PasswordCallback)callback).setPassword(new char[0]);
                    continue;
                }
                throw new UnsupportedCallbackException(callback, "Unrecognized SASL client callback");
            }
        }
    }

    private static enum AuthnTokenNotUsedReason {
        NONE_AVAILABLE("no token is available"),
        NO_TRUSTED_CERTS("no TLS certificates are trusted by the client"),
        FORBIDDEN_BY_POLICY("this connection will be used to acquire a new token and therefore requires primary credentials"),
        NOT_CHOSEN_BY_SERVER("the server chose not to accept token authentication");

        final String msg;

        private AuthnTokenNotUsedReason(String msg) {
            this.msg = msg;
        }
    }

    private static enum State {
        INITIAL,
        AWAIT_NEGOTIATE,
        AWAIT_TLS_HANDSHAKE,
        AWAIT_AUTHN_TOKEN_EXCHANGE,
        AWAIT_JWT_EXCHANGE,
        AWAIT_SASL,
        FINISHED;

    }

    private static enum SaslMechanism {
        GSSAPI,
        PLAIN;

    }
}

