// Copyright 2016 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.

// This file generates batch_generated.go. It can be run via:
//    go run -tags gen-batch gen_batch.go

// +build gen-batch

package main

import (
	"fmt"
	"io"
	"os"
	"reflect"
	"strings"

	"github.com/cockroachdb/cockroach/pkg/roachpb"
)

type variantInfo struct {
	// variantType is the name of the variant type that implements
	// the union interface (isRequestUnion_Value,isResponseUnion_Value).
	variantType string
	// variantName is the unique suffix of variantType. It is also
	// the name of the single field in this type.
	variantName string
	// msgType is the name of the variant's corresponding Request/Response
	// type.
	msgType string
}

var errVariants []variantInfo
var reqVariants []variantInfo
var resVariants []variantInfo
var reqResVariantMapping map[variantInfo]variantInfo

func initVariant(varInstance interface{}) variantInfo {
	t := reflect.TypeOf(varInstance)
	f := t.Elem().Field(0) // variants always have 1 field
	return variantInfo{
		variantType: t.Elem().Name(),
		variantName: f.Name,
		msgType:     f.Type.Elem().Name(),
	}
}

func initVariants() {
	_, _, _, errVars := (&roachpb.ErrorDetail{}).XXX_OneofFuncs()
	for _, v := range errVars {
		errInfo := initVariant(v)
		errVariants = append(errVariants, errInfo)
	}

	_, _, _, resVars := (&roachpb.ResponseUnion{}).XXX_OneofFuncs()
	resVarInfos := make(map[string]variantInfo, len(resVars))
	for _, v := range resVars {
		resInfo := initVariant(v)
		resVariants = append(resVariants, resInfo)
		resVarInfos[resInfo.variantName] = resInfo
	}

	_, _, _, reqVars := (&roachpb.RequestUnion{}).XXX_OneofFuncs()
	reqResVariantMapping = make(map[variantInfo]variantInfo, len(reqVars))
	for _, v := range reqVars {
		reqInfo := initVariant(v)
		reqVariants = append(reqVariants, reqInfo)

		// The ResponseUnion variants match those in RequestUnion, with the
		// following exceptions:
		resName := reqInfo.variantName
		switch resName {
		case "TransferLease":
			resName = "RequestLease"
		}
		resInfo, ok := resVarInfos[resName]
		if !ok {
			panic(fmt.Sprintf("unknown response variant %q", resName))
		}
		reqResVariantMapping[reqInfo] = resInfo
	}
}

func genGetInner(w io.Writer, unionName, variantName string, variants []variantInfo) {
	fmt.Fprintf(w, `
// GetInner returns the %[2]s contained in the union.
func (ru %[1]s) GetInner() %[2]s {
	switch t := ru.GetValue().(type) {
`, unionName, variantName)

	for _, v := range variants {
		fmt.Fprintf(w, `	case *%s:
		return t.%s
`, v.variantType, v.variantName)
	}

	fmt.Fprint(w, `	default:
		return nil
	}
}
`)
}

func genSetInner(w io.Writer, unionName, variantName string, variants []variantInfo) {
	fmt.Fprintf(w, `
// SetInner sets the %[2]s in the union.
func (ru *%[1]s) SetInner(r %[2]s) bool {
	var union is%[1]s_Value
	switch t := r.(type) {
`, unionName, variantName)

	for _, v := range variants {
		fmt.Fprintf(w, `	case *%s:
		union = &%s{t}
`, v.msgType, v.variantType)
	}

	fmt.Fprint(w, `	default:
		return false
	}
	ru.Value = union
	return true
}
`)
}

