#!/usr/bin/env python3 # Copyright 2022 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ Posts benchmark results to gist and comments on pull requests. Requires the environment variables: - GITHUB_TOKEN: token from GitHub action that has write access on issues. See https://docs.github.com/en/actions/security-guides/automatic-token-authentication#permissions-for-the-github_token - COMMENT_BOT_USER: user name that posts the comment. Note this can be different from the user creates the gist. - GIST_BOT_TOKEN: token that has write access to gist. Gist will be posted as the owner of the token. See https://docs.github.com/en/rest/overview/permissions-required-for-fine-grained-personal-access-tokens#gists """ import sys import pathlib # Add build_tools python dir to the search path. sys.path.insert(0, str(pathlib.Path(__file__).parent.with_name("python"))) import argparse import http.client import json import os import requests from typing import Any, Optional from reporting import benchmark_comment GITHUB_IREE_API_PREFIX = "https://api.github.com/repos/openxla/iree" GITHUB_GIST_API = "https://api.github.com/gists" GITHUB_API_VERSION = "2022-11-28" class APIRequester(object): """REST API client that injects proper GitHub authentication headers.""" def __init__(self, github_token: str): self._api_headers = { "Accept": "application/vnd.github+json", "Authorization": f"token {github_token}", "X-GitHub-Api-Version": GITHUB_API_VERSION, } self._session = requests.session() def get(self, endpoint: str, payload: Any = {}) -> requests.Response: return self._session.get( endpoint, data=json.dumps(payload), headers=self._api_headers ) def post(self, endpoint: str, payload: Any = {}) -> requests.Response: return self._session.post( endpoint, data=json.dumps(payload), headers=self._api_headers ) def patch(self, endpoint: str, payload: Any = {}) -> requests.Response: return self._session.patch( endpoint, data=json.dumps(payload), headers=self._api_headers ) class GithubClient(object): """Helper to call Github REST APIs.""" def __init__(self, requester: APIRequester): self._requester = requester def post_to_gist(self, filename: str, content: str, verbose: bool = False) -> str: """Posts the given content to a new GitHub Gist and returns the URL to it.""" response = self._requester.post( endpoint=GITHUB_GIST_API, payload={"public": True, "files": {filename: {"content": content}}}, ) if response.status_code != http.client.CREATED: raise RuntimeError( f"Failed to create on gist; error code: {response.status_code} - {response.text}" ) response = response.json() if verbose: print(f"Gist posting response: {response}") if response["truncated"]: raise RuntimeError(f"Content is too large and was truncated") return response["html_url"] def get_previous_comment_on_pr( self, pr_number: int, comment_bot_user: str, comment_type_id: str, query_comment_per_page: int = 100, max_pages_to_search: int = 10, verbose: bool = False, ) -> Optional[int]: """Gets the previous comment's id from GitHub.""" for page in range(1, max_pages_to_search + 1): response = self._requester.get( endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments", payload={ "per_page": query_comment_per_page, "page": page, "sort": "updated", "direction": "desc", }, ) if response.status_code != http.client.OK: raise RuntimeError( f"Failed to get PR comments from GitHub; error code: {response.status_code} - {response.text}" ) comments = response.json() if verbose: print(f"Previous comment query response on page {page}: {comments}") # Find the most recently updated comment that matches. for comment in comments: if ( comment["user"]["login"] == comment_bot_user and comment_type_id in comment["body"] ): return comment["id"] if len(comments) < query_comment_per_page: break return None def update_comment_on_pr(self, comment_id: int, content: str): """Updates the content of the given comment id.""" response = self._requester.patch( endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/comments/{comment_id}", payload={"body": content}, ) if response.status_code != http.client.OK: raise RuntimeError( f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}" ) def create_comment_on_pr(self, pr_number: int, content: str): """Posts the given content as comments to the current pull request.""" response = self._requester.post( endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments", payload={"body": content}, ) if response.status_code != http.client.CREATED: raise RuntimeError( f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}" ) def get_pull_request_head_commit(self, pr_number: int) -> str: """Get pull request head commit SHA.""" response = self._requester.get( endpoint=f"{GITHUB_IREE_API_PREFIX}/pulls/{pr_number}" ) if response.status_code != http.client.OK: raise RuntimeError( f"Failed to fetch the pull request: {pr_number}; " f"error code: {response.status_code} - {response.text}" ) return response.json()["head"]["sha"] def _parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("comment_json", type=pathlib.Path) parser.add_argument("--verbose", action="store_true") verification_parser = parser.add_mutually_exclusive_group(required=True) verification_parser.add_argument("--github_event_json", type=pathlib.Path) return parser.parse_args() def main(args: argparse.Namespace): github_token = os.environ.get("GITHUB_TOKEN") if github_token is None: raise ValueError("GITHUB_TOKEN must be set.") comment_bot_user = os.environ.get("COMMENT_BOT_USER") if comment_bot_user is None: raise ValueError("COMMENT_BOT_USER must be set.") gist_bot_token = os.environ.get("GIST_BOT_TOKEN") if gist_bot_token is None: raise ValueError("GIST_BOT_TOKEN must be set.") comment_data = benchmark_comment.CommentData( **json.loads(args.comment_json.read_text()) ) # Sanitize the pr number to make sure it is an integer. pr_number = int(comment_data.unverified_pr_number) pr_client = GithubClient(requester=APIRequester(github_token=github_token)) if args.github_event_json is None: github_event = None else: github_event = json.loads(args.github_event_json.read_text()) workflow_run_sha = github_event["workflow_run"]["head_sha"] pr_head_sha = pr_client.get_pull_request_head_commit(pr_number=pr_number) # We can't get the trusted PR number of a workflow run from GitHub API. So we # take the untrusted PR number from presubmit workflow and verify if the PR's # current head SHA matches the commit SHA in the workflow run. It assumes # that to generate the malicious comment data, attacker must modify the code # and has a new commit SHA. So if the PR head commit matches the workflow # run with attacker's commit, either the PR is created by the attacker or # other's PR has the malicious commit. In both cases posting malicious # comment is acceptable. # # Note that the collision of a target SHA1 is possible but GitHub has some # protections (https://github.blog/2017-03-20-sha-1-collision-detection-on-github-com/). # The assumption also only holds if files in GCS can't be overwritten (so the # comment data can't be modified without changing the code). # The check will also fail if the PR author pushes the new commit after the # workflow is triggered. But pushing the new commit means to cancel the # current CI run including the benchmarking. So it will unlikely fail for # that reason. if workflow_run_sha != pr_head_sha: raise ValueError( f"Workflow run SHA: {workflow_run_sha} does not match " f"the head SHA: {pr_head_sha} of the pull request: {pr_number}." ) gist_client = GithubClient(requester=APIRequester(github_token=gist_bot_token)) gist_url = gist_client.post_to_gist( filename=f"iree-full-benchmark-results-{pr_number}.md", content=comment_data.full_md, verbose=args.verbose, ) previous_comment_id = pr_client.get_previous_comment_on_pr( pr_number=pr_number, comment_bot_user=comment_bot_user, comment_type_id=comment_data.type_id, verbose=args.verbose, ) abbr_md = comment_data.abbr_md.replace( benchmark_comment.GIST_LINK_PLACEHORDER, gist_url ) if github_event is not None: abbr_md += ( f'\n\n[Source Workflow Run]({github_event["workflow_run"]["html_url"]})' ) if previous_comment_id is not None: pr_client.update_comment_on_pr(comment_id=previous_comment_id, content=abbr_md) else: pr_client.create_comment_on_pr(pr_number=pr_number, content=abbr_md) if __name__ == "__main__": main(_parse_arguments())