package shared import ( "bytes" "fmt" "text/template" pgs "github.com/lyft/protoc-gen-star/v2" "google.golang.org/protobuf/proto" "github.com/envoyproxy/protoc-gen-validate/validate" ) type RuleContext struct { Field pgs.Field Rules proto.Message MessageRules *validate.MessageRules Typ string WrapperTyp string OnKey bool Index string AccessorOverride string } func rulesContext(f pgs.Field) (out RuleContext, err error) { out.Field = f var rules validate.FieldRules if _, err = f.Extension(validate.E_Rules, &rules); err != nil { return } var wrapped bool if out.Typ, out.Rules, out.MessageRules, wrapped = resolveRules(f.Type(), &rules); wrapped { out.WrapperTyp = out.Typ out.Typ = "wrapper" } if out.Typ == "error" { err = fmt.Errorf("unknown rule type (%T)", rules.Type) } return } func (ctx RuleContext) Key(name, idx string) (out RuleContext, err error) { rules, ok := ctx.Rules.(*validate.MapRules) if !ok { err = fmt.Errorf("cannot get Key RuleContext from %T", ctx.Field) return } out.Field = ctx.Field out.AccessorOverride = name out.Index = idx out.Typ, out.Rules, out.MessageRules, _ = resolveRules(ctx.Field.Type().Key(), rules.GetKeys()) if out.Typ == "error" { err = fmt.Errorf("unknown rule type (%T)", rules) } return } func (ctx RuleContext) Elem(name, idx string) (out RuleContext, err error) { out.Field = ctx.Field out.AccessorOverride = name out.Index = idx var rules *validate.FieldRules switch r := ctx.Rules.(type) { case *validate.MapRules: rules = r.GetValues() case *validate.RepeatedRules: rules = r.GetItems() default: err = fmt.Errorf("cannot get Elem RuleContext from %T", ctx.Field) return } var wrapped bool if out.Typ, out.Rules, out.MessageRules, wrapped = resolveRules(ctx.Field.Type().Element(), rules); wrapped { out.WrapperTyp = out.Typ out.Typ = "wrapper" } if out.Typ == "error" { err = fmt.Errorf("unknown rule type (%T)", rules) } return } func (ctx RuleContext) Unwrap(name string) (out RuleContext, err error) { if ctx.Typ != "wrapper" { err = fmt.Errorf("cannot unwrap non-wrapper type %q", ctx.Typ) return } return RuleContext{ Field: ctx.Field, Rules: ctx.Rules, MessageRules: ctx.MessageRules, Typ: ctx.WrapperTyp, AccessorOverride: name, }, nil } func Render(tpl *template.Template) func(ctx RuleContext) (string, error) { return func(ctx RuleContext) (string, error) { var b bytes.Buffer err := tpl.ExecuteTemplate(&b, ctx.Typ, ctx) return b.String(), err } } func resolveRules(typ interface{ IsEmbed() bool }, rules *validate.FieldRules) (ruleType string, rule proto.Message, messageRule *validate.MessageRules, wrapped bool) { switch r := rules.GetType().(type) { case *validate.FieldRules_Float: ruleType, rule, wrapped = "float", r.Float, typ.IsEmbed() case *validate.FieldRules_Double: ruleType, rule, wrapped = "double", r.Double, typ.IsEmbed() case *validate.FieldRules_Int32: ruleType, rule, wrapped = "int32", r.Int32, typ.IsEmbed() case *validate.FieldRules_Int64: ruleType, rule, wrapped = "int64", r.Int64, typ.IsEmbed() case *validate.FieldRules_Uint32: ruleType, rule, wrapped = "uint32", r.Uint32, typ.IsEmbed() case *validate.FieldRules_Uint64: ruleType, rule, wrapped = "uint64", r.Uint64, typ.IsEmbed() case *validate.FieldRules_Sint32: ruleType, rule, wrapped = "sint32", r.Sint32, false case *validate.FieldRules_Sint64: ruleType, rule, wrapped = "sint64", r.Sint64, false case *validate.FieldRules_Fixed32: ruleType, rule, wrapped = "fixed32", r.Fixed32, false case *validate.FieldRules_Fixed64: ruleType, rule, wrapped = "fixed64", r.Fixed64, false case *validate.FieldRules_Sfixed32: ruleType, rule, wrapped = "sfixed32", r.Sfixed32, false case *validate.FieldRules_Sfixed64: ruleType, rule, wrapped = "sfixed64", r.Sfixed64, false case *validate.FieldRules_Bool: ruleType, rule, wrapped = "bool", r.Bool, typ.IsEmbed() case *validate.FieldRules_String_: ruleType, rule, wrapped = "string", r.String_, typ.IsEmbed() case *validate.FieldRules_Bytes: ruleType, rule, wrapped = "bytes", r.Bytes, typ.IsEmbed() case *validate.FieldRules_Enum: ruleType, rule, wrapped = "enum", r.Enum, false case *validate.FieldRules_Repeated: ruleType, rule, wrapped = "repeated", r.Repeated, false case *validate.FieldRules_Map: ruleType, rule, wrapped = "map", r.Map, false case *validate.FieldRules_Any: ruleType, rule, wrapped = "any", r.Any, false case *validate.FieldRules_Duration: ruleType, rule, wrapped = "duration", r.Duration, false case *validate.FieldRules_Timestamp: ruleType, rule, wrapped = "timestamp", r.Timestamp, false case nil: if ft, ok := typ.(pgs.FieldType); ok && ft.IsRepeated() { return "repeated", &validate.RepeatedRules{}, rules.Message, false } else if ok && ft.IsMap() && ft.Element().IsEmbed() { return "map", &validate.MapRules{}, rules.Message, false } else if typ.IsEmbed() { return "message", rules.GetMessage(), rules.GetMessage(), false } return "none", nil, nil, false default: ruleType, rule, wrapped = "error", nil, false } return ruleType, rule, rules.Message, wrapped }