{ "cells": [ { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from transformers import GPT2Tokenizer\n", "from rusty_dawg import Dawg, PyDawg\n", "\n", "\n", "# Expects to find a DAWG file in the `dawg` directory below the project root.\n", "dawg_path = \"../dawg/wikitext-2-raw.dawg\"\n", "dawg = Dawg.load(dawg_path)\n", "\n", "# Make sure the tokenizer matches the one used to construct the DAWG.\n", "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "py_dawg = PyDawg(dawg, tokenizer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's query with a string from Wikitext. We can use `get_suffix_context` to get the suffix contexts as output by the DAWG." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'tokens': [1722, 351, 2180, 569, 18354, 8704, 17740, 1830, 837, 569, 18354, 7496, 17740, 6711], 'suffix_contexts': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], 'context_counts': [234, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" ] } ], "source": [ "# Substring found in the Wikitext 2 train data.\n", "query = \"As with previous Valkyira Chronicles games , Valkyria Chronicles III\"\n", "\n", "# Get suffix contexts\n", "suffix_contexts = py_dawg.get_suffix_context(query)\n", "\n", "print(suffix_contexts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The method `get_matching_substrings` provides a more user-friendly wrapper around `get_suffix_context`. It returns the list of spans of the input query for which matches were found in the training corpus, and gives the `count` indicating the number of times each span appears in the training corpus.\n", "\n", "Since this example comes from the training data, we see that the entire query is exactly matched in the training data. " ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'tokens': (1722,), 'token_indices': (0,), 'text': 'As', 'count': 234}\n", "{'tokens': (1722, 351), 'token_indices': (0, 1), 'text': 'As with', 'count': 9}\n", "{'tokens': (1722, 351, 2180, 569, 18354, 8704, 17740, 1830, 837, 569, 18354, 7496, 17740, 6711), 'token_indices': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), 'text': 'As with previous Valkyira Chronicles games, Valkyria Chronicles III', 'count': 1}\n" ] } ], "source": [ "# Return a list of all substrings in the DAWG that match the query.\n", "matching_substrings = py_dawg.get_matching_substrings(query, remove_redundant=True)\n", "\n", "for entry in matching_substrings[\"matches\"]:\n", " print(entry)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a query that isn't included in the training data." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'tokens': (5842,), 'token_indices': (0,), 'text': 'Us', 'count': 4}\n", "{'tokens': (391,), 'token_indices': (1,), 'text': 'ain', 'count': 271}\n", "{'tokens': (18100,), 'token_indices': (2,), 'text': ' bolt', 'count': 4}\n", "{'tokens': (900,), 'token_indices': (3,), 'text': ' set', 'count': 725}\n", "{'tokens': (900, 262), 'token_indices': (3, 4), 'text': ' set the', 'count': 29}\n", "{'tokens': (262, 995), 'token_indices': (4, 5), 'text': ' the world', 'count': 347}\n", "{'tokens': (262, 995, 1700), 'token_indices': (4, 5, 6), 'text': ' the world record', 'count': 2}\n", "{'tokens': (995, 1700, 287, 262, 1802), 'token_indices': (5, 6, 7, 8, 9), 'text': ' world record in the 100', 'count': 1}\n", "{'tokens': (12,), 'token_indices': (10,), 'text': '-', 'count': 17027}\n", "{'tokens': (14470,), 'token_indices': (12,), 'text': ' dash', 'count': 7}\n", "{'tokens': (13,), 'token_indices': (13,), 'text': '.', 'count': 8668}\n" ] } ], "source": [ "query = \"Usain bolt set the world record in the 100-meter dash.\"\n", "suffix_context = py_dawg.get_suffix_context(query)\n", "matching_substrings = py_dawg.get_matching_substrings(query)\n", "\n", "for substring in matching_substrings[\"matches\"]:\n", " print(substring) " ] } ], "metadata": { "kernelspec": { "display_name": "rusty-dawg", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.17" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }