diff --git a/mud.c b/mud.c index cf6ec1b..621d878 100644 --- a/mud.c +++ b/mud.c @@ -581,13 +581,22 @@ mud_get_key(struct mud *mud, unsigned char *key, size_t *size) int mud_set_key(struct mud *mud, unsigned char *key, size_t size) { - if (!key || (size < MUD_KEY_SIZE)) { + if (key && (size < MUD_KEY_SIZE)) { errno = EINVAL; return -1; } - memcpy(mud->crypto.private.encrypt.key, key, MUD_KEY_SIZE); - memcpy(mud->crypto.private.decrypt.key, key, MUD_KEY_SIZE); + unsigned char *enc = mud->crypto.private.encrypt.key; + unsigned char *dec = mud->crypto.private.decrypt.key; + + if (key) { + memcpy(enc, key, MUD_KEY_SIZE); + sodium_memzero(key, size); + } else { + randombytes_buf(enc, MUD_KEY_SIZE); + } + + memcpy(dec, enc, MUD_KEY_SIZE); mud->crypto.current = mud->crypto.private; mud->crypto.next = mud->crypto.private; @@ -596,15 +605,6 @@ mud_set_key(struct mud *mud, unsigned char *key, size_t size) return 0; } -int -mud_new_key(struct mud *mud) -{ - unsigned char key[MUD_KEY_SIZE]; - - randombytes_buf(key, sizeof(key)); - return mud_set_key(mud, key, sizeof(key)); -} - int mud_set_tc(struct mud *mud, int tc) { @@ -814,11 +814,13 @@ mud_create(int port, int v4, int v6) if (sodium_init() == -1) return NULL; - struct mud *mud = calloc(1, sizeof(struct mud)); + struct mud *mud = sodium_malloc(sizeof(struct mud)); if (!mud) return NULL; + memset(mud, 0, sizeof(struct mud)); + mud->fd = mud_create_socket(port, v4, v6); if (mud->fd == -1) { @@ -859,7 +861,7 @@ mud_delete(struct mud *mud) errno = err; } - free(mud); + sodium_free(mud); } static int diff --git a/mud.h b/mud.h index 1e8caae..41d94c6 100644 --- a/mud.h +++ b/mud.h @@ -9,7 +9,6 @@ void mud_delete (struct mud *); int mud_get_fd (struct mud *); -int mud_new_key (struct mud *); int mud_set_key (struct mud *, unsigned char *, size_t); int mud_get_key (struct mud *, unsigned char *, size_t *);