You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
437 lines
11 KiB
437 lines
11 KiB
package cc |
|
|
|
import ( |
|
"fmt" |
|
"reflect" |
|
"strconv" |
|
"strings" |
|
"text/template" |
|
|
|
"github.com/envoyproxy/protoc-gen-validate/templates/shared" |
|
"github.com/iancoleman/strcase" |
|
pgs "github.com/lyft/protoc-gen-star" |
|
pgsgo "github.com/lyft/protoc-gen-star/lang/go" |
|
"google.golang.org/protobuf/types/known/durationpb" |
|
"google.golang.org/protobuf/types/known/timestamppb" |
|
) |
|
|
|
func RegisterModule(tpl *template.Template, params pgs.Parameters) { |
|
fns := CCFuncs{pgsgo.InitContext(params)} |
|
|
|
tpl.Funcs(map[string]interface{}{ |
|
"accessor": fns.accessor, |
|
"byteStr": fns.byteStr, |
|
"class": fns.className, |
|
"cmt": pgs.C80, |
|
"ctype": fns.cType, |
|
"durGt": fns.durGt, |
|
"durLit": fns.durLit, |
|
"durStr": fns.durStr, |
|
"err": fns.err, |
|
"errCause": fns.errCause, |
|
"errIdx": fns.errIdx, |
|
"errIdxCause": fns.errIdxCause, |
|
"hasAccessor": fns.hasAccessor, |
|
"inKey": fns.inKey, |
|
"inType": fns.inType, |
|
"isBytes": fns.isBytes, |
|
"lit": fns.lit, |
|
"lookup": fns.lookup, |
|
"oneof": fns.oneofTypeName, |
|
"output": fns.output, |
|
"package": fns.packageName, |
|
"quote": fns.quote, |
|
"staticVarName": fns.staticVarName, |
|
"tsGt": fns.tsGt, |
|
"tsLit": fns.tsLit, |
|
"tsStr": fns.tsStr, |
|
"typ": fns.Type, |
|
"unimplemented": fns.failUnimplemented, |
|
"unwrap": fns.unwrap, |
|
}) |
|
template.Must(tpl.Parse(moduleFileTpl)) |
|
template.Must(tpl.New("msg").Parse(msgTpl)) |
|
template.Must(tpl.New("const").Parse(constTpl)) |
|
template.Must(tpl.New("ltgt").Parse(ltgtTpl)) |
|
template.Must(tpl.New("in").Parse(inTpl)) |
|
template.Must(tpl.New("required").Parse(requiredTpl)) |
|
|
|
template.Must(tpl.New("none").Parse(noneTpl)) |
|
template.Must(tpl.New("float").Parse(numTpl)) |
|
template.Must(tpl.New("double").Parse(numTpl)) |
|
template.Must(tpl.New("int32").Parse(numTpl)) |
|
template.Must(tpl.New("int64").Parse(numTpl)) |
|
template.Must(tpl.New("uint32").Parse(numTpl)) |
|
template.Must(tpl.New("uint64").Parse(numTpl)) |
|
template.Must(tpl.New("sint32").Parse(numTpl)) |
|
template.Must(tpl.New("sint64").Parse(numTpl)) |
|
template.Must(tpl.New("fixed32").Parse(numTpl)) |
|
template.Must(tpl.New("fixed64").Parse(numTpl)) |
|
template.Must(tpl.New("sfixed32").Parse(numTpl)) |
|
template.Must(tpl.New("sfixed64").Parse(numTpl)) |
|
|
|
template.Must(tpl.New("bool").Parse(constTpl)) |
|
template.Must(tpl.New("string").Parse(strTpl)) |
|
template.Must(tpl.New("bytes").Parse(bytesTpl)) |
|
|
|
template.Must(tpl.New("email").Parse(emailTpl)) |
|
template.Must(tpl.New("hostname").Parse(hostTpl)) |
|
template.Must(tpl.New("address").Parse(hostTpl)) |
|
template.Must(tpl.New("uuid").Parse(uuidTpl)) |
|
|
|
template.Must(tpl.New("enum").Parse(enumTpl)) |
|
template.Must(tpl.New("message").Parse(messageTpl)) |
|
template.Must(tpl.New("repeated").Parse(repTpl)) |
|
template.Must(tpl.New("map").Parse(mapTpl)) |
|
|
|
template.Must(tpl.New("any").Parse(anyTpl)) |
|
template.Must(tpl.New("duration").Parse(durationTpl)) |
|
template.Must(tpl.New("timestamp").Parse(timestampTpl)) |
|
|
|
template.Must(tpl.New("wrapper").Parse(wrapperTpl)) |
|
} |
|
|
|
func RegisterHeader(tpl *template.Template, params pgs.Parameters) { |
|
fns := CCFuncs{pgsgo.InitContext(params)} |
|
|
|
tpl.Funcs(map[string]interface{}{ |
|
"class": fns.className, |
|
"output": fns.output, |
|
"screaming_snake_case": strcase.ToScreamingSnake, |
|
}) |
|
|
|
template.Must(tpl.Parse(headerFileTpl)) |
|
template.Must(tpl.New("decl").Parse(declTpl)) |
|
} |
|
|
|
// TODO(rodaine): break pgsgo dependency here (with equivalent pgscc subpackage) |
|
type CCFuncs struct{ pgsgo.Context } |
|
|
|
func CcFilePath(f pgs.File, ctx pgsgo.Context, tpl *template.Template) *pgs.FilePath { |
|
out := pgs.FilePath(f.Name().String()) |
|
out = out.SetExt(".pb.validate." + tpl.Name()) |
|
return &out |
|
} |
|
|
|
func (fns CCFuncs) methodName(name interface{}) string { |
|
nameStr := fmt.Sprintf("%s", name) |
|
switch nameStr { |
|
case "const": |
|
return "const_" |
|
case "inline": |
|
return "inline_" |
|
default: |
|
return nameStr |
|
} |
|
} |
|
|
|
func (fns CCFuncs) accessor(ctx shared.RuleContext) string { |
|
if ctx.AccessorOverride != "" { |
|
return ctx.AccessorOverride |
|
} |
|
|
|
return fmt.Sprintf( |
|
"m.%s()", |
|
fns.methodName(ctx.Field.Name())) |
|
} |
|
|
|
func (fns CCFuncs) hasAccessor(ctx shared.RuleContext) string { |
|
if ctx.AccessorOverride != "" { |
|
return "true" |
|
} |
|
|
|
return fmt.Sprintf( |
|
"m.has_%s()", |
|
fns.methodName(ctx.Field.Name())) |
|
} |
|
|
|
type childEntity interface { |
|
pgs.Entity |
|
Parent() pgs.ParentEntity |
|
} |
|
|
|
func (fns CCFuncs) classBaseName(ent childEntity) string { |
|
if m, ok := ent.Parent().(pgs.Message); ok { |
|
return fmt.Sprintf("%s_%s", fns.classBaseName(m), ent.Name().String()) |
|
} |
|
return ent.Name().String() |
|
} |
|
|
|
func (fns CCFuncs) className(ent childEntity) string { |
|
return fns.packageName(ent) + "::" + fns.classBaseName(ent) |
|
} |
|
|
|
func (fns CCFuncs) packageName(msg pgs.Entity) string { |
|
return "::" + strings.Join(msg.Package().ProtoName().Split(), "::") |
|
} |
|
|
|
func (fns CCFuncs) quote(s interface { |
|
String() string |
|
}) string { |
|
return strconv.Quote(s.String()) |
|
} |
|
|
|
func (fns CCFuncs) err(ctx shared.RuleContext, reason ...interface{}) string { |
|
return fns.errIdxCause(ctx, "", "nil", reason...) |
|
} |
|
|
|
func (fns CCFuncs) errCause(ctx shared.RuleContext, cause string, reason ...interface{}) string { |
|
return fns.errIdxCause(ctx, "", cause, reason...) |
|
} |
|
|
|
func (fns CCFuncs) errIdx(ctx shared.RuleContext, idx string, reason ...interface{}) string { |
|
return fns.errIdxCause(ctx, idx, "nil", reason...) |
|
} |
|
|
|
func (fns CCFuncs) errIdxCause(ctx shared.RuleContext, idx, cause string, reason ...interface{}) string { |
|
f := ctx.Field |
|
errName := fmt.Sprintf("%sValidationError", f.Message().Name()) |
|
|
|
output := []string{ |
|
"{", |
|
`std::ostringstream msg("invalid ");`, |
|
} |
|
|
|
if ctx.OnKey { |
|
output = append(output, `msg << "key for ";`) |
|
} |
|
output = append(output, |
|
fmt.Sprintf(`msg << %q << "." << %s;`, |
|
errName, fns.lit(pgsgo.PGGUpperCamelCase(f.Name())))) |
|
|
|
if idx != "" { |
|
output = append(output, fmt.Sprintf(`msg << "[" << %s << "]";`, idx)) |
|
} else if ctx.Index != "" { |
|
output = append(output, fmt.Sprintf(`msg << "[" << %s << "]";`, ctx.Index)) |
|
} |
|
|
|
output = append(output, fmt.Sprintf(`msg << ": " << %q;`, fmt.Sprint(reason...))) |
|
|
|
if cause != "nil" && cause != "" { |
|
output = append(output, fmt.Sprintf(`msg << " | caused by " << %s;`, cause)) |
|
} |
|
|
|
output = append(output, "*err = msg.str();", |
|
"return false;", |
|
"}") |
|
return strings.Join(output, "\n") |
|
} |
|
|
|
func (fns CCFuncs) lookup(f pgs.Field, name string) string { |
|
return fmt.Sprintf( |
|
"_%s_%s_%s", |
|
pgsgo.PGGUpperCamelCase(f.Message().Name()), |
|
pgsgo.PGGUpperCamelCase(f.Name()), |
|
name, |
|
) |
|
} |
|
|
|
func (fns CCFuncs) lit(x interface{}) string { |
|
val := reflect.ValueOf(x) |
|
|
|
if val.Kind() == reflect.Interface { |
|
val = val.Elem() |
|
} |
|
|
|
if val.Kind() == reflect.Ptr { |
|
val = val.Elem() |
|
} |
|
|
|
switch val.Kind() { |
|
case reflect.String: |
|
return fmt.Sprintf("%q", x) |
|
case reflect.Uint8: |
|
return fmt.Sprintf("%d", x) |
|
case reflect.Slice: |
|
els := make([]string, val.Len()) |
|
switch reflect.TypeOf(x).Elem().Kind() { |
|
case reflect.Uint8: |
|
for i, l := 0, val.Len(); i < l; i++ { |
|
els[i] = fmt.Sprintf("\\x%x", val.Index(i).Interface()) |
|
} |
|
return fmt.Sprintf("\"%s\"", strings.Join(els, "")) |
|
default: |
|
panic(fmt.Sprintf("don't know how to format literals of type %v", val.Kind())) |
|
} |
|
case reflect.Float32: |
|
return fmt.Sprintf("%fF", x) |
|
default: |
|
return fmt.Sprint(x) |
|
} |
|
} |
|
|
|
func (fns CCFuncs) isBytes(f interface { |
|
ProtoType() pgs.ProtoType |
|
}) bool { |
|
return f.ProtoType() == pgs.BytesT |
|
} |
|
|
|
func (fns CCFuncs) byteStr(x []byte) string { |
|
elms := make([]string, len(x)) |
|
for i, b := range x { |
|
elms[i] = fmt.Sprintf(`\x%X`, b) |
|
} |
|
|
|
return fmt.Sprintf(`"%s"`, strings.Join(elms, "")) |
|
} |
|
|
|
func (fns CCFuncs) oneofTypeName(f pgs.Field) pgsgo.TypeName { |
|
return pgsgo.TypeName(fmt.Sprintf("%s::%sCase::k%s", |
|
fns.className(f.Message()), |
|
pgsgo.PGGUpperCamelCase(f.OneOf().Name()), |
|
strings.ReplaceAll(pgsgo.PGGUpperCamelCase(f.Name()).String(), "_", ""))) |
|
} |
|
|
|
func (fns CCFuncs) inType(f pgs.Field, x interface{}) string { |
|
switch f.Type().ProtoType() { |
|
case pgs.BytesT: |
|
return "string" |
|
case pgs.MessageT: |
|
switch x.(type) { |
|
case []string: |
|
return "string" |
|
case []*durationpb.Duration: |
|
return "pgv::protobuf_wkt::Duration" |
|
default: |
|
return fns.className(f.Type().Element().Embed()) |
|
} |
|
default: |
|
return fns.cType(f.Type()) |
|
} |
|
} |
|
|
|
func (fns CCFuncs) cType(t pgs.FieldType) string { |
|
if t.IsEmbed() { |
|
return fns.className(t.Embed()) |
|
} |
|
if t.IsRepeated() { |
|
if t.ProtoType() == pgs.MessageT { |
|
return fns.className(t.Element().Embed()) |
|
} |
|
// Strip the leading [] |
|
return fns.cTypeOfString(fns.Type(t.Field()).String()[2:]) |
|
} else if t.IsMap() { |
|
if t.Element().IsEmbed() { |
|
return fns.className(t.Element().Embed()) |
|
} |
|
|
|
return fns.cTypeOfString(fns.Type(t.Field()).Element().String()) |
|
} |
|
|
|
return fns.cTypeOfString(fns.Type(t.Field()).String()) |
|
} |
|
|
|
func (fns CCFuncs) cTypeOfString(s string) string { |
|
switch s { |
|
case "float32": |
|
return "float" |
|
case "float64": |
|
return "double" |
|
case "int32": |
|
return "int32_t" |
|
case "int64": |
|
return "int64_t" |
|
case "uint32": |
|
return "uint32_t" |
|
case "uint64": |
|
return "uint64_t" |
|
case "[]byte": |
|
return "string" |
|
default: |
|
return s |
|
} |
|
} |
|
|
|
func (fns CCFuncs) inKey(f pgs.Field, x interface{}) string { |
|
switch f.Type().ProtoType() { |
|
case pgs.BytesT: |
|
return fns.byteStr(x.([]byte)) |
|
case pgs.MessageT: |
|
switch x := x.(type) { |
|
case *durationpb.Duration: |
|
return fns.durLit(x) |
|
default: |
|
return fns.lit(x) |
|
} |
|
case pgs.EnumT: |
|
return fmt.Sprintf("%s(%d)", fns.cType(f.Type()), x.(int32)) |
|
default: |
|
return fns.lit(x) |
|
} |
|
} |
|
|
|
func (fns CCFuncs) durLit(dur *durationpb.Duration) string { |
|
return fmt.Sprintf( |
|
"pgv::protobuf::util::TimeUtil::SecondsToDuration(%d) + pgv::protobuf::util::TimeUtil::NanosecondsToDuration(%d)", |
|
dur.GetSeconds(), dur.GetNanos()) |
|
} |
|
|
|
func (fns CCFuncs) durStr(dur *durationpb.Duration) string { |
|
d := dur.AsDuration() |
|
return d.String() |
|
} |
|
|
|
func (fns CCFuncs) durGt(a, b *durationpb.Duration) bool { |
|
ad := a.AsDuration() |
|
bd := b.AsDuration() |
|
|
|
return ad > bd |
|
} |
|
|
|
func (fns CCFuncs) tsLit(ts *timestamppb.Timestamp) string { |
|
return fmt.Sprintf( |
|
"time.Unix(%d, %d)", |
|
ts.GetSeconds(), ts.GetNanos(), |
|
) |
|
} |
|
|
|
func (fns CCFuncs) tsGt(a, b *timestamppb.Timestamp) bool { |
|
at := a.AsTime() |
|
bt := b.AsTime() |
|
|
|
return !bt.Before(at) |
|
} |
|
|
|
func (fns CCFuncs) tsStr(ts *timestamppb.Timestamp) string { |
|
t := ts.AsTime() |
|
return t.String() |
|
} |
|
|
|
func (fns CCFuncs) unwrap(ctx shared.RuleContext, name string) (shared.RuleContext, error) { |
|
ctx, err := ctx.Unwrap("wrapper") |
|
if err != nil { |
|
return ctx, err |
|
} |
|
|
|
ctx.AccessorOverride = fmt.Sprintf("%s.%s()", name, |
|
ctx.Field.Type().Embed().Fields()[0].Name()) |
|
|
|
return ctx, nil |
|
} |
|
|
|
func (fns CCFuncs) failUnimplemented(message string) string { |
|
if len(message) == 0 { |
|
return "throw pgv::UnimplementedException();" |
|
} |
|
|
|
return fmt.Sprintf(`throw pgv::UnimplementedException(%q);`, message) |
|
} |
|
|
|
func (fns CCFuncs) staticVarName(msg pgs.Message) string { |
|
return "validator_" + strings.Replace(fns.className(msg), ":", "_", -1) |
|
} |
|
|
|
func (fns CCFuncs) output(file pgs.File, ext string) string { |
|
return pgs.FilePath(file.Name().String()).SetExt(".pb" + ext).String() |
|
} |
|
|
|
func (fns CCFuncs) Type(f pgs.Field) pgsgo.TypeName { |
|
typ := fns.Context.Type(f) |
|
|
|
if f.Type().IsEnum() { |
|
parts := strings.Split(typ.String(), ".") |
|
typ = pgsgo.TypeName(parts[len(parts)-1]) |
|
} |
|
|
|
return typ |
|
}
|
|
|