package assert

import (
	"errors"
	"fmt"
	"go/ast"

	"gotest.tools/v3/assert/cmp"
	"gotest.tools/v3/internal/format"
	"gotest.tools/v3/internal/source"
)

// RunComparison and return Comparison.Success. If the comparison fails a messages
// will be printed using t.Log.
func RunComparison(
	t LogT,
	argSelector argSelector,
	f cmp.Comparison,
	msgAndArgs ...interface{},
) bool {
	if ht, ok := t.(helperT); ok {
		ht.Helper()
	}
	result := f()
	if result.Success() {
		return true
	}

	if source.IsUpdate() {
		if updater, ok := result.(updateExpected); ok {
			const stackIndex = 3 // Assert/Check, assert, RunComparison
			err := updater.UpdatedExpected(stackIndex)
			switch {
			case err == nil:
				return true
			case errors.Is(err, source.ErrNotFound):
				// do nothing, fallthrough to regular failure message
			default:
				t.Log("failed to update source", err)
				return false
			}
		}
	}

	var message string
	switch typed := result.(type) {
	case resultWithComparisonArgs:
		const stackIndex = 3 // Assert/Check, assert, RunComparison
		args, err := source.CallExprArgs(stackIndex)
		if err != nil {
			t.Log(err.Error())
		}
		message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
	case resultBasic:
		message = typed.FailureMessage()
	default:
		message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
	}

	t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
	return false
}

type resultWithComparisonArgs interface {
	FailureMessage(args []ast.Expr) string
}

type resultBasic interface {
	FailureMessage() string
}

type updateExpected interface {
	UpdatedExpected(stackIndex int) error
}

// filterPrintableExpr filters the ast.Expr slice to only include Expr that are
// easy to read when printed and contain relevant information to an assertion.
//
// Ident and SelectorExpr are included because they print nicely and the variable
// names may provide additional context to their values.
// BasicLit and CompositeLit are excluded because their source is equivalent to
// their value, which is already available.
// Other types are ignored for now, but could be added if they are relevant.
func filterPrintableExpr(args []ast.Expr) []ast.Expr {
	result := make([]ast.Expr, len(args))
	for i, arg := range args {
		if isShortPrintableExpr(arg) {
			result[i] = arg
			continue
		}

		if starExpr, ok := arg.(*ast.StarExpr); ok {
			result[i] = starExpr.X
			continue
		}
	}
	return result
}

func isShortPrintableExpr(expr ast.Expr) bool {
	switch expr.(type) {
	case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
		return true
	case *ast.BinaryExpr, *ast.UnaryExpr:
		return true
	default:
		// CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr
		return false
	}
}

type argSelector func([]ast.Expr) []ast.Expr

// ArgsAfterT selects args starting at position 1. Used when the caller has a
// testing.T as the first argument, and the args to select should follow it.
func ArgsAfterT(args []ast.Expr) []ast.Expr {
	if len(args) < 1 {
		return nil
	}
	return args[1:]
}

// ArgsFromComparisonCall selects args from the CallExpression at position 1.
// Used when the caller has a testing.T as the first argument, and the args to
// select are passed to the cmp.Comparison at position 1.
func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr {
	if len(args) <= 1 {
		return nil
	}
	if callExpr, ok := args[1].(*ast.CallExpr); ok {
		return callExpr.Args
	}
	return nil
}

// ArgsAtZeroIndex selects args from the CallExpression at position 1.
// Used when the caller accepts a single cmp.Comparison argument.
func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr {
	if len(args) == 0 {
		return nil
	}
	if callExpr, ok := args[0].(*ast.CallExpr); ok {
		return callExpr.Args
	}
	return nil
}