// Copyright 2022 Risc0, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "risc0/core/key.h" #include "risc0/zkp/core/fp.h" #include "risc0/zkvm/circuit/constants.h" #include "risc0/zkvm/platform/io.h" #include #include #include namespace risc0 { using BufferU8 = std::vector; using BufferU32 = std::vector; struct MemoryEvent { uint32_t addr; uint32_t cycle; bool isWrite; uint32_t data; bool operator<(const MemoryEvent& rhs) const { if (addr != rhs.addr) { return addr < rhs.addr; } return cycle < rhs.cycle; } }; struct MemoryState { std::map data; std::set history; void dump(size_t logLevel); uint8_t loadByte(uint32_t addr); uint32_t load(uint32_t addr); uint32_t loadBE(uint32_t addr); void loadRegion(uint32_t addr, void* ptr, uint32_t len); void storeByte(uint32_t addr, uint8_t byte); void store(uint32_t addr, uint32_t value); void store(uint32_t addr, const void* ptr, uint32_t len); size_t strlen(uint32_t addr); }; struct IoHandler { virtual void onInit(MemoryState& mem) {} virtual BufferU8 onSendRecv(uint32_t channelId, const BufferU8& data) { return BufferU8(); } virtual void onCommit(const BufferU8& data) {} virtual void onFault(const std::string& msg); virtual KeyStore& getKeyStore() = 0; }; class MemoryHandler { public: MemoryHandler(); MemoryHandler(IoHandler* io); // Called before the load of the ELF. Can write to memory, but any memory loaded from the // ELF will override. virtual void onInit(MemoryState& mem); // Called after loading an ELF, can write to any memory not already loaded. virtual void onLoaded(MemoryState& mem) {} // onRead is called when uninitalized memory is read, giving the host a chance to return with a // value. virtual uint32_t onRead(MemoryState& mem, uint32_t addr) { return 0; } // onWrite is called when a word is written to memory, giving the host a chance to be notified of // new data. virtual void onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value); // Called after the system is halted, gets final memory state & final output. virtual void onHalt(const MemoryState& mem, const std::array& output) {} private: IoHandler* io; // Memory address of current host->guest transmission. The host can only // write to each memory location once, so this advances after each write. uint32_t cur_host_to_guest_offset; }; struct StepContext { MemoryHandler* io; MemoryState mem; uint32_t curStep; uint32_t numSteps; Fp globals[kGlobalSize]; Fp get(const Fp* buf, size_t offset, size_t back); void set(Fp* buf, size_t offset, Fp val); Fp getDigits(const Fp* buf, size_t bits, size_t offset, size_t back, size_t size); Fp setDigits(Fp* buf, size_t bits, size_t offset, size_t size, Fp val); Fp getMux(const Fp* buf, size_t offset, size_t back, size_t size); void setMux(Fp* buf, size_t offset, size_t size, Fp val); void memWrite(Fp cycle, Fp addr, Fp low, Fp high); std::array memRead(Fp cycle, Fp addr); std::array memCheck(); // Cycle, Addr, IsWrite, Low, High std::array divide(Fp numerLow, Fp numerHigh, Fp denomLow, Fp denomHigh); void requireDigits(Fp* buf, size_t bits, size_t offset, size_t size); void requireMux(Fp* buf, size_t offset, size_t size, const char* msg); void requireZero(Fp val, const char* msg); }; void setupCode(Fp* code, size_t numSteps, uint32_t startAddr, const std::map& image); void dataStepExec(StepContext& ctx, const Fp* code, Fp* data); void dataStepCheck(StepContext& ctx, const Fp* code, Fp* data); void accumStep(StepContext& ctx, const Fp* code, const Fp* data, Fp* accum); } // namespace risc0