func main() {
	initVariants()

	f, err := os.Create("batch_generated.go")
	if err != nil {
		fmt.Fprintln(os.Stderr, "Error opening file: ", err)
		os.Exit(1)
	}

	// First comment for github/Go; second for reviewable.
	// https://github.com/golang/go/issues/13560#issuecomment-277804473
	// https://github.com/Reviewable/Reviewable/wiki/FAQ#how-do-i-tell-reviewable-that-a-file-is-generated-and-should-not-be-reviewed
	fmt.Fprint(f, `// Code generated by gen_batch.go; DO NOT EDIT.
// GENERATED FILE DO NOT EDIT

package roachpb

import (
	"fmt"
	"strconv"
	"strings"
)
`)

	// Generate GetInner methods.
	genGetInner(f, "ErrorDetail", "error", errVariants)
	genGetInner(f, "RequestUnion", "Request", reqVariants)
	genGetInner(f, "ResponseUnion", "Response", resVariants)

	// Generate SetInner methods.
	genSetInner(f, "ErrorDetail", "error", errVariants)
	genSetInner(f, "RequestUnion", "Request", reqVariants)
	genSetInner(f, "ResponseUnion", "Response", resVariants)

	fmt.Fprintf(f, `
type reqCounts [%d]int32
`, len(reqVariants))

	// Generate getReqCounts function.
	fmt.Fprint(f, `
// getReqCounts returns the number of times each
// request type appears in the batch.
func (ba *BatchRequest) getReqCounts() reqCounts {
	var counts reqCounts
	for _, ru := range ba.Requests {
		switch ru.GetValue().(type) {
`)

	for i, v := range reqVariants {
		fmt.Fprintf(f, `		case *%s:
			counts[%d]++
`, v.variantType, i)
	}

	fmt.Fprint(f, `		default:
			panic(fmt.Sprintf("unsupported request: %+v", ru))
		}
	}
	return counts
}
`)

	// A few shorthands to help make the names more terse.
	shorthands := map[string]string{
		"Delete":      "Del",
		"Range":       "Rng",
		"Transaction": "Txn",
		"Reverse":     "Rev",
		"Admin":       "Adm",
		"Increment":   "Inc",
		"Conditional": "C",
		"Check":       "Chk",
		"Truncate":    "Trunc",
	}

	// Generate Summary function.
	fmt.Fprintf(f, `
var requestNames = []string{`)
	for _, v := range reqVariants {
		name := v.variantName
		for str, short := range shorthands {
			name = strings.Replace(name, str, short, -1)
		}
		fmt.Fprintf(f, `
	"%s",`, name)
	}
	fmt.Fprint(f, `
}
`)

	// We don't use Fprint to avoid go vet warnings about
	// formatting directives in string.
	fmt.Fprint(f, `
// Summary prints a short summary of the requests in a batch.
func (ba *BatchRequest) Summary() string {
	var b strings.Builder
	ba.WriteSummary(&b)
	return b.String()
}

// WriteSummary writes a short summary of the requests in a batch
// to the provided builder.
func (ba *BatchRequest) WriteSummary(b *strings.Builder) {
	if len(ba.Requests) == 0 {
		b.WriteString("empty batch")
		return
	}
	counts := ba.getReqCounts()
	var tmp [10]byte
	var comma bool
	for i, v := range counts {
		if v != 0 {
			if comma {
				b.WriteString(", ")
			}
			comma = true

			b.Write(strconv.AppendInt(tmp[:0], int64(v), 10))
			b.WriteString(" ")
			b.WriteString(requestNames[i])
		}
	}
}
`)

	// Generate CreateReply function.
	fmt.Fprint(f, `
// The following types are used to group the allocations of Responses
// and their corresponding isResponseUnion_Value union wrappers together.
`)
	allocTypes := make(map[string]string)
	for _, resV := range resVariants {
		allocName := strings.ToLower(resV.msgType[:1]) + resV.msgType[1:] + "Alloc"
		fmt.Fprintf(f, `type %s struct {
	union %s
	resp  %s
}
`, allocName, resV.variantType, resV.msgType)
		allocTypes[resV.variantName] = allocName
	}

	fmt.Fprint(f, `
// CreateReply creates replies for each of the contained requests, wrapped in a
// BatchResponse. The response objects are batch allocated to minimize
// allocation overhead.
func (ba *BatchRequest) CreateReply() *BatchResponse {
	br := &BatchResponse{}
	br.Responses = make([]ResponseUnion, len(ba.Requests))

	counts := ba.getReqCounts()

`)

	for i, v := range reqVariants {
		resV, ok := reqResVariantMapping[v]
		if !ok {
			panic(fmt.Sprintf("unknown response variant for %v", v))
		}
		fmt.Fprintf(f, "	var buf%d []%s\n", i, allocTypes[resV.variantName])
	}

	fmt.Fprint(f, `
	for i, r := range ba.Requests {
		switch r.GetValue().(type) {
`)

	for i, v := range reqVariants {
		resV, ok := reqResVariantMapping[v]
		if !ok {
			panic(fmt.Sprintf("unknown response variant for %v", v))
		}

		fmt.Fprintf(f, `		case *%[2]s:
			if buf%[1]d == nil {
				buf%[1]d = make([]%[3]s, counts[%[1]d])
			}
			buf%[1]d[0].union.%[4]s = &buf%[1]d[0].resp
			br.Responses[i].Value = &buf%[1]d[0].union
			buf%[1]d = buf%[1]d[1:]
`, i, v.variantType, allocTypes[resV.variantName], resV.variantName)
	}

	fmt.Fprintf(f, "%s", `		default:
			panic(fmt.Sprintf("unsupported request: %+v", r))
		}
	}
	return br
}
`)

	if err := f.Close(); err != nil {
		fmt.Fprintln(os.Stderr, "Error closing file: ", err)
		os.Exit(1)
	}
}
