package traceql

import (
	"errors"
	"fmt"
	"math"
	"regexp"
	"strings"
)

func (g GroupOperation) evaluate(ss []*Spanset) ([]*Spanset, error) {
	result := make([]*Spanset, 0, len(ss))
	groups := g.groupBuffer

	// Iterate over each spanset in the input slice
	for _, spanset := range ss {
		// clear out the groups
		for k := range groups {
			delete(groups, k)
		}

		// Iterate over each span in the spanset
		for _, span := range spanset.Spans {
			// Execute the FieldExpression for the span
			result, err := g.Expression.execute(span)
			if err != nil {
				return nil, err
			}

			// Check if the result already has a group in the map
			group, ok := groups[result]
			if !ok {
				// If not, create a new group and add it to the map
				group = &Spanset{}
				// copy all existing attributes forward
				group.Attributes = append(group.Attributes, spanset.Attributes...)
				group.AddAttribute(g.String(), result)
				groups[result] = group
			}

			// Add the current spanset to the group
			group.Spans = append(group.Spans, span)
		}

		// add all groups created by this spanset to the result
		for _, group := range groups {
			result = append(result, group)
		}
	}

	return result, nil
}

// CoalesceOperation undoes grouping. It takes spansets and recombines them into
// one by trace id. Since all spansets are guaranteed to be from the same traceid
// due to the structure of the engine we can cheat and just recombine all spansets
// in ss into one without checking.
func (CoalesceOperation) evaluate(ss []*Spanset) ([]*Spanset, error) {
	l := 0
	for _, spanset := range ss {
		l += len(spanset.Spans)
	}
	result := &Spanset{
		Spans: make([]Span, 0, l),
	}
	for _, spanset := range ss {
		result.Spans = append(result.Spans, spanset.Spans...)
	}
	return []*Spanset{result}, nil
}

func (o SpansetOperation) evaluate(input []*Spanset) (output []*Spanset, err error) {

	for i := range input {
		curr := input[i : i+1]

		lhs, err := o.LHS.evaluate(curr)
		if err != nil {
			return nil, err
		}

		rhs, err := o.RHS.evaluate(curr)
		if err != nil {
			return nil, err
		}

		switch o.Op {
		case OpSpansetAnd:
			if len(lhs) > 0 && len(rhs) > 0 {
				matchingSpanset := input[i].clone()
				matchingSpanset.Spans = uniqueSpans(lhs, rhs)
				output = append(output, matchingSpanset)
			}

		case OpSpansetUnion:
			if len(lhs) > 0 || len(rhs) > 0 {
				matchingSpanset := input[i].clone()
				matchingSpanset.Spans = uniqueSpans(lhs, rhs)
				output = append(output, matchingSpanset)
			}

		default:
			return nil, fmt.Errorf("spanset operation (%v) not supported", o.Op)
		}
	}

	return output, nil
}

// SelectOperation evaluate is a no-op b/c the fetch layer has already decorated the spans with the requested attributes
func (o SelectOperation) evaluate(input []*Spanset) (output []*Spanset, err error) {
	return input, nil
}

func (f ScalarFilter) evaluate(input []*Spanset) (output []*Spanset, err error) {

	// TODO we solve this gap where pipeline elements and scalar binary
	// operations meet in a generic way. For now we only support well-defined
	// case: aggregate binop static
	switch l := f.lhs.(type) {
	case Aggregate:
		switch r := f.rhs.(type) {
		case Static:
			input, err = l.evaluate(input)
			if err != nil {
				return nil, err
			}

			for _, ss := range input {
				res, err := binOp(f.op, ss.Scalar, r)
				if err != nil {
					return nil, fmt.Errorf("scalar filter (%v) failed: %v", f, err)
				}
				if res {
					output = append(output, ss)
				}
			}

		default:
			return nil, fmt.Errorf("scalar filter lhs (%v) not supported", f.lhs)
		}

	default:
		return nil, fmt.Errorf("scalar filter lhs (%v) not supported", f.lhs)
	}

	return output, nil
}

