#include "sender_key.h" #include #include #include "hkdf.h" #include "signal_protocol_internal.h" #define HASH_OUTPUT_SIZE 32 struct sender_message_key { signal_type_base base; uint32_t iteration; signal_buffer *iv; signal_buffer *cipher_key; signal_buffer *seed; signal_context *global_context; }; struct sender_chain_key { signal_type_base base; uint32_t iteration; signal_buffer *chain_key; signal_context *global_context; }; static int sender_chain_key_get_derivative(signal_buffer **derivative, uint8_t seed, signal_buffer *key, signal_context *global_context); int sender_message_key_create(sender_message_key **key, uint32_t iteration, signal_buffer *seed, signal_context *global_context) { sender_message_key *result = 0; int ret = 0; ssize_t ret_size = 0; hkdf_context *kdf = 0; static const char info_material[] = "WhisperGroup"; uint8_t salt[HASH_OUTPUT_SIZE]; uint8_t *derivative = 0; assert(global_context); if(!seed) { return SG_ERR_INVAL; } memset(salt, 0, sizeof(salt)); result = malloc(sizeof(sender_message_key)); if(!result) { return SG_ERR_NOMEM; } SIGNAL_INIT(result, sender_message_key_destroy); ret = hkdf_create(&kdf, 3, global_context); if(ret < 0) { goto complete; } ret_size = hkdf_derive_secrets(kdf, &derivative, signal_buffer_data(seed), signal_buffer_len(seed), salt, sizeof(salt), (uint8_t *)info_material, sizeof(info_material) - 1, 48); if(ret_size != 48) { ret = (ret_size < 0) ? (int)ret_size : SG_ERR_UNKNOWN; signal_log(global_context, SG_LOG_WARNING, "hkdf_derive_secrets failed"); goto complete; } result->iteration = iteration; result->seed = signal_buffer_copy(seed); if(!result->seed) { ret = SG_ERR_NOMEM; goto complete; } result->iv = signal_buffer_create(derivative, 16); if(!result->iv) { ret = SG_ERR_NOMEM; goto complete; } result->cipher_key = signal_buffer_create(derivative + 16, 32); if(!result->cipher_key) { ret = SG_ERR_NOMEM; goto complete; } result->global_context = global_context; complete: SIGNAL_UNREF(kdf); if(derivative) { free(derivative); } if(ret < 0) { SIGNAL_UNREF(result); } else { ret = 0; *key = result; } return ret; } uint32_t sender_message_key_get_iteration(sender_message_key *key) { assert(key); return key->iteration; } signal_buffer *sender_message_key_get_iv(sender_message_key *key) { assert(key); return key->iv; } signal_buffer *sender_message_key_get_cipher_key(sender_message_key *key) { assert(key); return key->cipher_key; } signal_buffer *sender_message_key_get_seed(sender_message_key *key) { assert(key); return key->seed; } void sender_message_key_destroy(signal_type_base *type) { sender_message_key *key = (sender_message_key *)type; signal_buffer_bzero_free(key->iv); signal_buffer_bzero_free(key->cipher_key); signal_buffer_bzero_free(key->seed); free(key); } int sender_chain_key_create(sender_chain_key **key, uint32_t iteration, signal_buffer *chain_key, signal_context *global_context) { sender_chain_key *result = 0; int ret = 0; assert(global_context); if(!chain_key) { return SG_ERR_INVAL; } result = malloc(sizeof(sender_chain_key)); if(!result) { return SG_ERR_NOMEM; } SIGNAL_INIT(result, sender_chain_key_destroy); result->iteration = iteration; result->chain_key = signal_buffer_copy(chain_key); if(!result->chain_key) { ret = SG_ERR_NOMEM; goto complete; } result->global_context = global_context; complete: if(ret < 0) { SIGNAL_UNREF(result); } else { ret = 0; *key = result; } return ret; } uint32_t sender_chain_key_get_iteration(sender_chain_key *key) { assert(key); return key->iteration; } int sender_chain_key_create_message_key(sender_chain_key *key, sender_message_key **message_key) { static const uint8_t MESSAGE_KEY_SEED = 0x01; int ret = 0; signal_buffer *derivative = 0; sender_message_key *result = 0; assert(key); ret = sender_chain_key_get_derivative(&derivative, MESSAGE_KEY_SEED, key->chain_key, key->global_context); if(ret < 0) { goto complete; } ret = sender_message_key_create(&result, key->iteration, derivative, key->global_context); complete: signal_buffer_free(derivative); if(ret >= 0) { ret = 0; *message_key = result; } return ret; } int sender_chain_key_create_next(sender_chain_key *key, sender_chain_key **next_key) { static const uint8_t CHAIN_KEY_SEED = 0x02; int ret = 0; signal_buffer *derivative = 0; sender_chain_key *result = 0; assert(key); ret = sender_chain_key_get_derivative(&derivative, CHAIN_KEY_SEED, key->chain_key, key->global_context); if(ret < 0) { goto complete; } ret = sender_chain_key_create(&result, key->iteration + 1, derivative, key->global_context); complete: signal_buffer_free(derivative); if(ret >= 0) { ret = 0; *next_key = result; } return ret; } signal_buffer *sender_chain_key_get_seed(sender_chain_key *key) { assert(key); return key->chain_key; } void sender_chain_key_destroy(signal_type_base *type) { sender_chain_key *key = (sender_chain_key *)type; signal_buffer_bzero_free(key->chain_key); free(key); } int sender_chain_key_get_derivative(signal_buffer **derivative, uint8_t seed, signal_buffer *key, signal_context *global_context) { int result = 0; signal_buffer *output_buffer = 0; void *hmac_context = 0; result = signal_hmac_sha256_init(global_context, &hmac_context, signal_buffer_data(key), signal_buffer_len(key)); if(result < 0) { goto complete; } result = signal_hmac_sha256_update(global_context, hmac_context, &seed, sizeof(seed)); if(result < 0) { goto complete; } result = signal_hmac_sha256_final(global_context, hmac_context, &output_buffer); if(result < 0) { goto complete; } complete: signal_hmac_sha256_cleanup(global_context, hmac_context); if(result < 0) { signal_buffer_free(output_buffer); } else { *derivative = output_buffer; } return result; }