...

Source file src/github.com/twmb/franz-go/generate/gen.go

Documentation: github.com/twmb/franz-go/generate

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  )
     8  
     9  func (Bool) TypeName() string                  { return "bool" }
    10  func (Int8) TypeName() string                  { return "int8" }
    11  func (Int16) TypeName() string                 { return "int16" }
    12  func (Uint16) TypeName() string                { return "uint16" }
    13  func (Int32) TypeName() string                 { return "int32" }
    14  func (Int64) TypeName() string                 { return "int64" }
    15  func (Float64) TypeName() string               { return "float64" }
    16  func (Uint32) TypeName() string                { return "uint32" }
    17  func (Varint) TypeName() string                { return "int32" }
    18  func (Varlong) TypeName() string               { return "int64" }
    19  func (Uuid) TypeName() string                  { return "[16]byte" }
    20  func (String) TypeName() string                { return "string" }
    21  func (NullableString) TypeName() string        { return "*string" }
    22  func (Bytes) TypeName() string                 { return "[]byte" }
    23  func (NullableBytes) TypeName() string         { return "[]byte" }
    24  func (VarintString) TypeName() string          { return "string" }
    25  func (VarintBytes) TypeName() string           { return "[]byte" }
    26  func (a Array) TypeName() string               { return "[]" + a.Inner.TypeName() }
    27  func (Throttle) TypeName() string              { return "int32" }
    28  func (FieldLengthMinusBytes) TypeName() string { return "[]byte" }
    29  
    30  func (s Struct) TypeName() string {
    31  	if s.Nullable {
    32  		return "*" + s.Name
    33  	}
    34  	return s.Name
    35  }
    36  
    37  func (e Enum) TypeName() string { return e.Name }
    38  func (e Enum) WriteAppend(l *LineWriter) {
    39  	l.Write("{")
    40  	l.Write("v := %s(v)", e.Type.TypeName())
    41  	e.Type.WriteAppend(l)
    42  	l.Write("}")
    43  }
    44  
    45  func (e Enum) WriteDecode(l *LineWriter) {
    46  	l.Write("var t %s", e.Name)
    47  	l.Write("{")
    48  	e.Type.WriteDecode(l)
    49  	l.Write("t = %s(v)", e.Name)
    50  	l.Write("}")
    51  	l.Write("v := t")
    52  }
    53  
    54  // primAppend corresponds to the primitive append functions in
    55  // kmsg/primitives.go.
    56  func primAppend(name string, l *LineWriter) {
    57  	l.Write("dst = kbin.Append%s(dst, v)", name)
    58  }
    59  
    60  func compactAppend(fromFlexible bool, name string, l *LineWriter) {
    61  	if fromFlexible {
    62  		l.Write("if isFlexible {")
    63  		primAppend("Compact"+name, l)
    64  		l.Write("} else {")
    65  		defer l.Write("}")
    66  	}
    67  	primAppend(name, l)
    68  }
    69  
    70  func (Bool) WriteAppend(l *LineWriter)         { primAppend("Bool", l) }
    71  func (Int8) WriteAppend(l *LineWriter)         { primAppend("Int8", l) }
    72  func (Int16) WriteAppend(l *LineWriter)        { primAppend("Int16", l) }
    73  func (Uint16) WriteAppend(l *LineWriter)       { primAppend("Uint16", l) }
    74  func (Int32) WriteAppend(l *LineWriter)        { primAppend("Int32", l) }
    75  func (Int64) WriteAppend(l *LineWriter)        { primAppend("Int64", l) }
    76  func (Float64) WriteAppend(l *LineWriter)      { primAppend("Float64", l) }
    77  func (Uint32) WriteAppend(l *LineWriter)       { primAppend("Uint32", l) }
    78  func (Varint) WriteAppend(l *LineWriter)       { primAppend("Varint", l) }
    79  func (Varlong) WriteAppend(l *LineWriter)      { primAppend("Varlong", l) }
    80  func (Uuid) WriteAppend(l *LineWriter)         { primAppend("Uuid", l) }
    81  func (VarintString) WriteAppend(l *LineWriter) { primAppend("VarintString", l) }
    82  func (VarintBytes) WriteAppend(l *LineWriter)  { primAppend("VarintBytes", l) }
    83  func (Throttle) WriteAppend(l *LineWriter)     { primAppend("Int32", l) }
    84  
    85  func (v String) WriteAppend(l *LineWriter) { compactAppend(v.FromFlexible, "String", l) }
    86  func (v NullableString) WriteAppend(l *LineWriter) {
    87  	if v.NullableVersion > 0 {
    88  		l.Write("if version < %d {", v.NullableVersion)
    89  		l.Write("var vv string")
    90  		l.Write("if v != nil {")
    91  		l.Write("vv = *v")
    92  		l.Write("}")
    93  		l.Write("{")
    94  		l.Write("v := vv")
    95  		compactAppend(v.FromFlexible, "String", l)
    96  		l.Write("}")
    97  		l.Write("} else {")
    98  		defer l.Write("}")
    99  	}
   100  	compactAppend(v.FromFlexible, "NullableString", l)
   101  }
   102  func (v Bytes) WriteAppend(l *LineWriter)         { compactAppend(v.FromFlexible, "Bytes", l) }
   103  func (v NullableBytes) WriteAppend(l *LineWriter) { compactAppend(v.FromFlexible, "NullableBytes", l) }
   104  
   105  func (FieldLengthMinusBytes) WriteAppend(l *LineWriter) {
   106  	l.Write("dst = append(dst, v...)")
   107  }
   108  
   109  func (a Array) WriteAppend(l *LineWriter) {
   110  	writeNullable := func() {
   111  		if a.FromFlexible {
   112  			l.Write("if isFlexible {")
   113  			l.Write("dst = kbin.AppendCompactNullableArrayLen(dst, len(v), v == nil)")
   114  			l.Write("} else {")
   115  			l.Write("dst = kbin.AppendNullableArrayLen(dst, len(v), v == nil)")
   116  			l.Write("}")
   117  		} else {
   118  			l.Write("dst = kbin.AppendNullableArrayLen(dst, len(v), v == nil)")
   119  		}
   120  	}
   121  
   122  	writeNormal := func() {
   123  		if a.FromFlexible {
   124  			l.Write("if isFlexible {")
   125  			l.Write("dst = kbin.AppendCompactArrayLen(dst, len(v))")
   126  			l.Write("} else {")
   127  			l.Write("dst = kbin.AppendArrayLen(dst, len(v))")
   128  			l.Write("}")
   129  		} else {
   130  			l.Write("dst = kbin.AppendArrayLen(dst, len(v))")
   131  		}
   132  	}
   133  
   134  	switch {
   135  	case a.IsVarintArray:
   136  		l.Write("dst = kbin.AppendVarint(dst, int32(len(v)))")
   137  	case a.IsNullableArray:
   138  		if a.NullableVersion > 0 {
   139  			l.Write("if version >= %d {", a.NullableVersion)
   140  			writeNullable()
   141  			l.Write("} else {")
   142  			writeNormal()
   143  			l.Write("}")
   144  		} else {
   145  			writeNullable()
   146  		}
   147  	default:
   148  		writeNormal()
   149  	}
   150  	l.Write("for i := range v {")
   151  	if s, isStruct := a.Inner.(Struct); isStruct && !s.Nullable {
   152  		// If the array elements are structs, we avoid copying the
   153  		// struct out and instead grab a pointer to the element.
   154  		l.Write("v := &v[i]")
   155  	} else {
   156  		l.Write("v := v[i]")
   157  	}
   158  	a.Inner.WriteAppend(l)
   159  	l.Write("}")
   160  }
   161  
   162  func (s Struct) WriteAppend(l *LineWriter) {
   163  	tags := make(map[int]StructField)
   164  	if s.Nullable {
   165  		l.Write("if v == nil {")
   166  		l.Write("dst = append(dst, 255)") // -1
   167  		l.Write("} else {")
   168  		l.Write("dst = append(dst, 1)")
   169  		defer l.Write("}")
   170  	}
   171  	for _, f := range s.Fields {
   172  		if onlyTag := f.writeBeginAndTag(l, tags); onlyTag {
   173  			continue
   174  		}
   175  		// If the struct field is a struct itself, we avoid copying it
   176  		// and instead grab a pointer.
   177  		if s, isStruct := f.Type.(Struct); isStruct && !s.Nullable {
   178  			l.Write("v := &v.%s", f.FieldName)
   179  		} else {
   180  			l.Write("v := v.%s", f.FieldName)
   181  		}
   182  		f.Type.WriteAppend(l)
   183  		l.Write("}")
   184  	}
   185  
   186  	if !s.FromFlexible {
   187  		return
   188  	}
   189  
   190  	l.Write("if isFlexible {")
   191  	defer l.Write("}")
   192  
   193  	var tagsCanDefault bool
   194  	for i := 0; i < len(tags); i++ {
   195  		f, exists := tags[i]
   196  		if !exists {
   197  			die("saw %d tags, but did not see tag %d; expected monotonically increasing", len(tags), i)
   198  		}
   199  		if _, tagsCanDefault = f.Type.(Defaulter); tagsCanDefault {
   200  			break
   201  		}
   202  	}
   203  
   204  	defer l.Write("dst = v.UnknownTags.AppendEach(dst)")
   205  
   206  	if tagsCanDefault {
   207  		l.Write("var toEncode []uint32")
   208  		for i := 0; i < len(tags); i++ {
   209  			f := tags[i]
   210  			canDefault := false
   211  			if d, ok := f.Type.(Defaulter); ok {
   212  				canDefault = true
   213  				def, has := d.GetDefault()
   214  				if !has {
   215  					def = d.GetTypeDefault()
   216  				}
   217  				switch t := f.Type.(type) {
   218  				case Struct:
   219  					l.Write("if !reflect.DeepEqual(v.%s, %v) {", f.FieldName, def)
   220  
   221  				case Array:
   222  					if t.IsNullableArray {
   223  						l.Write("if version < %[1]d && len(v.%[2]s) > 0 || version >= %[1]d && v.%[2]s != nil {", t.NullableVersion, f.FieldName)
   224  					} else {
   225  						l.Write("if len(v.%s) > 0 {", f.FieldName)
   226  					}
   227  
   228  				default:
   229  					l.Write("if v.%s != %v {", f.FieldName, def)
   230  				}
   231  			}
   232  			l.Write("toEncode = append(toEncode, %d)", i)
   233  			if canDefault {
   234  				l.Write("}")
   235  			}
   236  		}
   237  
   238  		l.Write("dst = kbin.AppendUvarint(dst, uint32(len(toEncode) + v.UnknownTags.Len()))")
   239  		l.Write("for _, tag := range toEncode {")
   240  		l.Write("switch tag {")
   241  		defer l.Write("}")
   242  		defer l.Write("}")
   243  	} else {
   244  		l.Write("dst = kbin.AppendUvarint(dst, %d + uint32(v.UnknownTags.Len()))", len(tags))
   245  	}
   246  
   247  	for i := 0; i < len(tags); i++ {
   248  		if tagsCanDefault {
   249  			l.Write("case %d:", i)
   250  		}
   251  		f := tags[i]
   252  
   253  		l.Write("{")
   254  		l.Write("v := v.%s", f.FieldName)
   255  		l.Write("dst = kbin.AppendUvarint(dst, %d)", i) // tag num
   256  		switch f.Type.(type) {
   257  		case Bool, Int8:
   258  			l.Write("dst = kbin.AppendUvarint(dst, 1)") // size
   259  			f.Type.WriteAppend(l)
   260  		case Int16, Uint16:
   261  			l.Write("dst = kbin.AppendUvarint(dst, 2)")
   262  			f.Type.WriteAppend(l)
   263  		case Int32, Uint32:
   264  			l.Write("dst = kbin.AppendUvarint(dst, 4)")
   265  			f.Type.WriteAppend(l)
   266  		case Int64, Float64:
   267  			l.Write("dst = kbin.AppendUvarint(dst, 8)")
   268  			f.Type.WriteAppend(l)
   269  		case Varint:
   270  			l.Write("dst = kbin.AppendUvarint(dst, kbin.VarintLen(v))")
   271  			f.Type.WriteAppend(l)
   272  		case Varlong:
   273  			l.Write("dst = kbin.AppendUvarint(dst, kbin.VarlongLen(v))")
   274  			f.Type.WriteAppend(l)
   275  		case Uuid:
   276  			l.Write("dst = kbin.AppendUvarint(dst, 16)")
   277  			f.Type.WriteAppend(l)
   278  		case Array, Struct, String, NullableString, Bytes, NullableBytes:
   279  			l.Write("sized := false")
   280  			l.Write("lenAt := len(dst)")
   281  			l.Write("f%s:", f.FieldName)
   282  			f.Type.WriteAppend(l)
   283  			l.Write("if !sized {")
   284  			l.Write("dst = kbin.AppendUvarint(dst[:lenAt], uint32(len(dst[lenAt:])))")
   285  			l.Write("sized = true")
   286  			l.Write("goto f%s", f.FieldName)
   287  			l.Write("}")
   288  		default:
   289  			die("tag type %v unsupported in append! fix this!", f.Type.TypeName())
   290  		}
   291  		l.Write("}")
   292  	}
   293  }
   294  
   295  // writeBeginAndTag begins a struct field encode/decode and adds the field to
   296  // the tags map if necessary. If this field is only tagged, this returns true.
   297  func (f StructField) writeBeginAndTag(l *LineWriter, tags map[int]StructField) (onlyTag bool) {
   298  	if f.MinVersion == -1 && f.MaxVersion > 0 {
   299  		die("unexpected negative min version %d while max version %d on field %s", f.MinVersion, f.MaxVersion, f.FieldName)
   300  	}
   301  	if f.Tag >= 0 {
   302  		if _, exists := tags[f.Tag]; exists {
   303  			die("unexpected duplicate tag %d on field %s", f.Tag, f.FieldName)
   304  		}
   305  		tags[f.Tag] = f
   306  	}
   307  	switch {
   308  	case f.MaxVersion > -1:
   309  		l.Write("if version >= %d && version <= %d {", f.MinVersion, f.MaxVersion)
   310  	case f.MinVersion > 0:
   311  		l.Write("if version >= %d {", f.MinVersion)
   312  	case f.MinVersion == -1:
   313  		if f.Tag < 0 {
   314  			die("unexpected min version -1 with tag %d on field %s", f.Tag, f.FieldName)
   315  		}
   316  		return true
   317  	default:
   318  		l.Write("{")
   319  	}
   320  	return false
   321  }
   322  
   323  // primDecode corresponds to the binreader primitive decoding functions in
   324  // kmsg/primitives.go.
   325  func primDecode(name string, l *LineWriter) {
   326  	l.Write("v := b.%s()", name)
   327  }
   328  
   329  func unsafeDecode(l *LineWriter, fn func(string)) {
   330  	l.Write("if unsafe {")
   331  	fn("Unsafe")
   332  	l.Write("} else {")
   333  	fn("")
   334  	l.Write("}")
   335  }
   336  
   337  func flexDecode(supports bool, l *LineWriter, fn func(string)) {
   338  	if supports {
   339  		l.Write("if isFlexible {")
   340  		fn("Compact")
   341  		l.Write("} else {")
   342  		defer l.Write("}")
   343  	}
   344  	fn("")
   345  }
   346  
   347  func primUnsafeDecode(name string, l *LineWriter) {
   348  	l.Write("var v string")
   349  	unsafeDecode(l, func(u string) {
   350  		l.Write("v = b.%s%s()", u, name)
   351  	})
   352  }
   353  
   354  func compactDecode(fromFlexible, hasUnsafe bool, name, typ string, l *LineWriter) {
   355  	if fromFlexible {
   356  		l.Write("var v %s", typ)
   357  		fn := func(u string) {
   358  			l.Write("if isFlexible {")
   359  			l.Write("v = b.%sCompact%s()", u, name)
   360  			l.Write("} else {")
   361  			l.Write("v = b.%s%s()", u, name)
   362  			l.Write("}")
   363  		}
   364  		if hasUnsafe {
   365  			unsafeDecode(l, fn)
   366  		} else {
   367  			fn("")
   368  		}
   369  	} else {
   370  		if hasUnsafe {
   371  			primUnsafeDecode(name, l)
   372  		} else {
   373  			primDecode(name, l)
   374  		}
   375  	}
   376  }
   377  
   378  func (Bool) WriteDecode(l *LineWriter)         { primDecode("Bool", l) }
   379  func (Int8) WriteDecode(l *LineWriter)         { primDecode("Int8", l) }
   380  func (Int16) WriteDecode(l *LineWriter)        { primDecode("Int16", l) }
   381  func (Uint16) WriteDecode(l *LineWriter)       { primDecode("Uint16", l) }
   382  func (Int32) WriteDecode(l *LineWriter)        { primDecode("Int32", l) }
   383  func (Int64) WriteDecode(l *LineWriter)        { primDecode("Int64", l) }
   384  func (Float64) WriteDecode(l *LineWriter)      { primDecode("Float64", l) }
   385  func (Uint32) WriteDecode(l *LineWriter)       { primDecode("Uint32", l) }
   386  func (Varint) WriteDecode(l *LineWriter)       { primDecode("Varint", l) }
   387  func (Varlong) WriteDecode(l *LineWriter)      { primDecode("Varlong", l) }
   388  func (Uuid) WriteDecode(l *LineWriter)         { primDecode("Uuid", l) }
   389  func (VarintString) WriteDecode(l *LineWriter) { primUnsafeDecode("VarintString", l) }
   390  func (VarintBytes) WriteDecode(l *LineWriter)  { primDecode("VarintBytes", l) }
   391  func (Throttle) WriteDecode(l *LineWriter)     { primDecode("Int32", l) }
   392  
   393  func (v String) WriteDecode(l *LineWriter) {
   394  	compactDecode(v.FromFlexible, true, "String", "string", l)
   395  }
   396  
   397  func (v Bytes) WriteDecode(l *LineWriter) {
   398  	compactDecode(v.FromFlexible, false, "Bytes", "[]byte", l)
   399  }
   400  
   401  func (v NullableBytes) WriteDecode(l *LineWriter) {
   402  	compactDecode(v.FromFlexible, false, "NullableBytes", "[]byte", l)
   403  }
   404  
   405  func (v NullableString) WriteDecode(l *LineWriter) {
   406  	// If there is a nullable version, we write a "read string, then set
   407  	// pointer" block.
   408  	l.Write("var v *string")
   409  	if v.NullableVersion > 0 {
   410  		l.Write("if version < %d {", v.NullableVersion)
   411  		l.Write("var vv string")
   412  		flexDecode(v.FromFlexible, l, func(compact string) {
   413  			unsafeDecode(l, func(u string) {
   414  				l.Write("vv = b.%s%sString()", u, compact)
   415  			})
   416  		})
   417  		l.Write("v = &vv")
   418  		l.Write("} else {")
   419  		defer l.Write("}")
   420  	}
   421  
   422  	flexDecode(v.FromFlexible, l, func(compact string) {
   423  		unsafeDecode(l, func(u string) {
   424  			l.Write("v = b.%s%sNullableString()", u, compact)
   425  		})
   426  	})
   427  }
   428  
   429  func (f FieldLengthMinusBytes) WriteDecode(l *LineWriter) {
   430  	l.Write("v := b.Span(int(s.%s) - %d)", f.Field, f.LengthMinus)
   431  }
   432  
   433  func (a Array) WriteDecode(l *LineWriter) {
   434  	// For decoding arrays, we copy our "v" variable to our own "a"
   435  	// variable so that the scope opened just below can use its own
   436  	// v variable. At the end, we reset v with any updates to a.
   437  	l.Write("a := v")
   438  	l.Write("var l int32")
   439  
   440  	if a.IsVarintArray {
   441  		l.Write("l = b.VarintArrayLen()")
   442  	} else {
   443  		flexDecode(a.FromFlexible, l, func(compact string) {
   444  			l.Write("l = b.%sArrayLen()", compact)
   445  		})
   446  		if a.IsNullableArray {
   447  			l.Write("if version < %d || l == 0 {", a.NullableVersion)
   448  			l.Write("a = %s{}", a.TypeName())
   449  			l.Write("}")
   450  		}
   451  	}
   452  
   453  	l.Write("if !b.Ok() {")
   454  	l.Write("return b.Complete()")
   455  	l.Write("}")
   456  
   457  	l.Write("a = a[:0]")
   458  
   459  	l.Write("if l > 0 {")
   460  	l.Write("a = append(a, make(%s, l)...)", a.TypeName())
   461  	l.Write("}")
   462  
   463  	l.Write("for i := int32(0); i < l; i++ {")
   464  	switch t := a.Inner.(type) {
   465  	case Struct:
   466  		if t.Nullable {
   467  			l.Write("if present := b.Int8(); present != -1 && b.Ok() {")
   468  			defer l.Write("}")
   469  		}
   470  		l.Write("v := &a[i]")
   471  		l.Write("v.Default()") // set defaults first
   472  	case Array:
   473  		// With nested arrays, we declare a new v and introduce scope
   474  		// so that the next level will not collide with our current "a".
   475  		l.Write("v := a[i]")
   476  		l.Write("{")
   477  	}
   478  
   479  	a.Inner.WriteDecode(l)
   480  
   481  	if _, isArray := a.Inner.(Array); isArray {
   482  		// With nested arrays, now we release our scope.
   483  		l.Write("}")
   484  	}
   485  
   486  	if _, isStruct := a.Inner.(Struct); !isStruct {
   487  		l.Write("a[i] = v")
   488  	}
   489  
   490  	l.Write("}") // close the for loop
   491  
   492  	l.Write("v = a")
   493  }
   494  
   495  func (f StructField) WriteDecode(l *LineWriter) {
   496  	switch t := f.Type.(type) {
   497  	case Struct:
   498  		// For decoding a nested struct, we copy a pointer out.
   499  		// The nested version will then set the fields directly.
   500  		if t.Nullable {
   501  			l.Write("if present := b.Int8(); present != -1 && b.Ok() {")
   502  			l.Write("s.%s = new(%s)", f.FieldName, t.Name)
   503  			l.Write("v := s.%s", f.FieldName)
   504  			defer l.Write("}")
   505  		} else {
   506  			l.Write("v := &s.%s", f.FieldName)
   507  		}
   508  		l.Write("v.Default()")
   509  	case Array:
   510  		// For arrays, we need to copy the array into a v
   511  		// field so that the array function can use it.
   512  		l.Write("v := s.%s", f.FieldName)
   513  	default:
   514  		// All other types use primDecode, which does a `v :=`.
   515  	}
   516  	f.Type.WriteDecode(l)
   517  
   518  	_, isStruct := f.Type.(Struct)
   519  	if !isStruct {
   520  		// If the field was not a struct or it was a nullable struct,
   521  		// we need to copy the changes back.
   522  		l.Write("s.%s = v", f.FieldName)
   523  	}
   524  }
   525  
   526  func (s Struct) WriteDecode(l *LineWriter) {
   527  	if len(s.Fields) == 0 {
   528  		return
   529  	}
   530  	rangeFrom := s.Fields
   531  	if s.WithVersionField {
   532  		f := s.Fields[0]
   533  		if f.FieldName != "Version" {
   534  			die("expected first field in 'with version field' type to be version, is %s", f.FieldName)
   535  		}
   536  		if f.Type != (Int16{}) {
   537  			die("expected field version type to be int16, was %v", f.Type)
   538  		}
   539  		rangeFrom = s.Fields[1:]
   540  	}
   541  	l.Write("s := v")
   542  
   543  	tags := make(map[int]StructField)
   544  
   545  	for _, f := range rangeFrom {
   546  		if onlyTag := f.writeBeginAndTag(l, tags); onlyTag {
   547  			continue
   548  		}
   549  		f.WriteDecode(l)
   550  		l.Write("}")
   551  	}
   552  
   553  	if !s.FromFlexible {
   554  		return
   555  	}
   556  
   557  	l.Write("if isFlexible {")
   558  	if len(tags) == 0 {
   559  		l.Write("s.UnknownTags = internalReadTags(&b)")
   560  		l.Write("}")
   561  		return
   562  	}
   563  	defer l.Write("}")
   564  
   565  	l.Write("for i := b.Uvarint(); i > 0; i-- {")
   566  	defer l.Write("}")
   567  
   568  	l.Write("switch key := b.Uvarint(); key {")
   569  	defer l.Write("}")
   570  
   571  	l.Write("default:")
   572  	l.Write("s.UnknownTags.Set(key, b.Span(int(b.Uvarint())))")
   573  
   574  	for i := 0; i < len(tags); i++ {
   575  		f, exists := tags[i]
   576  		if !exists {
   577  			die("saw %d tags, but did not see tag %d; expected monotonically increasing", len(tags), i)
   578  		}
   579  
   580  		l.Write("case %d:", i)
   581  		l.Write("b := kbin.Reader{Src: b.Span(int(b.Uvarint()))}")
   582  		f.WriteDecode(l)
   583  		l.Write("if err := b.Complete(); err != nil {")
   584  		l.Write("return err")
   585  		l.Write("}")
   586  	}
   587  }
   588  
   589  func (s Struct) WriteDefault(l *LineWriter) {
   590  	if len(s.Fields) == 0 || s.Nullable {
   591  		return
   592  	}
   593  
   594  	// Like decoding above, we skip the version field.
   595  	rangeFrom := s.Fields
   596  	if s.WithVersionField {
   597  		f := s.Fields[0]
   598  		if f.FieldName != "Version" {
   599  			die("expected first field in 'with version field' type to be version, is %s", f.FieldName)
   600  		}
   601  		if f.Type != (Int16{}) {
   602  			die("expected field version type to be int16, was %v", f.Type)
   603  		}
   604  		rangeFrom = s.Fields[1:]
   605  	}
   606  
   607  	for _, f := range rangeFrom {
   608  		switch inner := f.Type.(type) {
   609  		case Struct:
   610  			l.Write("{")
   611  			l.Write("v := &v.%s", f.FieldName)
   612  			l.Write("_ = v")
   613  			inner.WriteDefault(l)
   614  			l.Write("}")
   615  		default:
   616  			if d, ok := f.Type.(Defaulter); ok {
   617  				def, has := d.GetDefault()
   618  				if has {
   619  					l.Write("v.%s = %v", f.FieldName, def)
   620  				}
   621  			}
   622  		}
   623  	}
   624  }
   625  
   626  func (s Struct) WriteDefn(l *LineWriter) {
   627  	if s.Comment != "" {
   628  		l.Write(s.Comment)
   629  	}
   630  	l.Write("type %s struct {", s.Name)
   631  	if s.TopLevel {
   632  		// Top level messages always have a Version field.
   633  		l.Write("// Version is the version of this message used with a Kafka broker.")
   634  		l.Write("Version int16")
   635  		l.Write("")
   636  	}
   637  	for i, f := range s.Fields {
   638  		if f.Comment != "" {
   639  			l.Write("%s", f.Comment)
   640  		}
   641  		versionTag := ""
   642  		switch {
   643  		case f.MinVersion > 0 && f.MaxVersion > 0:
   644  			versionTag = fmt.Sprintf(" // v%d-v%d", f.MinVersion, f.MaxVersion)
   645  		case f.MinVersion > 0:
   646  			versionTag = fmt.Sprintf(" // v%d+", f.MinVersion)
   647  		case f.MaxVersion > 0:
   648  			versionTag = fmt.Sprintf(" // v0-v%d", f.MaxVersion)
   649  		}
   650  		if f.Tag >= 0 {
   651  			if versionTag == "" {
   652  				versionTag += " // tag "
   653  			} else {
   654  				versionTag += ", tag "
   655  			}
   656  			versionTag += strconv.Itoa(f.Tag)
   657  		}
   658  		l.Write("%s %s%s", f.FieldName, f.Type.TypeName(), versionTag)
   659  		if i < len(s.Fields)-1 {
   660  			l.Write("") // blank between fields
   661  		}
   662  	}
   663  	if s.FlexibleAt >= 0 {
   664  		l.Write("")
   665  		l.Write("// UnknownTags are tags Kafka sent that we do not know the purpose of.")
   666  		if s.FlexibleAt == 0 {
   667  			l.Write("UnknownTags Tags")
   668  		} else {
   669  			l.Write("UnknownTags Tags // v%d+", s.FlexibleAt)
   670  		}
   671  		l.Write("")
   672  	}
   673  	l.Write("}")
   674  }
   675  
   676  func (s Struct) WriteKeyFunc(l *LineWriter) {
   677  	l.Write("func (*%s) Key() int16 { return %d }", s.Name, s.Key)
   678  }
   679  
   680  func (s Struct) WriteMaxVersionFunc(l *LineWriter) {
   681  	l.Write("func (*%s) MaxVersion() int16 { return %d }", s.Name, s.MaxVersion)
   682  }
   683  
   684  func (s Struct) WriteGetVersionFunc(l *LineWriter) {
   685  	l.Write("func (v *%s) GetVersion() int16 { return v.Version }", s.Name)
   686  }
   687  
   688  func (s Struct) WriteSetVersionFunc(l *LineWriter) {
   689  	l.Write("func (v *%s) SetVersion(version int16) { v.Version = version }", s.Name)
   690  }
   691  
   692  func (s Struct) WriteAdminFunc(l *LineWriter) {
   693  	l.Write("func (v *%s) IsAdminRequest() {}", s.Name)
   694  }
   695  
   696  func (s Struct) WriteGroupCoordinatorFunc(l *LineWriter) {
   697  	l.Write("func (v *%s) IsGroupCoordinatorRequest() {}", s.Name)
   698  }
   699  
   700  func (s Struct) WriteTxnCoordinatorFunc(l *LineWriter) {
   701  	l.Write("func (v *%s) IsTxnCoordinatorRequest() {}", s.Name)
   702  }
   703  
   704  func (s Struct) WriteResponseKindFunc(l *LineWriter) {
   705  	l.Write("func (v *%s) ResponseKind() Response {", s.Name)
   706  	l.Write("r := &%s{Version: v.Version }", s.ResponseKind)
   707  	l.Write("r.Default()")
   708  	l.Write("return r")
   709  	l.Write("}")
   710  }
   711  
   712  func (s Struct) WriteRequestKindFunc(l *LineWriter) {
   713  	l.Write("func (v *%s) RequestKind() Request { return &%s{Version: v.Version }}", s.Name, s.RequestKind)
   714  }
   715  
   716  func (s Struct) WriteIsFlexibleFunc(l *LineWriter) {
   717  	if s.FlexibleAt >= 0 {
   718  		l.Write("func (v *%s) IsFlexible() bool { return v.Version >= %d }", s.Name, s.FlexibleAt)
   719  	} else {
   720  		l.Write("func (v *%s) IsFlexible() bool { return false }", s.Name)
   721  	}
   722  }
   723  
   724  func (s Struct) WriteThrottleMillisFunc(f StructField, l *LineWriter) {
   725  	t := f.Type.(Throttle)
   726  	l.Write("func (v *%s) Throttle() (int32, bool) { return v.ThrottleMillis, v.Version >= %d }", s.Name, t.Switchup)
   727  }
   728  
   729  func (s Struct) WriteSetThrottleMillisFunc(l *LineWriter) {
   730  	l.Write("func (v *%s) SetThrottle(throttleMillis int32) { v.ThrottleMillis = throttleMillis}", s.Name)
   731  }
   732  
   733  func (s Struct) WriteTimeoutMillisFunc(l *LineWriter) {
   734  	l.Write("func (v *%s) Timeout() int32 { return v.TimeoutMillis }", s.Name)
   735  }
   736  
   737  func (s Struct) WriteSetTimeoutMillisFunc(l *LineWriter) {
   738  	l.Write("func (v *%s) SetTimeout(timeoutMillis int32) { v.TimeoutMillis = timeoutMillis }", s.Name)
   739  }
   740  
   741  func (s Struct) WriteAppendFunc(l *LineWriter) {
   742  	l.Write("func (v *%s) AppendTo(dst []byte) []byte {", s.Name)
   743  	if s.TopLevel || s.WithVersionField {
   744  		l.Write("version := v.Version")
   745  		l.Write("_ = version")
   746  	}
   747  	if s.FlexibleAt >= 0 {
   748  		l.Write("isFlexible := version >= %d", s.FlexibleAt)
   749  		l.Write("_ = isFlexible")
   750  	}
   751  	s.WriteAppend(l)
   752  	l.Write("return dst")
   753  	l.Write("}")
   754  }
   755  
   756  func (s Struct) WriteDecodeFunc(l *LineWriter) {
   757  	l.Write("func (v *%s) ReadFrom(src []byte) error {", s.Name)
   758  	l.Write("return v.readFrom(src, false)")
   759  	l.Write("}")
   760  
   761  	l.Write("func (v *%s) UnsafeReadFrom(src []byte) error {", s.Name)
   762  	l.Write("return v.readFrom(src, true)")
   763  	l.Write("}")
   764  
   765  	l.Write("func (v *%s) readFrom(src []byte, unsafe bool) error {", s.Name)
   766  	l.Write("v.Default()")
   767  	l.Write("b := kbin.Reader{Src: src}")
   768  	if s.WithVersionField {
   769  		l.Write("v.Version = b.Int16()")
   770  	}
   771  	if s.TopLevel || s.WithVersionField {
   772  		l.Write("version := v.Version")
   773  		l.Write("_ = version")
   774  	}
   775  	if s.FlexibleAt >= 0 {
   776  		l.Write("isFlexible := version >= %d", s.FlexibleAt)
   777  		l.Write("_ = isFlexible")
   778  	}
   779  	s.WriteDecode(l)
   780  	l.Write("return b.Complete()")
   781  	l.Write("}")
   782  }
   783  
   784  func (s Struct) WriteRequestWithFunc(l *LineWriter) {
   785  	l.Write("// RequestWith is requests v on r and returns the response or an error.")
   786  	l.Write("// For sharded requests, the response may be merged and still return an error.")
   787  	l.Write("// It is better to rely on client.RequestSharded than to rely on proper merging behavior.")
   788  	l.Write("func (v *%s) RequestWith(ctx context.Context, r Requestor) (*%s, error) {", s.Name, s.ResponseKind)
   789  	l.Write("kresp, err := r.Request(ctx, v)")
   790  	l.Write("resp, _ := kresp.(*%s)", s.ResponseKind)
   791  	l.Write("return resp, err")
   792  	l.Write("}")
   793  }
   794  
   795  func (s Struct) WriteDefaultFunc(l *LineWriter) {
   796  	l.Write("// Default sets any default fields. Calling this allows for future compatibility")
   797  	l.Write("// if new fields are added to %s.", s.Name)
   798  	l.Write("func (v *%s) Default() {", s.Name)
   799  	s.WriteDefault(l)
   800  	l.Write("}")
   801  }
   802  
   803  func (s Struct) WriteNewFunc(l *LineWriter) {
   804  	l.Write("// New%[1]s returns a default %[1]s", s.Name)
   805  	l.Write("// This is a shortcut for creating a struct and calling Default yourself.")
   806  	l.Write("func New%[1]s() %[1]s {", s.Name)
   807  	l.Write("var v %s", s.Name)
   808  	l.Write("v.Default()")
   809  	l.Write("return v")
   810  	l.Write("}")
   811  }
   812  
   813  func (s Struct) WriteNewPtrFunc(l *LineWriter) {
   814  	l.Write("// NewPtr%[1]s returns a pointer to a default %[1]s", s.Name)
   815  	l.Write("// This is a shortcut for creating a new(struct) and calling Default yourself.")
   816  	l.Write("func NewPtr%[1]s() *%[1]s {", s.Name)
   817  	l.Write("var v %s", s.Name)
   818  	l.Write("v.Default()")
   819  	l.Write("return &v")
   820  	l.Write("}")
   821  }
   822  
   823  func (e Enum) WriteDefn(l *LineWriter) {
   824  	if e.Comment != "" {
   825  		l.Write(e.Comment)
   826  		l.Write("// ")
   827  	}
   828  	l.Write("// Possible values and their meanings:")
   829  	l.Write("// ")
   830  	for _, v := range e.Values {
   831  		l.Write("// * %d (%s)", v.Value, v.Word)
   832  		if len(v.Comment) > 0 {
   833  			l.Write(v.Comment)
   834  		}
   835  		l.Write("//")
   836  	}
   837  	l.Write("type %s %s", e.Name, e.Type.TypeName())
   838  }
   839  
   840  func (e Enum) WriteStringFunc(l *LineWriter) {
   841  	l.Write("func (v %s) String() string {", e.Name)
   842  	l.Write("switch v {")
   843  	l.Write("default:")
   844  	if e.CamelCase {
   845  		l.Write(`return "Unknown"`)
   846  	} else {
   847  		l.Write(`return "UNKNOWN"`)
   848  	}
   849  	for _, v := range e.Values {
   850  		l.Write("case %d:", v.Value)
   851  		l.Write(`return "%s"`, v.Word)
   852  	}
   853  	l.Write("}")
   854  	l.Write("}")
   855  }
   856  
   857  func (e Enum) WriteStringsFunc(l *LineWriter) {
   858  	l.Write("func %sStrings() []string {", e.Name)
   859  	l.Write("return []string{")
   860  	for _, v := range e.Values {
   861  		l.Write(`"%s",`, v.Word)
   862  	}
   863  	l.Write("}")
   864  	l.Write("}")
   865  }
   866  
   867  func (e Enum) WriteParseFunc(l *LineWriter) {
   868  	l.Write("// Parse%s normalizes the input s and returns", e.Name)
   869  	l.Write("// the value represented by the string.")
   870  	l.Write("//")
   871  	l.Write("// Normalizing works by stripping all dots, underscores, and dashes,")
   872  	l.Write("// trimming spaces, and lowercasing.")
   873  	l.Write("func Parse%[1]s(s string) (%[1]s, error) {", e.Name)
   874  	l.Write("switch strnorm(s) {")
   875  	for _, v := range e.Values {
   876  		l.Write(`case "%s":`, strnorm(v.Word))
   877  		l.Write("return %d, nil", v.Value)
   878  	}
   879  	l.Write("default:")
   880  	l.Write(`return 0, fmt.Errorf("%s: unable to parse %%q", s)`, e.Name)
   881  	l.Write("}")
   882  	l.Write("}")
   883  }
   884  
   885  func (e Enum) WriteUnmarshalTextFunc(l *LineWriter) {
   886  	l.Write("// UnmarshalText implements encoding.TextUnmarshaler.")
   887  	l.Write("func (e *%s) UnmarshalText(text []byte) error {", e.Name)
   888  	l.Write("v, err := Parse%s(string(text))", e.Name)
   889  	l.Write("*e = v")
   890  	l.Write("return err")
   891  	l.Write("}")
   892  }
   893  
   894  func (e Enum) WriteMarshalTextFunc(l *LineWriter) {
   895  	l.Write("// MarshalText implements encoding.TextMarshaler.")
   896  	l.Write("func (e %s) MarshalText() (text []byte, err error) {", e.Name)
   897  	l.Write("return []byte(e.String()), nil")
   898  	l.Write("}")
   899  }
   900  
   901  func strnorm(s string) string {
   902  	s = strings.ReplaceAll(s, ".", "")
   903  	s = strings.ReplaceAll(s, "_", "")
   904  	s = strings.ReplaceAll(s, "-", "")
   905  	s = strings.TrimSpace(s)
   906  	s = strings.ToLower(s)
   907  	return s
   908  }
   909  
   910  func writeStrnorm(l *LineWriter) {
   911  	l.Write(`func strnorm(s string) string {`)
   912  	l.Write(`s = strings.ReplaceAll(s, ".", "")`)
   913  	l.Write(`s = strings.ReplaceAll(s, "_", "")`)
   914  	l.Write(`s = strings.ReplaceAll(s, "-", "")`)
   915  	l.Write(`s = strings.TrimSpace(s)`)
   916  	l.Write(`s = strings.ToLower(s)`)
   917  	l.Write(`return s`)
   918  	l.Write(`}`)
   919  }
   920  
   921  func (e Enum) WriteConsts(l *LineWriter) {
   922  	l.Write("const (")
   923  	if !e.HasZero {
   924  		l.Write("%[1]sUnknown %[1]s = 0", e.Name)
   925  	}
   926  	defer l.Write(")")
   927  	for _, v := range e.Values {
   928  		var sb strings.Builder
   929  		if e.CamelCase {
   930  			sb.WriteString(v.Word)
   931  		} else {
   932  			upper := true
   933  			for _, c := range v.Word {
   934  				switch c {
   935  				case '_':
   936  					upper = true
   937  				default:
   938  					s := string([]rune{c})
   939  					if upper {
   940  						sb.WriteString(strings.ToUpper(s))
   941  					} else {
   942  						sb.WriteString(strings.ToLower(s))
   943  					}
   944  					upper = false
   945  				}
   946  			}
   947  		}
   948  
   949  		l.Write("%s%s %s = %d", e.Name, sb.String(), e.Name, v.Value)
   950  	}
   951  }
   952  

View as plain text