func (a Aggregate) evaluate(input []*Spanset) (output []*Spanset, err error) {
	for _, ss := range input {
		switch a.op {
		case aggregateCount:
			copy := ss.clone()
			copy.Scalar = NewStaticInt(len(ss.Spans))
			copy.AddAttribute(a.String(), copy.Scalar)
			output = append(output, copy)

		case aggregateAvg:
			var sum *Static
			count := 0
			for _, s := range ss.Spans {
				val, err := a.e.execute(s)
				if err != nil {
					return nil, err
				}

				if sum == nil {
					sum = &val
				} else {
					sum.sumInto(val)
				}
				count++
			}

			copy := ss.clone()
			copy.Scalar = sum.divideBy(float64(count))
			copy.AddAttribute(a.String(), copy.Scalar)
			output = append(output, copy)

		case aggregateMax:
			var max *Static
			for _, s := range ss.Spans {
				val, err := a.e.execute(s)
				if err != nil {
					return nil, err
				}
				if max == nil || val.compare(max) == 1 {
					max = &val
				}
			}
			copy := ss.clone()
			copy.Scalar = *max
			copy.AddAttribute(a.String(), copy.Scalar)
			output = append(output, copy)

		case aggregateMin:
			var min *Static
			for _, s := range ss.Spans {
				val, err := a.e.execute(s)
				if err != nil {
					return nil, err
				}
				if min == nil || val.compare(min) == -1 {
					min = &val
				}
			}
			copy := ss.clone()
			copy.Scalar = *min
			copy.AddAttribute(a.String(), copy.Scalar)
			output = append(output, copy)

		case aggregateSum:
			var sum *Static
			for _, s := range ss.Spans {
				val, err := a.e.execute(s)
				if err != nil {
					return nil, err
				}
				if sum == nil {
					sum = &val
				} else {
					sum.sumInto(val)
				}
			}
			copy := ss.clone()
			copy.Scalar = *sum
			copy.AddAttribute(a.String(), copy.Scalar)
			output = append(output, copy)

		default:
			return nil, fmt.Errorf("aggregate operation (%v) not supported", a.op)
		}
	}

	return output, nil
}

func (o BinaryOperation) execute(span Span) (Static, error) {
	lhs, err := o.LHS.execute(span)
	if err != nil {
		return NewStaticNil(), err
	}

	rhs, err := o.RHS.execute(span)
	if err != nil {
		return NewStaticNil(), err
	}

	// Ensure the resolved types are still valid
	lhsT := lhs.impliedType()
	rhsT := rhs.impliedType()
	if !lhsT.isMatchingOperand(rhsT) {
		return NewStaticBool(false), nil
	}

	if !o.Op.binaryTypesValid(lhsT, rhsT) {
		return NewStaticBool(false), nil
	}

	if lhsT == TypeString && rhsT == TypeString {
		switch o.Op {
		case OpGreater:
			return NewStaticBool(strings.Compare(lhs.String(), rhs.String()) > 0), nil
		case OpGreaterEqual:
			return NewStaticBool(strings.Compare(lhs.String(), rhs.String()) >= 0), nil
		case OpLess:
			return NewStaticBool(strings.Compare(lhs.String(), rhs.String()) < 0), nil
		case OpLessEqual:
			return NewStaticBool(strings.Compare(lhs.String(), rhs.String()) <= 0), nil
		default:
		}
	}

	switch o.Op {
	case OpAdd:
		return NewStaticFloat(lhs.asFloat() + rhs.asFloat()), nil
	case OpSub:
		return NewStaticFloat(lhs.asFloat() - rhs.asFloat()), nil
	case OpDiv:
		return NewStaticFloat(lhs.asFloat() / rhs.asFloat()), nil
	case OpMod:
		return NewStaticFloat(math.Mod(lhs.asFloat(), rhs.asFloat())), nil
	case OpMult:
		return NewStaticFloat(lhs.asFloat() * rhs.asFloat()), nil
	case OpGreater:
		return NewStaticBool(lhs.asFloat() > rhs.asFloat()), nil
	case OpGreaterEqual:
		return NewStaticBool(lhs.asFloat() >= rhs.asFloat()), nil
	case OpLess:
		return NewStaticBool(lhs.asFloat() < rhs.asFloat()), nil
	case OpLessEqual:
		return NewStaticBool(lhs.asFloat() <= rhs.asFloat()), nil
	case OpPower:
		return NewStaticFloat(math.Pow(lhs.asFloat(), rhs.asFloat())), nil
	case OpEqual:
		return NewStaticBool(lhs.Equals(rhs)), nil
	case OpNotEqual:
		return NewStaticBool(!lhs.Equals(rhs)), nil
	case OpRegex:
		matched, err := regexp.MatchString(rhs.S, lhs.S)
		return NewStaticBool(matched), err
	case OpNotRegex:
		matched, err := regexp.MatchString(rhs.S, lhs.S)
		return NewStaticBool(!matched), err
	case OpAnd:
		return NewStaticBool(lhs.B && rhs.B), nil
	case OpOr:
		return NewStaticBool(lhs.B || rhs.B), nil
	default:
		return NewStaticNil(), errors.New("unexpected operator " + o.Op.String())
	}
}

