#!/usr/bin/env python3 # Copyright (c) 2010 ArtForz -- public domain half-a-node # Copyright (c) 2012 Jeff Garzik # Copyright (c) 2010-2016 The Bitcoin Core developers # Copyright (c) 2024 Bitcoin Unlimited # Distributed under the MIT software license, see the accompanying # file COPYING or http://www.opensource.org/licenses/mit-license.php. # # mininode.py - Bitcoin P2P network half-a-node # # This python code was modified from ArtForz' public domain half-a-node, as # found in the mini-node branch of http://github.com/jgarzik/pynode. # # NodeConn: an object which manages p2p connectivity to a bitcoin node # NodeConnCB: a base class that describes the interface for receiving # callbacks with network messages from a NodeConn # CBlock, CTransaction, CBlockHeader, CTxIn, CTxOut, etc....: # data structures that should map to corresponding structures in # bitcoin/primitives # MsgBlock, MsgTx, MsgHeaders, etc.: # data structures that represent network messages # ser_*, deser_*: functions that handle serialization/deserialization import asyncio import pdb import socket import struct import time import sys from io import BytesIO from threading import RLock import logging import traceback from .constants import TEST_PEER_VERSION from .environment import network, Network from .util import wait_for, sha256 # These structs are the same for BCH and NEX (and NYI in bch module) from .nex.nodemessages import ( MsgVerack, MsgGetdata, MsgPong, MsgPing, MsgVersion, MsgAddr, MsgInv, MsgGetblocks, MsgGetaddr, MsgGetheaders, MsgReject, MsgSendheaders, MsgMempool, CInv, ) if network() == Network.BCH: from .bch.nodemessages import ( CTransaction, MsgBlock, MsgHeaders, MsgTx, CBlockHeader, ) elif network() == Network.NEX: from .nex.nodemessages import ( CTransaction, MsgBlock, MsgHeaders, MsgTx, CBlockHeader, ) else: raise NotImplementedError() # One lock for synchronizing all data access between the networking thread (see # NetworkThread below) and the thread running the test logic. For simplicity, # NodeConn acquires this lock whenever delivering a message to a NodeConnCB, # and whenever adding anything to the send buffer (in send_message()). This # lock should be acquired in the thread running the test logic to synchronize # access to any data shared with the NodeConnCB or NodeConn. mininode_lock = RLock() async def wait_until(predicate, attempts=float("inf"), timeout=float("inf")): attempt = 0 elapsed = 0 while attempt < attempts and elapsed < timeout: with mininode_lock: if predicate(): return True attempt += 1 elapsed += 0.25 await asyncio.sleep(0.25) return False MAX_INV_SZ = 50000 MAX_BLOCK_SIZE = 1000000 NEXA_MAGIC_BYTES = { "nexa": b"\x72\x27\x12\x21", "testnet3": b"\x72\x27\x12\x22", "regtest": b"\xea\xe5\xef\xea", } BCH_MAGIC_BYTES = { "mainnet": b"\xe3\xe1\xf3\xe8", "testnet3": b"\xf4\xe5\xf3\xf4", "regtest": b"\xda\xb5\xbf\xfa", } class MiniNodeError(Exception): pass class DisconnectedError(MiniNodeError): pass # pylint: disable=too-many-public-methods class NodeConnCB: def __init__(self, extversion=None): """Pass None to not use extversion. Pass a MsgExtversion object to use that. Pass True to use extversion, but your derived class will issue the message""" self.verack_received = False self.xverack_received = False self.xver = {} # deliver_sleep_time is helpful for debugging race conditions in p2p # tests; it causes message delivery to sleep for the specified time # before acquiring the global lock and delivering the next message. self.deliver_sleep_time = None self.disconnected = False self.extversion = extversion def set_deliver_sleep_time(self, value): with mininode_lock: self.deliver_sleep_time = value def get_deliver_sleep_time(self): with mininode_lock: return self.deliver_sleep_time # Spin until verack message is received from the node. # Tests may want to use this as a signal that the test can begin. # This can be called from the testing thread, so it needs to acquire the # global lock. async def wait_for(self, test_function): for _ in range(200): if self.disconnected: raise DisconnectedError() with mininode_lock: if test_function(): return await asyncio.sleep(0.05) raise TimeoutError(f"Waiting for {repr(test_function)} timed out.") async def wait_for_verack(self): await self.wait_for(lambda: self.verack_received) async def deliver(self, conn, message): deliver_sleep = self.get_deliver_sleep_time() if deliver_sleep is not None: asyncio.sleep(deliver_sleep) with mininode_lock: fn = "on_" + message.command.decode("ascii") try: getattr(self, fn)(conn, message) except BaseException: print(f"ERROR delivering {repr(message)} ({sys.exc_info()[0]}) to {fn}") traceback.print_exc() def on_version(self, conn, message): asyncio.create_task(conn.send_message(MsgVerack())) conn.ver_send = min(TEST_PEER_VERSION, message.n_version) conn.ver_recv = conn.ver_send def on_verack(self, conn, _message): conn.ver_recv = conn.ver_send self.verack_received = True def on_inv(self, conn, message): want = MsgGetdata() for i in message.inv: if i.type != 0: want.inv.append(i) if want.inv: asyncio.create_task(conn.send_message(want)) def on_addr(self, conn, message): pass def on_alert(self, conn, message): pass def on_getdata(self, conn, message): pass def on_getblocks(self, conn, message): pass def on_tx(self, conn, message): pass def on_block(self, conn, message): pass def on_getaddr(self, conn, message): pass def on_headers(self, conn, message): pass def on_getheaders(self, conn, message): pass def on_ping(self, conn, message): asyncio.create_task(conn.send_message(MsgPong(message.nonce))) def on_reject(self, conn, message): pass def on_close(self, _conn): self.disconnected = True def on_mempool(self, conn): pass def on_pong(self, conn, message): pass def on_sendheaders(self, conn, message): pass def on_sendcmpct(self, conn, message): pass def on_cmpctblock(self, conn, message): pass def on_getblocktxn(self, conn, message): pass def on_blocktxn(self, conn, message): pass def on_xverack_old(self, _conn, _message): self.xverack_received = True def on_extversion(self, conn, message): # reply with a verack since we got both the version and extversion # messages conn.xver = message if self.extversion is not None: # already sent otherwise asyncio.create_task(conn.send_message(MsgVerack())) # More useful callbacks and functions for NodeConnCB's which have a single # NodeConn class SingleNodeConnCB(NodeConnCB): def __init__(self): NodeConnCB.__init__(self) self.connection = None self.ping_counter = 1 self.last_pong = MsgPong() def add_connection(self, conn): self.connection = conn # Wrapper for the NodeConn's send_message function async def send_message(self, message, pushbuf=False): assert self.connection is not None, "forgot to .add_connection" await self.connection.send_message(message, pushbuf) async def send_and_ping(self, message): await self.send_message(message) await self.sync_with_ping() def on_pong(self, conn, message): self.last_pong = message # Sync up with the node async def sync_with_ping(self, timeout=30): def received_pong(): return self.last_pong.nonce == self.ping_counter await self.send_message(MsgPing(nonce=self.ping_counter)) success = await wait_until(received_pong, timeout) self.ping_counter += 1 return success def dupdate(x, y): x.update(y) return x # pylint: disable=too-few-public-methods class MsgAnnotater: def __init__(self): self.idx = 0 def annotate(self, msg, conn): msg.idx = self.idx msg.offset = conn.curIndex self.idx += 1 return msg # Create socket with SO_REUSEADDR to mitigate the error # "[Errno 99] Cannot assign requested address" # on aarch64 qa test setup async def create_connection_with_reuseaddr(dstaddr, dstport): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setblocking(False) loop = asyncio.get_event_loop() await loop.sock_connect(sock, (dstaddr, dstport)) # Create StreamReader and StreamWriter from the socket reader, writer = await asyncio.open_connection(sock=sock) return reader, writer # The actual NodeConn class # This class provides an interface for a p2p connection to a specified node # pylint: disable=too-many-instance-attributes class NodeConn: messagemap = { b"version": MsgVersion, b"verack": MsgVerack, b"addr": MsgAddr, b"inv": MsgInv, b"getdata": MsgGetdata, b"getblocks": MsgGetblocks, b"tx": MsgTx, b"block": MsgBlock, b"getaddr": MsgGetaddr, b"ping": MsgPing, b"pong": MsgPong, b"headers": MsgHeaders, b"getheaders": MsgGetheaders, b"reject": MsgReject, b"mempool": MsgMempool, b"sendheaders": MsgSendheaders, } # pylint: disable=too-many-arguments def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, dstaddr, dstport, rpc, callback, net="regtest", services=1, send_initial_version=True, extversion_service=False, ): self.log = logging.getLogger(f"NodeConn({dstaddr}:{dstport})") self.dstaddr = dstaddr self.dstport = dstport self.reader = reader self.writer = writer self.sendbuf = b"" self.recvbuf = b"" self.ver_send = 209 self.ver_recv = 209 self.last_sent = 0 self.state = "connected" self.network = net if network() == Network.BCH: self.magic_bytes = BCH_MAGIC_BYTES[net] elif network() == Network.NEX: self.magic_bytes = NEXA_MAGIC_BYTES[net] else: raise NotImplementedError() self.cb = callback self.disconnect = False self.cur_index = 0 self.allow0_checksum = False self.produce0_checksum = False self.num0_checksums = 0 self.rpc = rpc self.exceptions = [] self.cp = None self.recv_buf_len = None if send_initial_version: # stuff version msg into sendbuf vt = MsgVersion() if extversion_service: services = services | (1 << 11) vt.n_services = services vt.addr_to.ip = self.dstaddr vt.addr_to.port = self.dstport vt.addr_from.ip = "0.0.0.0" vt.addr_from.port = 0 asyncio.create_task(self.send_message(vt)) @classmethod async def create( cls, dstaddr, dstport, rpc, callback, net="regtest", services=1, send_initial_version=True, extversion_service=False, ): print(f"MiniNode: Connected to Bitcoin Node IP # {dstaddr}:{dstport}") reader, writer = await create_connection_with_reuseaddr(dstaddr, dstport) instance = cls( reader, writer, dstaddr, dstport, rpc, callback, net, services, send_initial_version, extversion_service, ) asyncio.create_task(instance.handle_connection()) return instance async def handle_connection(self): while not self.disconnect: try: data = await self.reader.read(1024) if not data: break self.recvbuf += data await self.got_data() except asyncio.CancelledError: break except Exception as e: self.log.error("Error during connection handling: %s", e) self.exceptions.append(e) break self.close_connection() def show_debug_msg(self, msg): self.log.debug(msg) # def handle_connect(self): # self.show_debug_msg("MiniNode: Connected & Listening: \n") # self.state = "connected" def close_connection(self): self.show_debug_msg( f"MiniNode: Closing Connection to {self.dstaddr}:{self.dstport}... " ) if self.writer: self.writer.close() asyncio.create_task(self.writer.wait_closed()) self.recvbuf = b"" self.sendbuf = b"" self.state = "closed" self.log.info("Connection closed") self.cb.on_close(self) # pylint: disable=too-many-branches,too-many-nested-blocks async def got_data(self): self.recv_buf_len = len(self.recvbuf) try: while True: now_len = len(self.recvbuf) self.cur_index += self.recv_buf_len - now_len self.recv_buf_len = now_len if now_len < 4: return if self.recvbuf[:4] != self.magic_bytes: raise ValueError(f"got garbage {repr(self.recvbuf)}") if self.ver_recv < 209: if len(self.recvbuf) < 4 + 12 + 4: return command = self.recvbuf[4 : 4 + 12].split(b"\x00", 1)[0] msglen = struct.unpack("= 209: if self.produce0_checksum: tmsg += b"\x00" * 4 else: th = sha256(data) h = sha256(th) tmsg += h[:4] tmsg += data self.sendbuf += tmsg self.last_sent = time.time() print("Sending", command, "state", self.state) if self.state == "connected": self.writer.write(self.sendbuf) await self.writer.drain() self.sendbuf = b"" async def got_message(self, message): if self.last_sent + 30 * 60 < time.time(): await self.send_message(self.messagemap[b"ping"]()) print(f"Recv {repr(message)}") self.show_debug_msg(f"Recv {repr(message)}") await self.cb.deliver(self, message) def disconnect_node(self): self.disconnect = True # An exception we can raise if we detect a potential disconnect # (p2p or rpc) before the test is complete class EarlyDisconnectError(Exception): def __init__(self, value): self.value = value def __str__(self): return repr(self.value) async def wait_for_block_in_chain_tips(node, blockhash, timeout=30): """Waits for a block to appear in the chaintip list. Returns None if timeout or that block's chaintip data""" start = time.time() while time.time() < start + timeout: gct = node.getchaintips() for t in gct: if t["hash"] == blockhash: return t await asyncio.sleep(1) raise AssertionError(f"Block {str(blockhash)} never appeared in chain tips") class P2PDataStore(SingleNodeConnCB): """A P2P data store class. Keeps a block and transaction store and responds correctly to getdata and getheaders requests. """ def __init__(self): super().__init__() # store of blocks. key is block hash, value is a CBlock object self.block_store = {} self.last_block_hash = "" # store of txs. key is txid, value is a CTransaction object self.tx_store = {} self.getdata_requests = [] def on_getdata(self, conn, message): """Check for the tx/block in our stores and if found, reply with an inv message.""" async def request(): for inv in message.inv: self.getdata_requests.append(inv.hash) if inv.type == CInv.MSG_TX and inv.hash in self.tx_store: await self.send_message(MsgTx(self.tx_store[inv.hash])) elif inv.type == CInv.MSG_BLOCK and inv.hash in self.block_store: await self.send_message(MsgBlock(self.block_store[inv.hash])) else: logging.debug("getdata message type %s received.", hex(inv.type)) asyncio.create_task(request()) def on_getheaders(self, conn, message): """Search back through our block store for the locator, and reply with a headers message if found.""" locator, hash_stop = message.locator, message.hashstop # Assume that the most recent block added is the tip if not self.block_store: return headers_list = [self.block_store[self.last_block_hash]] maxheaders = 2000 while headers_list[-1].gethash() not in locator.vHave: # Walk back through the block store, adding headers to headers_list # as we go. prev_block_hash = headers_list[-1].hashPrevBlock if prev_block_hash in self.block_store: prev_block_header = CBlockHeader(self.block_store[prev_block_hash]) headers_list.append(prev_block_header) if prev_block_header.gethash() == hash_stop: # if this is the hashstop header, stop here break else: logging.debug( "block hash %s not found in block store", hex(prev_block_hash) ) break # Truncate the list if there are too many headers headers_list = headers_list[: -maxheaders - 1 : -1] response = MsgHeaders(headers_list) if response is not None: asyncio.create_task(self.send_message(response)) # pylint: disable=too-many-locals, too-many-arguments async def send_blocks_and_test( self, blocks, node, *, success=True, request_block=True, reject_reason=None, expect_ban=False, expect_disconnect=False, timeout=60, ): """Send blocks to test node and test whether the tip advances. - add all blocks to our block_store - send all headers - the on_getheaders handler will ensure that any getheaders are responded to - if request_block is True: wait for getdata for each of the blocks. The on_getdata handler will ensure that any getdata messages are responded to - if success is True: assert that the node's tip advances to the most recent block - if success is False: assert that the node's tip doesn't advance - if reject_reason is set: assert that the correct reject message is logged""" with mininode_lock: for block in blocks: self.block_store[block.gethash()] = block self.last_block_hash = block.gethash() def to_headers(blocks): return [CBlockHeader(b) for b in blocks] ban_msg = "BAN THRESHOLD EXCEEDED" expected_msgs = [] unexpected_msgs = [] if reject_reason: expected_msgs.append(reject_reason) if expect_ban: expected_msgs.append(ban_msg) else: unexpected_msgs.append(ban_msg) with node.assert_debug_log( expected_msgs=expected_msgs, unexpected_msgs=unexpected_msgs ): await self.send_message(MsgHeaders(to_headers(blocks))) if request_block: ok = await wait_until( lambda: blocks[-1].gethash() in self.getdata_requests, timeout=timeout, ) assert ok, f"did not receive getdata for {blocks[-1].gethash()}" if expect_disconnect: # self.wait_for_disconnect() raise Exception("NYI") await self.sync_with_ping() if success: ok = await wait_until( lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout ) assert ok, f"node failed to sync to block {blocks[-1].gethash('hex')}" else: ct = await wait_for_block_in_chain_tips(node, blocks[-1].hash, timeout) assert ( ct["status"] == "invalid" ) # Was expecting failure but block is not invalid gbbh = node.getbestblockhash() assert gbbh != blocks[-1].hash # pylint: disable=too-many-arguments async def send_txs_and_test( self, txs, node, *, success=True, expect_ban=False, reject_reason=None, timeout=60, ): """Send txs to test node and test whether they're accepted to the mempool. - add all txs to our tx_store - send tx messages for all txs - if success is True/False: assert that the txs are/are not accepted to the mempool - if expect_disconnect is True: Skip the sync with ping - if reject_reason is set: assert that the correct reject message is logged.""" assert len(txs) with mininode_lock: for tx in txs: self.tx_store[tx.get_id()] = tx ban_msg = "BAN THRESHOLD EXCEEDED" expected_msgs = [] unexpected_msgs = [] if reject_reason: expected_msgs.append(reject_reason) if expect_ban: expected_msgs.append(ban_msg) else: unexpected_msgs.append(ban_msg) with node.assert_debug_log( expected_msgs=expected_msgs, unexpected_msgs=unexpected_msgs ): for tx in txs: await self.send_message(MsgTx(tx)) await self.sync_with_ping() async def wait_for_tx_not_in_mempool(tx: CTransaction): await wait_for( timeout, lambda: tx.get_rpc_hex_id() not in node.getrawtxpool(False, "id"), on_error=f"{tx.get_rpc_hex_id()} tx not found in mempool", ) async def wait_for_tx_in_mempool(tx: CTransaction): await wait_for( timeout, lambda: tx.get_rpc_hex_id() in node.getrawtxpool(False, "id"), on_error=f"{tx.get_rpc_hex_id()} tx not found in mempool", ) if success: # Check that all txs are now in the mempool for tx in txs: await wait_for_tx_in_mempool(tx) else: # Check that none of the txs are now in the mempool for tx in txs: await wait_for_tx_not_in_mempool(tx)