#include "extensions/transport_sockets/proxy_protocol/proxy_protocol.h" #include #include "envoy/config/core/v3/proxy_protocol.pb.h" #include "envoy/network/transport_socket.h" #include "common/buffer/buffer_impl.h" #include "common/network/address_impl.h" #include "extensions/common/proxy_protocol/proxy_protocol_header.h" using envoy::config::core::v3::ProxyProtocolConfig; using envoy::config::core::v3::ProxyProtocolConfig_Version; namespace Envoy { namespace Extensions { namespace TransportSockets { namespace ProxyProtocol { UpstreamProxyProtocolSocket::UpstreamProxyProtocolSocket( Network::TransportSocketPtr&& transport_socket, Network::TransportSocketOptionsSharedPtr options, ProxyProtocolConfig_Version version) : PassthroughSocket(std::move(transport_socket)), options_(options), version_(version) {} void UpstreamProxyProtocolSocket::setTransportSocketCallbacks( Network::TransportSocketCallbacks& callbacks) { transport_socket_->setTransportSocketCallbacks(callbacks); callbacks_ = &callbacks; } Network::IoResult UpstreamProxyProtocolSocket::doWrite(Buffer::Instance& buffer, bool end_stream) { if (header_buffer_.length() > 0) { auto header_res = writeHeader(); if (header_buffer_.length() == 0 && header_res.action_ == Network::PostIoAction::KeepOpen) { auto inner_res = transport_socket_->doWrite(buffer, end_stream); return {inner_res.action_, header_res.bytes_processed_ + inner_res.bytes_processed_, false}; } return header_res; } else { return transport_socket_->doWrite(buffer, end_stream); } } void UpstreamProxyProtocolSocket::generateHeader() { if (version_ == ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1) { generateHeaderV1(); } else { generateHeaderV2(); } } void UpstreamProxyProtocolSocket::generateHeaderV1() { // Default to local addresses (used if no downstream connection exists e.g. health checks) auto src_addr = callbacks_->connection().localAddress(); auto dst_addr = callbacks_->connection().remoteAddress(); if (options_ && options_->proxyProtocolOptions().has_value()) { const auto options = options_->proxyProtocolOptions().value(); src_addr = options.src_addr_; dst_addr = options.dst_addr_; } Common::ProxyProtocol::generateV1Header(*src_addr->ip(), *dst_addr->ip(), header_buffer_); } void UpstreamProxyProtocolSocket::generateHeaderV2() { if (!options_ || !options_->proxyProtocolOptions().has_value()) { Common::ProxyProtocol::generateV2LocalHeader(header_buffer_); } else { const auto options = options_->proxyProtocolOptions().value(); Common::ProxyProtocol::generateV2Header(*options.src_addr_->ip(), *options.dst_addr_->ip(), header_buffer_); } } Network::IoResult UpstreamProxyProtocolSocket::writeHeader() { Network::PostIoAction action = Network::PostIoAction::KeepOpen; uint64_t bytes_written = 0; do { if (header_buffer_.length() == 0) { break; } Api::IoCallUint64Result result = callbacks_->ioHandle().write(header_buffer_); if (result.ok()) { ENVOY_CONN_LOG(trace, "write returns: {}", callbacks_->connection(), result.rc_); bytes_written += result.rc_; } else { ENVOY_CONN_LOG(trace, "write error: {}", callbacks_->connection(), result.err_->getErrorDetails()); if (result.err_->getErrorCode() != Api::IoError::IoErrorCode::Again) { action = Network::PostIoAction::Close; } break; } } while (true); return {action, bytes_written, false}; } void UpstreamProxyProtocolSocket::onConnected() { generateHeader(); transport_socket_->onConnected(); } UpstreamProxyProtocolSocketFactory::UpstreamProxyProtocolSocketFactory( Network::TransportSocketFactoryPtr transport_socket_factory, ProxyProtocolConfig config) : transport_socket_factory_(std::move(transport_socket_factory)), config_(config) {} Network::TransportSocketPtr UpstreamProxyProtocolSocketFactory::createTransportSocket( Network::TransportSocketOptionsSharedPtr options) const { auto inner_socket = transport_socket_factory_->createTransportSocket(options); if (inner_socket == nullptr) { return nullptr; } return std::make_unique(std::move(inner_socket), options, config_.version()); } bool UpstreamProxyProtocolSocketFactory::implementsSecureTransport() const { return transport_socket_factory_->implementsSecureTransport(); } } // namespace ProxyProtocol } // namespace TransportSockets } // namespace Extensions } // namespace Envoy