diff --git a/glorytun.c b/glorytun.c index bfd7824..11424f6 100644 --- a/glorytun.c +++ b/glorytun.c @@ -203,15 +203,12 @@ static void gt_set_signal (void) sigaction(SIGPIPE, &sa, NULL); } -static int read_to_buffer (int fd, buffer_t *buffer, size_t size) +static ssize_t fd_read (int fd, void *data, size_t size) { - if (!size || buffer_write_size(buffer)write, size); - - if (!ret) - return 0; + ssize_t ret = read(fd, data, size); if (ret==-1) { if (errno==EAGAIN || errno==EINTR) @@ -221,20 +218,15 @@ static int read_to_buffer (int fd, buffer_t *buffer, size_t size) return 0; } - buffer->write += ret; - return ret; } -static int write_from_buffer (int fd, buffer_t *buffer, size_t size) +static ssize_t fd_write (int fd, const void *data, size_t size) { - if (!size || buffer_read_size(buffer)read, size); - - if (!ret) - return 0; + ssize_t ret = write(fd, data, size); if (ret==-1) { if (errno==EAGAIN || errno==EINTR) @@ -244,8 +236,6 @@ static int write_from_buffer (int fd, buffer_t *buffer, size_t size) return 0; } - buffer->read += ret; - return ret; } @@ -297,6 +287,17 @@ static int option (int argc, char **argv, int n, struct option *opt) return 0; } +static ssize_t get_ip_size (const uint8_t *data, size_t size) +{ + if (size<20) + return -1; + + if ((data[0]>>4)==4) + return (data[2]<<8)|data[3]; + + return 0; +} + struct netio { int fd; buffer_t recv; @@ -391,59 +392,65 @@ int main (int argc, char **argv) buffer_shift(&tun.recv); if (fds[0].revents & POLLIN) { - if (buffer_write_size(&tun.recv)) { - uint8_t *tmp = tun.recv.write; - int r = read_to_buffer(fds[0].fd, &tun.recv, buffer_write_size(&tun.recv)); + size_t size = buffer_write_size(&tun.recv); + ssize_t r = fd_read(fds[0].fd, tun.recv.write, size); - if (!r) - return 2; + if (!r) + return 2; - if (r>0 && r!=((tmp[2]<<8)|tmp[3])) - tun.recv.write = tmp; - } + if (r>0 && r==get_ip_size(tun.recv.write, size)) + tun.recv.write += r; } if (fds[1].revents & POLLOUT) fds[1].events = POLLIN; if (buffer_read_size(&tun.recv)) { - int r = write_from_buffer(fds[1].fd, &tun.recv, buffer_read_size(&tun.recv)); + ssize_t r = fd_write(fds[1].fd, tun.recv.read, buffer_read_size(&tun.recv)); if (!r) goto restart; if (r==-1) fds[1].events = POLLIN|POLLOUT; + + if (r>0) + tun.recv.read += r; } buffer_shift(&sock.recv); if (fds[1].revents & POLLIN) { - int r = read_to_buffer(fds[1].fd, &sock.recv, buffer_write_size(&sock.recv)); + ssize_t r = fd_read(fds[1].fd, sock.recv.write, buffer_write_size(&sock.recv)); if (!r) goto restart; + + if (r>0) + sock.recv.write += r; } if (fds[0].revents & POLLOUT) fds[0].events = POLLIN; - if (buffer_read_size(&sock.recv)>=20) { - if ((sock.recv.read[0]>>4)!=4) - goto restart; + size_t size = buffer_read_size(&sock.recv); + ssize_t ip_size = get_ip_size(sock.recv.read, size); - size_t ps = (sock.recv.read[2]<<8)|sock.recv.read[3]; + if (ip_size>0 && (size_t)ip_size<=size) { + ssize_t r = fd_write(fds[0].fd, sock.recv.read, ip_size); - if (buffer_read_size(&sock.recv)>=ps) { - int r = write_from_buffer(fds[0].fd, &sock.recv, ps); + if (!r) + return 2; - if (!r) - return 2; + if (r==-1) + fds[0].events = POLLIN|POLLOUT; - if (r==-1) - fds[0].events = POLLIN|POLLOUT; - } + if (r>0) + sock.recv.read += r; } + + if (!ip_size) + goto restart; } restart: