From 4e30e747390ac6bb6a6e5ededfbcdae937579393 Mon Sep 17 00:00:00 2001 From: Max Kellermann Date: Tue, 10 Feb 2015 20:30:10 +0100 Subject: net/SocketAddress: light wrapper for struct sockaddr --- src/Listen.cxx | 6 +- src/client/Client.hxx | 4 +- src/client/ClientNew.cxx | 7 +- src/event/ServerSocket.cxx | 43 +++++------ src/event/ServerSocket.hxx | 7 +- src/net/Resolver.cxx | 21 ++--- src/net/Resolver.hxx | 4 +- src/net/SocketAddress.cxx | 38 +++++++++ src/net/SocketAddress.hxx | 103 +++++++++++++++++++++++++ src/net/SocketUtil.cxx | 12 +-- src/net/SocketUtil.hxx | 6 +- src/output/plugins/httpd/HttpdInternal.hxx | 3 +- src/output/plugins/httpd/HttpdOutputPlugin.cxx | 10 +-- 13 files changed, 198 insertions(+), 66 deletions(-) create mode 100644 src/net/SocketAddress.cxx create mode 100644 src/net/SocketAddress.hxx (limited to 'src') diff --git a/src/Listen.cxx b/src/Listen.cxx index 3ba368a0c..cf4b41352 100644 --- a/src/Listen.cxx +++ b/src/Listen.cxx @@ -23,6 +23,7 @@ #include "config/Param.hxx" #include "config/ConfigGlobal.hxx" #include "config/ConfigOption.hxx" +#include "net/SocketAddress.hxx" #include "event/ServerSocket.hxx" #include "util/Error.hxx" #include "util/Domain.hxx" @@ -48,10 +49,9 @@ public: :ServerSocket(_loop), partition(_partition) {} private: - void OnAccept(int fd, const sockaddr &address, - size_t address_length, int uid) override { + void OnAccept(int fd, SocketAddress address, int uid) override { client_new(GetEventLoop(), partition, - fd, &address, address_length, uid); + fd, address, uid); } }; diff --git a/src/client/Client.hxx b/src/client/Client.hxx index ef7f4c406..bdfb1ef93 100644 --- a/src/client/Client.hxx +++ b/src/client/Client.hxx @@ -36,7 +36,7 @@ #include #include -struct sockaddr; +class SocketAddress; class EventLoop; class Path; struct Partition; @@ -204,7 +204,7 @@ void client_manager_init(void); void client_new(EventLoop &loop, Partition &partition, - int fd, const sockaddr *sa, size_t sa_length, int uid); + int fd, SocketAddress address, int uid); /** * Write a C string to the client. diff --git a/src/client/ClientNew.cxx b/src/client/ClientNew.cxx index 2b6941fd0..8dca9f81f 100644 --- a/src/client/ClientNew.cxx +++ b/src/client/ClientNew.cxx @@ -23,6 +23,7 @@ #include "Partition.hxx" #include "Instance.hxx" #include "system/fd_util.h" +#include "net/SocketAddress.hxx" #include "net/Resolver.hxx" #include "Permission.hxx" #include "util/Error.hxx" @@ -58,15 +59,15 @@ Client::Client(EventLoop &_loop, Partition &_partition, void client_new(EventLoop &loop, Partition &partition, - int fd, const struct sockaddr *sa, size_t sa_length, int uid) + int fd, SocketAddress address, int uid) { static unsigned int next_client_num; - const auto remote = sockaddr_to_string(sa, sa_length); + const auto remote = sockaddr_to_string(address); assert(fd >= 0); #ifdef HAVE_LIBWRAP - if (sa->sa_family != AF_UNIX) { + if (address.GetFamily() != AF_UNIX) { // TODO: shall we obtain the program name from argv[0]? const char *progname = "mpd"; diff --git a/src/event/ServerSocket.cxx b/src/event/ServerSocket.cxx index daf8517fc..d7a429f62 100644 --- a/src/event/ServerSocket.cxx +++ b/src/event/ServerSocket.cxx @@ -19,6 +19,7 @@ #include "config.h" #include "ServerSocket.hxx" +#include "net/SocketAddress.hxx" #include "net/SocketUtil.hxx" #include "net/SocketError.hxx" #include "net/Resolver.hxx" @@ -59,29 +60,28 @@ class OneServerSocket final : private SocketMonitor { AllocatedPath path; - size_t address_length; - struct sockaddr *address; + SocketAddress address; public: OneServerSocket(EventLoop &_loop, ServerSocket &_parent, unsigned _serial, - const struct sockaddr *_address, - size_t _address_length) + SocketAddress _address) :SocketMonitor(_loop), parent(_parent), serial(_serial), path(AllocatedPath::Null()), - address_length(_address_length), - address((sockaddr *)xmemdup(_address, _address_length)) + address((sockaddr *)xmemdup(_address.GetAddress(), + _address.GetSize()), + _address.GetSize()) { - assert(_address != nullptr); - assert(_address_length > 0); + assert(!_address.IsNull()); + assert(_address.GetSize() > 0); } OneServerSocket(const OneServerSocket &other) = delete; OneServerSocket &operator=(const OneServerSocket &other) = delete; ~OneServerSocket() { - free(address); + free(const_cast(address.GetAddress())); if (IsDefined()) Close(); @@ -104,7 +104,7 @@ public: gcc_pure std::string ToString() const { - return sockaddr_to_string(address, address_length); + return sockaddr_to_string(address); } void SetFD(int _fd) { @@ -168,8 +168,8 @@ OneServerSocket::Accept() } parent.OnAccept(peer_fd, - (const sockaddr &)peer_address, - peer_address_length, get_remote_uid(peer_fd)); + { (const sockaddr *)&peer_address, socklen_t(peer_address_length) }, + get_remote_uid(peer_fd)); } bool @@ -184,9 +184,9 @@ OneServerSocket::Open(Error &error) { assert(!IsDefined()); - int _fd = socket_bind_listen(address->sa_family, + int _fd = socket_bind_listen(address.GetFamily(), SOCK_STREAM, 0, - address, address_length, 5, + address, 5, error); if (_fd < 0) return false; @@ -280,10 +280,10 @@ ServerSocket::Close() } OneServerSocket & -ServerSocket::AddAddress(const sockaddr &address, size_t address_length) +ServerSocket::AddAddress(SocketAddress address) { sockets.emplace_back(loop, *this, next_serial, - &address, address_length); + address); return sockets.back(); } @@ -302,8 +302,7 @@ ServerSocket::AddFD(int fd, Error &error) return false; } - OneServerSocket &s = AddAddress((const sockaddr &)address, - address_length); + OneServerSocket &s = AddAddress({(const sockaddr *)&address, address_length}); s.SetFD(fd); return true; @@ -320,7 +319,7 @@ ServerSocket::AddPortIPv4(unsigned port) sin.sin_family = AF_INET; sin.sin_addr.s_addr = INADDR_ANY; - AddAddress((const sockaddr &)sin, sizeof(sin)); + AddAddress({(const sockaddr *)&sin, sizeof(sin)}); } #ifdef HAVE_IPV6 @@ -333,7 +332,7 @@ ServerSocket::AddPortIPv6(unsigned port) sin.sin6_port = htons(port); sin.sin6_family = AF_INET6; - AddAddress((const sockaddr &)sin, sizeof(sin)); + AddAddress({(const sockaddr *)&sin, sizeof(sin)}); } /** @@ -392,7 +391,7 @@ ServerSocket::AddHost(const char *hostname, unsigned port, Error &error) return false; for (const struct addrinfo *i = ai; i != nullptr; i = i->ai_next) - AddAddress(*i->ai_addr, i->ai_addrlen); + AddAddress(SocketAddress(i->ai_addr, i->ai_addrlen)); freeaddrinfo(ai); @@ -426,7 +425,7 @@ ServerSocket::AddPath(AllocatedPath &&path, Error &error) s_un.sun_family = AF_UNIX; memcpy(s_un.sun_path, path.c_str(), path_length + 1); - OneServerSocket &s = AddAddress((const sockaddr &)s_un, sizeof(s_un)); + OneServerSocket &s = AddAddress({(const sockaddr *)&s_un, sizeof(s_un)}); s.SetPath(std::move(path)); return true; diff --git a/src/event/ServerSocket.hxx b/src/event/ServerSocket.hxx index 314889517..e5b7cffad 100644 --- a/src/event/ServerSocket.hxx +++ b/src/event/ServerSocket.hxx @@ -24,7 +24,7 @@ #include -struct sockaddr; +class SocketAddress; class EventLoop; class Error; class AllocatedPath; @@ -51,7 +51,7 @@ public: } private: - OneServerSocket &AddAddress(const sockaddr &address, size_t length); + OneServerSocket &AddAddress(SocketAddress address); /** * Add a listener on a port on all IPv4 interfaces. @@ -111,8 +111,7 @@ public: void Close(); protected: - virtual void OnAccept(int fd, const sockaddr &address, - size_t address_length, int uid) = 0; + virtual void OnAccept(int fd, SocketAddress address, int uid) = 0; }; #endif diff --git a/src/net/Resolver.cxx b/src/net/Resolver.cxx index 389b3d4b5..cfbce5ff6 100644 --- a/src/net/Resolver.cxx +++ b/src/net/Resolver.cxx @@ -19,6 +19,7 @@ #include "config.h" #include "Resolver.hxx" +#include "SocketAddress.hxx" #include "util/Error.hxx" #include "util/Domain.hxx" @@ -43,13 +44,14 @@ const Domain resolver_domain("resolver"); std::string -sockaddr_to_string(const struct sockaddr *sa, size_t length) +sockaddr_to_string(SocketAddress address) { #ifdef HAVE_UN - if (sa->sa_family == AF_UNIX) { + if (address.GetFamily() == AF_UNIX) { /* return path of UNIX domain sockets */ - const sockaddr_un &s_un = *(const sockaddr_un *)sa; - if (length < sizeof(s_un) || s_un.sun_path[0] == 0) + const sockaddr_un &s_un = *(const sockaddr_un *) + address.GetAddress(); + if (address.GetSize() < sizeof(s_un) || s_un.sun_path[0] == 0) return "local"; return s_un.sun_path; @@ -57,14 +59,15 @@ sockaddr_to_string(const struct sockaddr *sa, size_t length) #endif #if defined(HAVE_IPV6) && defined(IN6_IS_ADDR_V4MAPPED) - const struct sockaddr_in6 *a6 = (const struct sockaddr_in6 *)sa; + const struct sockaddr_in6 *a6 = (const struct sockaddr_in6 *) + address.GetAddress(); struct sockaddr_in a4; #endif int ret; char host[NI_MAXHOST], serv[NI_MAXSERV]; #if defined(HAVE_IPV6) && defined(IN6_IS_ADDR_V4MAPPED) - if (sa->sa_family == AF_INET6 && + if (address.GetFamily() == AF_INET6 && IN6_IS_ADDR_V4MAPPED(&a6->sin6_addr)) { /* convert "::ffff:127.0.0.1" to "127.0.0.1" */ @@ -74,12 +77,12 @@ sockaddr_to_string(const struct sockaddr *sa, size_t length) sizeof(a4.sin_addr)); a4.sin_port = a6->sin6_port; - sa = (const struct sockaddr *)&a4; - length = sizeof(a4); + address = { (const struct sockaddr *)&a4, sizeof(a4) }; } #endif - ret = getnameinfo(sa, length, host, sizeof(host), serv, sizeof(serv), + ret = getnameinfo(address.GetAddress(), address.GetSize(), + host, sizeof(host), serv, sizeof(serv), NI_NUMERICHOST|NI_NUMERICSERV); if (ret != 0) return "unknown"; diff --git a/src/net/Resolver.hxx b/src/net/Resolver.hxx index a9596c299..7567983f0 100644 --- a/src/net/Resolver.hxx +++ b/src/net/Resolver.hxx @@ -26,10 +26,10 @@ #include -struct sockaddr; struct addrinfo; class Error; class Domain; +class SocketAddress; extern const Domain resolver_domain; @@ -42,7 +42,7 @@ extern const Domain resolver_domain; */ gcc_pure std::string -sockaddr_to_string(const sockaddr *sa, size_t length); +sockaddr_to_string(SocketAddress address); /** * Resolve a specification in the form "host", "host:port", diff --git a/src/net/SocketAddress.cxx b/src/net/SocketAddress.cxx new file mode 100644 index 000000000..38aeb8d6d --- /dev/null +++ b/src/net/SocketAddress.cxx @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2012-2015 Max Kellermann + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * - Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * - Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE + * FOUNDATION OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED + * OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "SocketAddress.hxx" + +#include + +bool +SocketAddress::operator==(SocketAddress other) const +{ + return size == other.size && memcmp(address, other.address, size) == 0; +} diff --git a/src/net/SocketAddress.hxx b/src/net/SocketAddress.hxx new file mode 100644 index 000000000..0577edd72 --- /dev/null +++ b/src/net/SocketAddress.hxx @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2012-2015 Max Kellermann + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * - Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * - Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE + * FOUNDATION OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED + * OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef SOCKET_ADDRESS_HXX +#define SOCKET_ADDRESS_HXX + +#include "Compiler.h" + +#include + +#ifdef WIN32 +#include +#else +#include +#endif + +/** + * An OO wrapper for struct sockaddr. + */ +class SocketAddress { +public: +#ifdef WIN32 + typedef int size_type; +#else + typedef socklen_t size_type; +#endif + +private: + const struct sockaddr *address; + size_type size; + +public: + SocketAddress() = default; + + constexpr SocketAddress(std::nullptr_t):address(nullptr), size(0) {} + + constexpr SocketAddress(const struct sockaddr *_address, + size_type _size) + :address(_address), size(_size) {} + + static constexpr SocketAddress Null() { + return nullptr; + } + + constexpr bool IsNull() const { + return address == nullptr; + } + + const struct sockaddr *GetAddress() const { + return address; + } + + constexpr size_type GetSize() const { + return size; + } + + constexpr int GetFamily() const { + return address->sa_family; + } + + /** + * Does the object have a well-defined address? Check !IsNull() + * before calling this method. + */ + bool IsDefined() const { + return GetFamily() != AF_UNSPEC; + } + + gcc_pure + bool operator==(const SocketAddress other) const; + + bool operator!=(const SocketAddress other) const { + return !(*this == other); + } +}; + +#endif diff --git a/src/net/SocketUtil.cxx b/src/net/SocketUtil.cxx index 556a8d037..b2ea63985 100644 --- a/src/net/SocketUtil.cxx +++ b/src/net/SocketUtil.cxx @@ -19,25 +19,19 @@ #include "config.h" #include "SocketUtil.hxx" +#include "SocketAddress.hxx" #include "SocketError.hxx" #include "system/fd_util.h" #include -#ifndef WIN32 -#include -#else -#include -#include -#endif - #ifdef HAVE_IPV6 #include #endif int socket_bind_listen(int domain, int type, int protocol, - const struct sockaddr *address, size_t address_length, + SocketAddress address, int backlog, Error &error) { @@ -60,7 +54,7 @@ socket_bind_listen(int domain, int type, int protocol, return -1; } - ret = bind(fd, address, address_length); + ret = bind(fd, address.GetAddress(), address.GetSize()); if (ret < 0) { SetSocketError(error); close_socket(fd); diff --git a/src/net/SocketUtil.hxx b/src/net/SocketUtil.hxx index ad02ef481..c0a0c95db 100644 --- a/src/net/SocketUtil.hxx +++ b/src/net/SocketUtil.hxx @@ -26,9 +26,7 @@ #ifndef MPD_SOCKET_UTIL_HXX #define MPD_SOCKET_UTIL_HXX -#include - -struct sockaddr; +class SocketAddress; class Error; /** @@ -47,7 +45,7 @@ class Error; */ int socket_bind_listen(int domain, int type, int protocol, - const struct sockaddr *address, size_t address_length, + SocketAddress address, int backlog, Error &error); diff --git a/src/output/plugins/httpd/HttpdInternal.hxx b/src/output/plugins/httpd/HttpdInternal.hxx index d3ea49cd4..c9f983e17 100644 --- a/src/output/plugins/httpd/HttpdInternal.hxx +++ b/src/output/plugins/httpd/HttpdInternal.hxx @@ -259,8 +259,7 @@ public: private: virtual void RunDeferred() override; - virtual void OnAccept(int fd, const sockaddr &address, - size_t address_length, int uid) override; + void OnAccept(int fd, SocketAddress address, int uid) override; }; extern const class Domain httpd_output_domain; diff --git a/src/output/plugins/httpd/HttpdOutputPlugin.cxx b/src/output/plugins/httpd/HttpdOutputPlugin.cxx index 05e3d53d0..89dbcb85f 100644 --- a/src/output/plugins/httpd/HttpdOutputPlugin.cxx +++ b/src/output/plugins/httpd/HttpdOutputPlugin.cxx @@ -26,6 +26,7 @@ #include "encoder/EncoderPlugin.hxx" #include "encoder/EncoderList.hxx" #include "net/Resolver.hxx" +#include "net/SocketAddress.hxx" #include "Page.hxx" #include "IcyMetaDataServer.hxx" #include "system/fd_util.h" @@ -201,16 +202,14 @@ HttpdOutput::RunDeferred() } void -HttpdOutput::OnAccept(int fd, const sockaddr &address, - size_t address_length, gcc_unused int uid) +HttpdOutput::OnAccept(int fd, SocketAddress address, gcc_unused int uid) { /* the listener socket has become readable - a client has connected */ #ifdef HAVE_LIBWRAP - if (address.sa_family != AF_UNIX) { - const auto hostaddr = sockaddr_to_string(&address, - address_length); + if (address.GetFamily() != AF_UNIX) { + const auto hostaddr = sockaddr_to_string(address); // TODO: shall we obtain the program name from argv[0]? const char *progname = "mpd"; @@ -230,7 +229,6 @@ HttpdOutput::OnAccept(int fd, const sockaddr &address, } #else (void)address; - (void)address_length; #endif /* HAVE_WRAP */ const ScopeLock protect(mutex); -- cgit v1.2.3