package tocraft.walkers.skills.impl;

import com.mojang.serialization.Codec;
import com.mojang.serialization.codecs.RecordCodecBuilder;
import org.jetbrains.annotations.NotNull;
import tocraft.walkers.Walkers;
import tocraft.walkers.skills.ShapeSkill;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Stream;
import net.minecraft.class_1299;
import net.minecraft.class_1309;
import net.minecraft.class_2960;
import net.minecraft.class_6862;
import net.minecraft.class_7923;
import net.minecraft.class_7924;

@SuppressWarnings("unused")
public class RiderSkill<E extends class_1309> extends ShapeSkill<E> {
    public static final class_2960 ID = Walkers.id("rider");
    public static final Codec<RiderSkill<?>> CODEC = RecordCodecBuilder.create((instance) -> instance.group(
            Codec.list(class_2960.field_25139).optionalFieldOf("rideable", new ArrayList<>()).forGetter(o -> o.rideableTypes.stream().map(class_7923.field_41177::method_10221).toList()),
            Codec.list(class_2960.field_25139).optionalFieldOf("rideable_tags", new ArrayList<>()).forGetter(o -> o.rideableTags.stream().map(class_6862::comp_327).toList())
    ).apply(instance, instance.stable((rideableTypeIds, rideableTagIds) -> {
        List<class_1299<?>> rideableTypes = new ArrayList<>();
        List<class_6862<class_1299<?>>> rideableTags = new ArrayList<>();
        for (class_2960 rideableTypeId : rideableTypeIds) {
            if (class_7923.field_41177.method_10250(rideableTypeId)) {
                rideableTypes.add(class_7923.field_41177.method_10223(rideableTypeId));
            }
        }
        for (class_2960 rideableTagId : rideableTagIds) {
            rideableTags.add(class_6862.method_40092(class_7924.field_41266, rideableTagId));
        }
        return new RiderSkill<>(new ArrayList<>(), rideableTypes, new ArrayList<>(), rideableTags);
    })));

    private final List<Predicate<class_1309>> rideablePredicates;
    private final List<class_1299<?>> rideableTypes;
    private final List<Class<? extends class_1309>> rideableClasses;
    private final List<class_6862<class_1299<?>>> rideableTags;

    public static RiderSkill<?> ofRideableType(class_1299<?>... rideable) {
        return new RiderSkill<>(Stream.of(rideable).map(entry -> (Predicate<class_1309>) entity -> entity.method_5864().equals(entry)).toList());
    }

    @SafeVarargs
    public static RiderSkill<?> ofRideableClass(Class<? extends class_1309>... rideable) {
        return new RiderSkill<>(Stream.of(rideable).map(entry -> (Predicate<class_1309>) entry::isInstance).toList());
    }

    public RiderSkill(@NotNull List<Predicate<class_1309>> rideablePredicates) {
        this(rideablePredicates, new ArrayList<>(), new ArrayList<>(), new ArrayList<>());
    }

    public RiderSkill(@NotNull List<Predicate<class_1309>> rideablePredicates, @NotNull List<class_1299<?>> rideableTypes, @NotNull List<Class<? extends class_1309>> rideableClasses, @NotNull List<class_6862<class_1299<?>>> rideableTags) {
        this.rideablePredicates = rideablePredicates;
        this.rideableTypes = rideableTypes;
        this.rideableClasses = rideableClasses;
        this.rideableTags = rideableTags;
    }

    public boolean isRideable(class_1309 entity) {
        if (rideableTypes.contains(entity.method_5864())) return true;
        for (Class<? extends class_1309> rideableClass : rideableClasses) {
            if (rideableClass.isInstance(entity)) return true;
        }
        for (class_6862<class_1299<?>> rideableTag : rideableTags) {
            if (entity.method_5864().method_20210(rideableTag)) return true;
        }
        for (Predicate<class_1309> rideablePredicate : rideablePredicates) {
            if (rideablePredicate.test(entity)) return true;
        }
        return false;
    }

    @Override
    public class_2960 getId() {
        return ID;
    }

    @Override
    public Codec<? extends ShapeSkill<?>> codec() {
        return CODEC;
    }
}