// why does this and the above exist?
func binOp(op Operator, lhs, rhs Static) (bool, error) {
	lhsT := lhs.impliedType()
	rhsT := rhs.impliedType()
	if !lhsT.isMatchingOperand(rhsT) {
		return false, nil
	}

	if !op.binaryTypesValid(lhsT, rhsT) {
		return false, nil
	}

	switch op {
	case OpGreater:
		return lhs.asFloat() > rhs.asFloat(), nil
	case OpGreaterEqual:
		return lhs.asFloat() >= rhs.asFloat(), nil
	case OpLess:
		return lhs.asFloat() < rhs.asFloat(), nil
	case OpLessEqual:
		return lhs.asFloat() <= rhs.asFloat(), nil
	case OpEqual:
		return lhs.Equals(rhs), nil
	case OpNotEqual:
		return !lhs.Equals(rhs), nil
	case OpAnd:
		return lhs.B && rhs.B, nil
	case OpOr:
		return lhs.B || rhs.B, nil
	}

	return false, errors.New("unexpected operator " + op.String())
}

func (o UnaryOperation) execute(span Span) (Static, error) {
	static, err := o.Expression.execute(span)
	if err != nil {
		return NewStaticNil(), err
	}

	if o.Op == OpNot {
		if static.Type != TypeBoolean {
			return NewStaticNil(), fmt.Errorf("expression (%v) expected a boolean, but got %v", o, static.Type)
		}
		return NewStaticBool(!static.B), nil
	}
	if o.Op == OpSub {
		if !static.Type.isNumeric() {
			return NewStaticNil(), fmt.Errorf("expression (%v) expected a numeric, but got %v", o, static.Type)
		}
		switch static.Type {
		case TypeInt:
			return NewStaticInt(-1 * static.N), nil
		case TypeFloat:
			return NewStaticFloat(-1 * static.F), nil
		case TypeDuration:
			return NewStaticDuration(-1 * static.D), nil
		}
	}

	return NewStaticNil(), errors.New("UnaryOperation has Op different from Not and Sub")
}

func (s Static) execute(span Span) (Static, error) {
	return s, nil
}

func (a Attribute) execute(span Span) (Static, error) {
	atts := span.Attributes()
	static, ok := atts[a]
	if ok {
		return static, nil
	}

	if a.Scope == AttributeScopeNone {
		for attribute, static := range atts {
			if a.Name == attribute.Name && attribute.Scope == AttributeScopeSpan {
				return static, nil
			}
		}
		for attribute, static := range atts {
			if a.Name == attribute.Name {
				return static, nil
			}
		}
	}

	return NewStaticNil(), nil
}

func uniqueSpans(ss1 []*Spanset, ss2 []*Spanset) []Span {
	ss1Count := 0
	ss2Count := 0

	for _, ss1 := range ss1 {
		ss1Count += len(ss1.Spans)
	}
	for _, ss2 := range ss2 {
		ss2Count += len(ss2.Spans)
	}
	output := make([]Span, 0, ss1Count+ss2Count)

	ssCount := ss2Count
	ssSmaller := ss2
	ssLarger := ss1
	if ss1Count < ss2Count {
		ssCount = ss1Count
		ssSmaller = ss1
		ssLarger = ss2
	}

	// make the map with ssSmaller
	spans := make(map[Span]struct{}, ssCount)
	for _, ss := range ssSmaller {
		for _, span := range ss.Spans {
			spans[span] = struct{}{}
			output = append(output, span)
		}
	}

	// only add the spans from ssLarger that aren't in the map
	for _, ss := range ssLarger {
		for _, span := range ss.Spans {
			if _, ok := spans[span]; !ok {
				output = append(output, span)
			}
		}
	}

	return output
}
