diff --git a/mud.c b/mud.c index 5e5ffe0..ab5746f 100644 --- a/mud.c +++ b/mud.c @@ -151,6 +151,7 @@ struct mud { uint64_t send_timeout; uint64_t time_tolerance; uint64_t keyx_timeout; + struct sockaddr_storage addr; struct mud_path *paths; unsigned count; struct { @@ -443,6 +444,30 @@ mud_get_paths(struct mud *mud, unsigned *ret_count) return paths; } +static void +mud_copy_port(struct sockaddr_storage *d, struct sockaddr_storage *s) +{ + uint16_t port = 0; + + switch (s->ss_family) { + case AF_INET: + port = ((struct sockaddr_in *)s)->sin_port; + break; + case AF_INET6: + port = ((struct sockaddr_in6 *)s)->sin6_port; + break; + } + + switch (d->ss_family) { + case AF_INET: + ((struct sockaddr_in *)d)->sin_port = port; + break; + case AF_INET6: + ((struct sockaddr_in6 *)d)->sin6_port = port; + break; + } +} + static struct mud_path * mud_get_path(struct mud *mud, struct sockaddr_storage *local_addr, struct sockaddr_storage *addr, int create) @@ -452,6 +477,8 @@ mud_get_path(struct mud *mud, struct sockaddr_storage *local_addr, return NULL; } + mud_copy_port(local_addr, &mud->addr); + for (unsigned i = 0; i < mud->count; i++) { struct mud_path *path = &mud->paths[i]; @@ -477,7 +504,7 @@ mud_get_path(struct mud *mud, struct sockaddr_storage *local_addr, if (!path) { struct mud_path *paths = realloc(mud->paths, - (mud->count + 1) * sizeof(struct mud_path)); + (mud->count + 1) * sizeof(struct mud_path)); if (!paths) return NULL; @@ -650,7 +677,7 @@ mud_set_keyx_timeout(struct mud *mud, unsigned long msec) } int -mud_set_state(struct mud *mud, struct sockaddr *peer, enum mud_state state) +mud_set_state(struct mud *mud, struct sockaddr *addr, enum mud_state state) { if (!mud->peer.set || (state < MUD_DOWN) || (state > MUD_UP)) { @@ -658,12 +685,13 @@ mud_set_state(struct mud *mud, struct sockaddr *peer, enum mud_state state) return -1; } - struct sockaddr_storage addr; + struct sockaddr_storage local_addr; - if (mud_ss_from_sa(&addr, peer)) + if (mud_ss_from_sa(&local_addr, addr)) return -1; - struct mud_path *path = mud_get_path(mud, &addr, &mud->peer.addr, state > MUD_DOWN); + struct mud_path *path = mud_get_path(mud, + &local_addr, &mud->peer.addr, state > MUD_DOWN); if (!path) return -1; @@ -862,6 +890,8 @@ mud_create(struct sockaddr *addr) mud->tc = MUD_PACKET_TC; mud->mtu = sizeof(struct mud_packet); + memcpy(&mud->addr, addr, addrlen); + mud_keyx_init(mud); randombytes_buf(mud->local.kiss, sizeof(mud->local.kiss)); @@ -966,21 +996,16 @@ mud_decrypt(struct mud *mud, } static int -mud_localaddr(struct sockaddr_storage *addr, struct msghdr *msg, int family) +mud_localaddr(struct sockaddr_storage *addr, struct msghdr *msg) { - int cmsg_level = IPPROTO_IP; - int cmsg_type = MUD_PKTINFO; - - if (family == AF_INET6) { - cmsg_level = IPPROTO_IPV6; - cmsg_type = IPV6_PKTINFO; - } - struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); for (; cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { - if ((cmsg->cmsg_level == cmsg_level) && - (cmsg->cmsg_type == cmsg_type)) + if ((cmsg->cmsg_level == IPPROTO_IP) && + (cmsg->cmsg_type == MUD_PKTINFO)) + break; + if ((cmsg->cmsg_level == IPPROTO_IPV6) && + (cmsg->cmsg_type == IPV6_PKTINFO)) break; } @@ -988,18 +1013,21 @@ mud_localaddr(struct sockaddr_storage *addr, struct msghdr *msg, int family) return 1; memset(addr, 0, sizeof(struct sockaddr_storage)); - addr->ss_family = family; - if (family == AF_INET) { + if (cmsg->cmsg_level == IPPROTO_IP) { + addr->ss_family = AF_INET; memcpy(&((struct sockaddr_in *)addr)->sin_addr, MUD_PKTINFO_SRC(CMSG_DATA(cmsg)), sizeof(struct in_addr)); } else { + addr->ss_family = AF_INET6; memcpy(&((struct sockaddr_in6 *)addr)->sin6_addr, &((struct in6_pktinfo *)CMSG_DATA(cmsg))->ipi6_addr, sizeof(struct in6_addr)); } + mud_unmapv4(addr); + return 0; } @@ -1237,7 +1265,7 @@ mud_recv(struct mud *mud, void *data, size_t size) struct sockaddr_storage local_addr; - if (mud_localaddr(&local_addr, &msg, addr.ss_family)) + if (mud_localaddr(&local_addr, &msg)) return 0; struct mud_path *path = mud_get_path(mud, &local_addr, &addr, 1);