# Convert Whisper transformer model from PyTorch to ggml format # # Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium # # You need to clone the original repo in ~/path/to/repo/whisper/ # # git clone https://github.com/openai/whisper ~/path/to/repo/whisper/ # # It is used to various assets needed by the algorithm: # # - tokenizer # - mel filters # # Also, you need to have the original models in ~/.cache/whisper/ # See the original repo for more details. # # This script loads the specified model and whisper assets and saves them in ggml format. # The output is a single binary file containing the following information: # # - hparams # - mel filters # - tokenizer vocab # - model variables # # For each variable, write the following: # # - Number of dimensions (int) # - Name length (int) # - Dimensions (int[n_dims]) # - Name (char[name_length]) # - Data (float[n_dims]) # import io import os import sys import struct import json import code import torch import numpy as np from transformers import GPTJForCausalLM from transformers import GPT2TokenizerFast # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110 LANGUAGES = { "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "iw": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", } # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292 def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"): os.environ["TOKENIZERS_PARALLELISM"] = "false" path = os.path.join(path_to_whisper_repo, "whisper/assets", name) tokenizer = GPT2TokenizerFast.from_pretrained(path) specials = [ "<|startoftranscript|>", *[f"<|{lang}|>" for lang in LANGUAGES.keys()], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nocaptions|>", "<|notimestamps|>", ] tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) return tokenizer # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) if len(sys.argv) < 4: print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n") sys.exit(1) fname_inp = sys.argv[1] dir_whisper = sys.argv[2] dir_out = sys.argv[3] # try to load PyTorch binary data try: model_bytes = open(fname_inp, "rb").read() with io.BytesIO(model_bytes) as fp: checkpoint = torch.load(fp, map_location="cpu") except: print("Error: failed to load PyTorch model file: %s" % fname_inp) sys.exit(1) hparams = checkpoint["dims"] print("hparams:", hparams) list_vars = checkpoint["model_state_dict"] #print(list_vars['encoder.positional_embedding']) #print(list_vars['encoder.conv1.weight']) #print(list_vars['encoder.conv1.weight'].shape) # load mel filters n_mels = hparams["n_mels"] with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f: filters = torch.from_numpy(f[f"mel_{n_mels}"]) #print (filters) #code.interact(local=locals()) multilingual = hparams["n_vocab"] == 51865 tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2") #print(tokenizer) #print(tokenizer.name_or_path) #print(len(tokenizer.additional_special_tokens)) dir_tokenizer = tokenizer.name_or_path # output in the same directory as the model fname_out = dir_out + "/ggml-model.bin" with open(dir_tokenizer + "/vocab.json", "r") as f: tokens = json.load(f) # use 16-bit or 32-bit floats use_f16 = True if len(sys.argv) > 4: use_f16 = False fname_out = dir_out + "/ggml-model-f32.bin" fout = open(fname_out, "wb") fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex fout.write(struct.pack("i", hparams["n_vocab"])) fout.write(struct.pack("i", hparams["n_audio_ctx"])) fout.write(struct.pack("i", hparams["n_audio_state"])) fout.write(struct.pack("i", hparams["n_audio_head"])) fout.write(struct.pack("i", hparams["n_audio_layer"])) fout.write(struct.pack("i", hparams["n_text_ctx"])) fout.write(struct.pack("i", hparams["n_text_state"])) fout.write(struct.pack("i", hparams["n_text_head"])) fout.write(struct.pack("i", hparams["n_text_layer"])) fout.write(struct.pack("i", hparams["n_mels"])) fout.write(struct.pack("i", use_f16)) # write mel filters fout.write(struct.pack("i", filters.shape[0])) fout.write(struct.pack("i", filters.shape[1])) for i in range(filters.shape[0]): for j in range(filters.shape[1]): fout.write(struct.pack("f", filters[i][j])) byte_encoder = bytes_to_unicode() byte_decoder = {v:k for k, v in byte_encoder.items()} fout.write(struct.pack("i", len(tokens))) for key in tokens: text = bytearray([byte_decoder[c] for c in key]) fout.write(struct.pack("i", len(text))) fout.write(text) for name in list_vars.keys(): data = list_vars[name].squeeze().numpy() print("Processing variable: " + name + " with shape: ", data.shape) # reshape conv bias from [n] to [n, 1] if name == "encoder.conv1.bias" or \ name == "encoder.conv2.bias": data = data.reshape(data.shape[0], 1) print(" Reshaped variable: " + name + " to shape: ", data.shape) n_dims = len(data.shape); # looks like the whisper models are in f16 by default # so we need to convert the small tensors to f32 until we fully support f16 in ggml # ftype == 0 -> float32, ftype == 1 -> float16 ftype = 1; if use_f16: if n_dims < 2 or \ name == "encoder.conv1.bias" or \ name == "encoder.conv2.bias" or \ name == "encoder.positional_embedding" or \ name == "decoder.positional_embedding": ftype = 0 data = data.astype(np.float32) print(" Converting to float32") data = data.astype(np.float32) ftype = 0 else: if n_dims < 3 and data.dtype != np.float32: data = data.astype(np.float32) ftype = 0 #if name.startswith("encoder"): # if name.endswith("mlp.0.weight") or \ # name.endswith("mlp.2.weight"): # print(" Transposing") # data = data.transpose() # header str = name.encode('utf-8') fout.write(struct.pack("iii", n_dims, len(str), ftype)) for i in range(n_dims): fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) fout.write(str); # data data.tofile(fout) fout.close() print("Done. Output file: " + fname_out) print("")