diff --git a/lib/socket.c b/lib/socket.c index d3e636e..1ca7783 100644 --- a/lib/socket.c +++ b/lib/socket.c @@ -120,7 +120,7 @@ static struct nl_sock *__alloc_socket(struct nl_cb *cb) return NULL; sk->s_fd = -1; - sk->s_cb = cb; + sk->s_cb = nl_cb_get(cb); sk->s_local.nl_family = AF_NETLINK; sk->s_peer.nl_family = AF_NETLINK; sk->s_seq_expect = sk->s_seq_next = time(0); @@ -141,12 +141,18 @@ static struct nl_sock *__alloc_socket(struct nl_cb *cb) struct nl_sock *nl_socket_alloc(void) { struct nl_cb *cb; - + struct nl_sock *sk; + cb = nl_cb_alloc(default_cb); if (!cb) return NULL; - return __alloc_socket(cb); + /* will increment cb reference count on success */ + sk = __alloc_socket(cb); + + nl_cb_put(cb); + + return sk; } /** @@ -163,7 +169,7 @@ struct nl_sock *nl_socket_alloc_cb(struct nl_cb *cb) if (cb == NULL) BUG(); - return __alloc_socket(nl_cb_get(cb)); + return __alloc_socket(cb); } /** @@ -519,6 +525,9 @@ struct nl_cb *nl_socket_get_cb(const struct nl_sock *sk) void nl_socket_set_cb(struct nl_sock *sk, struct nl_cb *cb) { + if (cb == NULL) + BUG(); + nl_cb_put(sk->s_cb); sk->s_cb = nl_cb_get(cb); }