// Copyright 2022 Risc0, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "risc0/zkvm/sdk/cpp/host/receipt.h" #include "risc0/core/log.h" #include "risc0/zkp/core/sha256_cpu.h" #include "risc0/zkp/prove/prove.h" #include "risc0/zkp/verify/verify.h" #include "risc0/zkvm/prove/method_id.h" #include "risc0/zkvm/prove/riscv.h" #include "risc0/zkvm/verify/riscv.h" #include namespace risc0 { void Receipt::verify(const std::string& filename) const { LOG(1, "Reading code id from " << filename); MethodID code = readMethodID(filename); std::unique_ptr circuit = getRiscVVerifyCircuit(code); risc0::verify(*circuit, seal.data(), seal.size()); if (journal.size() != seal[8]) { std::stringstream ss; ss << "Receipt::verify> journal size (" << journal.size() << ") does not match receipt seal (" << seal[8] << ")"; throw std::runtime_error(ss.str()); } if (journal.size() > 32) { ShaDigest digest = shaHash(journal.data(), journal.size()); if (memcmp(&digest, seal.data(), sizeof(ShaDigest)) != 0) { throw std::runtime_error("Receipt journal/seal root mismatch"); } } else { if (memcmp(journal.data(), seal.data(), journal.size()) != 0) { throw std::runtime_error("Receipt journal/seal root mismatch"); } } } struct Prover::Impl : public IoHandler { Impl(const std::string& elfPath, const std::string& idPath) : elfPath(elfPath) , idPath(idPath) , outputStream(outputBuffer) , commitStream(commitBuffer) , inputWriter(inputStream) , outputReader(outputStream) , commitReader(commitStream) {} virtual ~Impl() {} void onInit(MemoryState& mem) override { LOG(1, "Prover::onInit>"); uint32_t addr = kMemInputStart; for (uint32_t word : inputStream.vec) { if (addr > kMemInputEnd) { throw std::runtime_error("Out of memory: inputs"); } LOG(1, " " << hex(addr) << ": " << hex(word)); mem.store(addr, word); addr += sizeof(uint32_t); } } void onWrite(const BufferU8& buf) override { LOG(1, "IoHandler::onWrite> " << buf.size()); outputBuffer.insert(outputBuffer.end(), buf.begin(), buf.end()); } void onCommit(const BufferU8& buf) override { LOG(1, "IoHandler::onCommit> " << buf.size()); commitBuffer.insert(commitBuffer.end(), buf.begin(), buf.end()); } KeyStore& getKeyStore() override { return keyStore; } std::string elfPath; std::string idPath; KeyStore keyStore; BufferU8 outputBuffer; BufferU8 commitBuffer; VectorStreamWriter inputStream; CheckedStreamReader outputStream; CheckedStreamReader commitStream; ArchiveWriter inputWriter; ArchiveReader outputReader; ArchiveReader commitReader; }; CheckedStreamReader::CheckedStreamReader(const BufferU8& buffer) : buffer(buffer), cursor(0) {} uint32_t CheckedStreamReader::read_word() { if (cursor + sizeof(uint32_t) > buffer.size()) { throw(std::out_of_range("Read out of bounds")); } uint32_t b1 = buffer[cursor++]; uint32_t b2 = buffer[cursor++]; uint32_t b3 = buffer[cursor++]; uint32_t b4 = buffer[cursor++]; return b1 | b2 << 8 | b3 << 16 | b4 << 24; } uint64_t CheckedStreamReader::read_dword() { uint64_t low = read_word(); uint64_t high = read_word(); return low | high << 32; } void CheckedStreamReader::read_buffer(void* buf, size_t len) { size_t end_cursor = align(cursor + len); if (end_cursor > buffer.size()) { throw(std::out_of_range("Read out of bounds")); } memcpy(buf, buffer.data(), len); cursor = end_cursor; } Prover::Prover(const std::string& elfPath, const std::string& idPath) : impl(new Impl(elfPath, idPath)) {} Prover::~Prover() = default; KeyStore& Prover::getKeyStore() { return impl->getKeyStore(); } void Prover::setKey(const std::string& name, const Key& key) { impl->getKeyStore()[name] = key; } const BufferU8& Prover::getOutput() { return impl->outputBuffer; } const BufferU8& Prover::getCommit() { return impl->commitBuffer; } ArchiveWriter& Prover::getInputWriter() { return impl->inputWriter; } ArchiveReader& Prover::getOutputReader() { return impl->outputReader; } ArchiveReader& Prover::getCommitReader() { return impl->commitReader; } void Prover::writeInput(const void* ptr, size_t size) { LOG(1, "Prover::writeInput> size: " << size); const uint8_t* ptr_u8 = static_cast(ptr); while (size >= sizeof(uint32_t)) { uint32_t word = 0; word |= *ptr_u8++; word |= *ptr_u8++ << 8; word |= *ptr_u8++ << 16; word |= *ptr_u8++ << 24; LOG(1, " write_word: " << hex(word)); impl->inputStream.write_word(word); size -= sizeof(uint32_t); } if (size) { LOG(1, " tail: " << size); uint32_t word = 0; for (size_t i = 0; i < size; i++) { word |= *ptr_u8++ << (8 * i); } LOG(1, " write_word: " << hex(word)); impl->inputStream.write_word(word); } } Receipt Prover::run() { // Set the memory handlers to call back to the impl MemoryHandler handler(impl.get()); // Make the circuit std::unique_ptr circuit = getRiscVProveCircuit(impl->elfPath.c_str(), handler); BufferU32 seal = prove(*circuit); // Attach the full version of the output journal + construct receipt object Receipt receipt{getCommit(), seal}; // Verify receipt to make sure it works receipt.verify(impl->idPath); return receipt; } } // namespace risc0