diff --git a/mud.c b/mud.c index fd88652..33c9d80 100644 --- a/mud.c +++ b/mud.c @@ -664,9 +664,31 @@ int mud_is_up (struct mud *mud) return ret; } +static +struct cmsghdr *mud_get_pktinfo (struct msghdr *msg, int family) +{ + int cmsg_level = IPPROTO_IP; + int cmsg_type = IP_PKTINFO; + + if (family == AF_INET6) { + cmsg_level = IPPROTO_IPV6; + cmsg_type = IPV6_PKTINFO; + } + + struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); + + for (; cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if ((cmsg->cmsg_level == cmsg_level) && + (cmsg->cmsg_type == cmsg_type)) + return cmsg; + } + + return NULL; +} + int mud_pull (struct mud *mud) { - unsigned char ctrl[1024]; + unsigned char ctrl[256]; for (int i = 0; i < 16; i++) { unsigned next = MUD_PACKET_NEXT(mud->rx.end); @@ -707,28 +729,14 @@ int mud_pull (struct mud *mud) mud_unmapv4((struct sockaddr *)&addr); - int cmsg_level = IPPROTO_IP; - int cmsg_type = IP_PKTINFO; - - if (addr.ss_family == AF_INET6) { - cmsg_level = IPPROTO_IPV6; - cmsg_type = IPV6_PKTINFO; - } - - struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - - for (; cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if ((cmsg->cmsg_level == cmsg_level) && - (cmsg->cmsg_type == cmsg_type)) - break; - } + struct cmsghdr *cmsg = mud_get_pktinfo(&msg, addr.ss_family); if (!cmsg) continue; unsigned index = 0; - if (cmsg_level == IPPROTO_IP) { + if (cmsg->cmsg_level == IPPROTO_IP) { memcpy(&index, &((struct in_pktinfo *)CMSG_DATA(cmsg))->ipi_ifindex, sizeof(index)); @@ -748,7 +756,7 @@ int mud_pull (struct mud *mud) struct cmsghdr *send_cmsg = CMSG_FIRSTHDR(&send_msg); - if (cmsg_level == IPPROTO_IP) { + if (cmsg->cmsg_level == IPPROTO_IP) { memcpy(&((struct in_pktinfo *)CMSG_DATA(send_cmsg))->ipi_spec_dst, &((struct in_pktinfo *)CMSG_DATA(cmsg))->ipi_addr, sizeof(struct in_addr));