package com.linkedin.metadata.graph;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.linkedin.common.Siblings;
import com.linkedin.common.UrnArray;
import com.linkedin.common.urn.Urn;
import com.linkedin.data.template.RecordTemplate;
import com.linkedin.metadata.entity.EntityService;
import com.linkedin.metadata.shared.ValidationUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import static com.linkedin.metadata.Constants.*;


@Slf4j
@RequiredArgsConstructor
public class SiblingGraphService {

  private final EntityService _entityService;
  private final GraphService _graphService;

  @Nonnull
  public EntityLineageResult getLineage(@Nonnull Urn entityUrn, @Nonnull LineageDirection direction, int offset,
      int count, int maxHops) {
    return ValidationUtils.validateEntityLineageResult(getLineage(
            entityUrn,
            direction,
            offset,
            count,
            maxHops,
            false,
            new HashSet<>(),
            null,
            null),
        _entityService);
  }

  /**
   * Traverse from the entityUrn towards the input direction up to maxHops number of hops
   * Abstracts away the concept of relationship types
   *
   * Unless overridden, it uses the lineage registry to fetch valid edge types and queries for them
   */
  @Nonnull
  public EntityLineageResult getLineage(@Nonnull Urn entityUrn, @Nonnull LineageDirection direction,
      int offset, int count, int maxHops, boolean separateSiblings, @Nonnull Set<Urn> visitedUrns,
      @Nullable Long startTimeMillis, @Nullable Long endTimeMillis) {
    if (separateSiblings) {
      return ValidationUtils.validateEntityLineageResult(_graphService.getLineage(
          entityUrn,
          direction,
          offset,
          count,
          maxHops,
          startTimeMillis,
          endTimeMillis), _entityService);
    }

    if (maxHops > 1) {
      throw new UnsupportedOperationException(
          String.format("More than 1 hop is not supported for %s", this.getClass().getSimpleName()));
    }

    EntityLineageResult entityLineage =
        _graphService.getLineage(
            entityUrn,
            direction,
            offset,
            count,
            maxHops,
            startTimeMillis,
            endTimeMillis);

    Siblings siblingAspectOfEntity = (Siblings) _entityService.getLatestAspect(entityUrn, SIBLINGS_ASPECT_NAME);

    // if you have siblings, we want to fetch their lineage too and merge it in
    if (siblingAspectOfEntity != null && siblingAspectOfEntity.hasSiblings()) {
      UrnArray siblingUrns = siblingAspectOfEntity.getSiblings();
      Set<Urn> allSiblingsInGroup = siblingUrns.stream().collect(Collectors.toSet());
      allSiblingsInGroup.add(entityUrn);

      // remove your siblings from your lineage
      entityLineage =
          filterLineageResultFromSiblings(entityUrn, allSiblingsInGroup, entityLineage, null);

      // Update offset and count to fetch the correct number of edges from the next sibling node
      offset = Math.max(0, offset - entityLineage.getTotal());
      count = Math.max(0, count - entityLineage.getRelationships().size());

      visitedUrns.add(entityUrn);
      // iterate through each sibling and include their lineage in the bunch
      for (Urn siblingUrn : siblingUrns) {
        if (visitedUrns.contains(siblingUrn)) {
          continue;
        }
        // need to call siblingGraphService to get sibling results for this sibling entity in case there is more than one sibling
        EntityLineageResult nextEntityLineage = filterLineageResultFromSiblings(siblingUrn, allSiblingsInGroup,
            getLineage(
                siblingUrn,
                direction,
                offset,
                count,
                maxHops,
                false,
                visitedUrns,
                startTimeMillis,
                endTimeMillis),
            entityLineage);

        // Update offset and count to fetch the correct number of edges from the next sibling node
        offset = Math.max(0, offset - nextEntityLineage.getTotal());
        count = Math.max(0, count - nextEntityLineage.getCount() - entityLineage.getCount());

        entityLineage = nextEntityLineage;
      };
    }

    return ValidationUtils.validateEntityLineageResult(entityLineage, _entityService);
  }

  // takes a lineage result and removes any nodes that are siblings of some other node already in the result
  private EntityLineageResult filterLineageResultFromSiblings(
      @Nonnull final Urn urn,
      @Nonnull final Set<Urn> allSiblingsInGroup,
      @Nonnull final EntityLineageResult entityLineageResult,
      @Nullable final EntityLineageResult existingResult
  ) {
    // 1) remove the source entities siblings from this entity's downstreams
    final List<LineageRelationship> filteredRelationships = entityLineageResult.getRelationships()
        .stream()
        .filter(lineageRelationship -> !allSiblingsInGroup.contains(lineageRelationship.getEntity())
            || lineageRelationship.getEntity().equals(urn))
        .collect(Collectors.toList());

    // 2) filter out existing lineage to avoid duplicates in our combined result
    final Set<Urn> existingUrns = existingResult != null
        ? existingResult.getRelationships().stream().map(LineageRelationship::getEntity).collect(Collectors.toSet())
        : new HashSet<>();
    List<LineageRelationship> uniqueFilteredRelationships = filteredRelationships.stream().filter(
        lineageRelationship -> !existingUrns.contains(lineageRelationship.getEntity())).collect(Collectors.toList());

    // 3) combine this entity's lineage with the lineage we've already seen and remove duplicates
    final List<LineageRelationship> combinedResults = Stream.concat(
            uniqueFilteredRelationships.stream(),
            existingResult != null ? existingResult.getRelationships().stream() : ImmutableList.<LineageRelationship>of().stream())
        .collect(Collectors.toList());

    // 4) fetch the siblings of each lineage result
    final Set<Urn> combinedResultUrns = combinedResults.stream().map(result -> result.getEntity()).collect(Collectors.toSet());

    final Map<Urn, List<RecordTemplate>> siblingAspects =
        _entityService.getLatestAspects(combinedResultUrns, ImmutableSet.of(SIBLINGS_ASPECT_NAME));

    // 5) if you are not primary & your sibling is in the results, filter yourself out of the return set
    uniqueFilteredRelationships = combinedResults.stream().filter(result -> {
      Optional<RecordTemplate> optionalSiblingsAspect = siblingAspects.get(result.getEntity()).stream().filter(
          aspect -> aspect instanceof Siblings
      ).findAny();

      if (!optionalSiblingsAspect.isPresent()) {
        return true;
      }


      final Siblings siblingsAspect = (Siblings) optionalSiblingsAspect.get();

      if (siblingsAspect.isPrimary()) {
        return true;
      }

      // if you are not primary and your sibling exists in the result set, filter yourself out
      if (siblingsAspect.getSiblings().stream().anyMatch(
          sibling -> combinedResultUrns.contains(sibling)
      )) {
        return false;
      }

      return true;
    }).collect(Collectors.toList());

    entityLineageResult.setRelationships(new LineageRelationshipArray(uniqueFilteredRelationships));
    entityLineageResult.setTotal(entityLineageResult.getTotal() + (existingResult != null ? existingResult.getTotal() : 0));
    entityLineageResult.setCount(uniqueFilteredRelationships.size());
    return ValidationUtils.validateEntityLineageResult(entityLineageResult, _entityService);
  }

}
