package runtime import ( "context" "errors" "fmt" "io" "net/http" "net/textproto" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/internal" "google.golang.org/grpc/grpclog" ) var errEmptyResponse = errors.New("empty response") // ForwardResponseStream forwards the stream from gRPC server to REST client. func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { f, ok := w.(http.Flusher) if !ok { grpclog.Infof("Flush not supported in %T", w) http.Error(w, "unexpected type of web server", http.StatusInternalServerError) return } md, ok := ServerMetadataFromContext(ctx) if !ok { grpclog.Infof("Failed to extract ServerMetadata from context") http.Error(w, "unexpected error", http.StatusInternalServerError) return } handleForwardResponseServerMetadata(w, mux, md) w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Content-Type", marshaler.ContentType()) if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil { HTTPError(ctx, mux, marshaler, w, req, err) return } var delimiter []byte if d, ok := marshaler.(Delimited); ok { delimiter = d.Delimiter() } else { delimiter = []byte("\n") } var wroteHeader bool for { resp, err := recv() if err == io.EOF { return } if err != nil { handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err) return } if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err) return } var buf []byte switch { case resp == nil: buf, err = marshaler.Marshal(errorChunk(streamError(ctx, mux.streamErrorHandler, errEmptyResponse))) default: result := map[string]interface{}{"result": resp} if rb, ok := resp.(responseBody); ok { result["result"] = rb.XXX_ResponseBody() } buf, err = marshaler.Marshal(result) } if err != nil { grpclog.Infof("Failed to marshal response chunk: %v", err) handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err) return } if _, err = w.Write(buf); err != nil { grpclog.Infof("Failed to send response chunk: %v", err) return } wroteHeader = true if _, err = w.Write(delimiter); err != nil { grpclog.Infof("Failed to send delimiter chunk: %v", err) return } f.Flush() } } func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) { for k, vs := range md.HeaderMD { if h, ok := mux.outgoingHeaderMatcher(k); ok { for _, v := range vs { w.Header().Add(h, v) } } } } func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) { for k := range md.TrailerMD { tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)) w.Header().Add("Trailer", tKey) } } func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) { for k, vs := range md.TrailerMD { tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k) for _, v := range vs { w.Header().Add(tKey, v) } } } // responseBody interface contains method for getting field for marshaling to the response body // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule` type responseBody interface { XXX_ResponseBody() interface{} } // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client. func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { md, ok := ServerMetadataFromContext(ctx) if !ok { grpclog.Infof("Failed to extract ServerMetadata from context") } handleForwardResponseServerMetadata(w, mux, md) handleForwardResponseTrailerHeader(w, md) contentType := marshaler.ContentType() // Check marshaler on run time in order to keep backwards compatability // An interface param needs to be added to the ContentType() function on // the Marshal interface to be able to remove this check if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok { contentType = typeMarshaler.ContentTypeFromMessage(resp) } w.Header().Set("Content-Type", contentType) if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { HTTPError(ctx, mux, marshaler, w, req, err) return } var buf []byte var err error if rb, ok := resp.(responseBody); ok { buf, err = marshaler.Marshal(rb.XXX_ResponseBody()) } else { buf, err = marshaler.Marshal(resp) } if err != nil { grpclog.Infof("Marshal error: %v", err) HTTPError(ctx, mux, marshaler, w, req, err) return } if _, err = w.Write(buf); err != nil { grpclog.Infof("Failed to write response: %v", err) } handleForwardResponseTrailer(w, md) } func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error { if len(opts) == 0 { return nil } for _, opt := range opts { if err := opt(ctx, w, resp); err != nil { grpclog.Infof("Error handling ForwardResponseOptions: %v", err) return err } } return nil } func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) { serr := streamError(ctx, mux.streamErrorHandler, err) if !wroteHeader { w.WriteHeader(int(serr.HttpCode)) } buf, merr := marshaler.Marshal(errorChunk(serr)) if merr != nil { grpclog.Infof("Failed to marshal an error: %v", merr) return } if _, werr := w.Write(buf); werr != nil { grpclog.Infof("Failed to notify error to client: %v", werr) return } } // streamError returns the payload for the final message in a response stream // that represents the given err. func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError { serr := errHandler(ctx, err) if serr != nil { return serr } // TODO: log about misbehaving stream error handler? return DefaultHTTPStreamErrorHandler(ctx, err) } func errorChunk(err *StreamError) map[string]proto.Message { return map[string]proto.Message{"error": (*internal.StreamError)(err)} }