﻿// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.Threading.Analyzers
{
    using System;
    using System.Collections.Generic;
    using System.Collections.Immutable;
    using System.Linq;
    using System.Text;
    using Microsoft.CodeAnalysis;
    using Microsoft.CodeAnalysis.CSharp;
    using Microsoft.CodeAnalysis.CSharp.Syntax;
    using Microsoft.CodeAnalysis.Diagnostics;

    internal static class CSharpCommonInterest
    {
        internal static readonly IImmutableSet<SyntaxKind> MethodSyntaxKinds = ImmutableHashSet.Create(
            SyntaxKind.ConstructorDeclaration,
            SyntaxKind.MethodDeclaration,
            SyntaxKind.OperatorDeclaration,
            SyntaxKind.AnonymousMethodExpression,
            SyntaxKind.SimpleLambdaExpression,
            SyntaxKind.ParenthesizedLambdaExpression,
            SyntaxKind.GetAccessorDeclaration,
            SyntaxKind.SetAccessorDeclaration,
            SyntaxKind.AddAccessorDeclaration,
            SyntaxKind.RemoveAccessorDeclaration);

        /// <summary>
        /// This is an explicit rule to ignore the code that was generated by Xaml2CS.
        /// </summary>
        /// <remarks>
        /// The generated code has the comments like this:
        /// <![CDATA[
        ///   //------------------------------------------------------------------------------
        ///   // <auto-generated>
        /// ]]>
        /// This rule is based on the fact the keyword "&lt;auto-generated&gt;" should be found in the comments.
        /// </remarks>
        internal static bool ShouldIgnoreContext(SyntaxNodeAnalysisContext context)
        {
            NamespaceDeclarationSyntax? namespaceDeclaration = context.Node.FirstAncestorOrSelf<NamespaceDeclarationSyntax>();
            if (namespaceDeclaration is object)
            {
                foreach (SyntaxTrivia trivia in namespaceDeclaration.NamespaceKeyword.GetAllTrivia())
                {
                    const string autoGeneratedKeyword = @"<auto-generated>";
                    if (trivia.FullSpan.Length > autoGeneratedKeyword.Length
                        && trivia.ToString().Contains(autoGeneratedKeyword))
                    {
                        return true;
                    }
                }
            }

            return false;
        }

        internal static void InspectMemberAccess(
            SyntaxNodeAnalysisContext context,
            MemberAccessExpressionSyntax? memberAccessSyntax,
            DiagnosticDescriptor descriptor,
            IEnumerable<CommonInterest.SyncBlockingMethod> problematicMethods,
            bool ignoreIfInsideAnonymousDelegate = false)
        {
            if (descriptor is null)
            {
                throw new ArgumentNullException(nameof(descriptor));
            }

            if (memberAccessSyntax is null)
            {
                return;
            }

            if (ShouldIgnoreContext(context))
            {
                return;
            }

            if (ignoreIfInsideAnonymousDelegate && context.Node.FirstAncestorOrSelf<AnonymousFunctionExpressionSyntax>() is object)
            {
                // We do not analyze JTF.Run inside anonymous functions because
                // they are so often used as callbacks where the signature is constrained.
                return;
            }

            if (CSharpUtils.IsWithinNameOf(context.Node as ExpressionSyntax))
            {
                // We do not consider arguments to nameof( ) because they do not represent invocations of code.
                return;
            }

            ITypeSymbol? typeReceiver = context.SemanticModel.GetTypeInfo(memberAccessSyntax.Expression).Type;
            if (typeReceiver is object)
            {
                foreach (CommonInterest.SyncBlockingMethod item in problematicMethods)
                {
                    if (memberAccessSyntax.Name.Identifier.Text == item.Method.Name &&
                        typeReceiver.Name == item.Method.ContainingType.Name &&
                        typeReceiver.BelongsToNamespace(item.Method.ContainingType.Namespace))
                    {
                        if (HasTaskCompleted(context, memberAccessSyntax))
                        {
                            return;
                        }

                        Location? location = memberAccessSyntax.Name.GetLocation();
                        context.ReportDiagnostic(Diagnostic.Create(descriptor, location));
                    }
                }
            }
        }

        private static SyntaxNode? GetEnclosingBlock(SyntaxNode node)
        {
            while (node is not null)
            {
                if (node.IsKind(SyntaxKind.Block))
                {
                    return node;
                }

                node = node.Parent;
            }

            return null;
        }

        private static bool IsVariablePassedToInvocation(InvocationExpressionSyntax invocationExpr, string variableName, bool byRef)
        {
            ArgumentListSyntax? argList = invocationExpr.ChildNodes().OfType<ArgumentListSyntax>().FirstOrDefault();
            if (argList is null)
            {
                return false;
            }

            foreach (ArgumentSyntax arg in argList.ChildNodes().OfType<ArgumentSyntax>())
            {
                // `byRef` includes `out` parameters because they are the same as `ref` except don't require initialization first.
                if (byRef && !arg.RefKindKeyword.IsKind(SyntaxKind.RefKeyword) && !arg.RefKindKeyword.IsKind(SyntaxKind.OutKeyword))
                {
                    continue;
                }

                IdentifierNameSyntax identiferName = arg.ChildNodes().OfType<IdentifierNameSyntax>().FirstOrDefault();
                if (identiferName is null)
                {
                    return false;
                }

                if (identiferName.Identifier.ValueText == variableName)
                {
                    return true;
                }
            }

            return false;
        }

        private static bool IsTaskCompletedWithWhenAll(SyntaxNodeAnalysisContext context, InvocationExpressionSyntax invocationExpr, string taskVariableName)
        {
            // We only care about awaited invocations, because an un-awaited Task.WhenAll will be an error.
            if (invocationExpr.Parent is not AwaitExpressionSyntax)
            {
                return false;
            }

            IEnumerable<MemberAccessExpressionSyntax>? memberAccessList = invocationExpr.ChildNodes().OfType<MemberAccessExpressionSyntax>();
            if (memberAccessList.Count() != 1)
            {
                return false;
            }

            MemberAccessExpressionSyntax? memberAccess = memberAccessList.First();

            // Does the invocation have the expected `Task.WhenAll` syntax? This is cheaper to verify before looking up its semantic type.
            bool correctSyntax = memberAccess.Expression is IdentifierNameSyntax { Identifier.ValueText: Types.Task.TypeName }
                && memberAccess.Name is IdentifierNameSyntax { Identifier.ValueText: Types.Task.WhenAll };

            if (!correctSyntax)
            {
                return false;
            }

            // Is this `Task.WhenAll` invocation from the System.Threading.Tasks.Task type?
            ITypeSymbol? classType = context.SemanticModel.GetTypeInfo(memberAccess.Expression).Type;
            var correctType = classType.Name == Types.Task.TypeName && classType.BelongsToNamespace(Types.Task.Namespace);
            if (!correctType)
            {
                return false;
            }

            // Is the task variable passed as an argument to `Task.WhenAll`?
            return IsVariablePassedToInvocation(invocationExpr, taskVariableName, byRef: false);
        }

        private static bool HasTaskCompleted(SyntaxNodeAnalysisContext context, MemberAccessExpressionSyntax memberAccessSyntax)
        {
            SyntaxNode? enclosingBlock = GetEnclosingBlock(memberAccessSyntax);
            if (enclosingBlock is null)
            {
                return false;
            }

            // Get the task variable name from the problematic member access expression so that we can later try
            // and determine if it has been used in a `Task.WhenAll` invocation.
            // Examples:
            //   task1.Result;
            //   task2.GetAwaiter().GetResult();
            string? taskVariableName = null;
            ExpressionSyntax parentExpr = memberAccessSyntax.Expression;
            while (parentExpr is not null)
            {
                if (parentExpr is IdentifierNameSyntax identifierExpr)
                {
                    taskVariableName = identifierExpr.Identifier.ValueText;
                    break;
                }
                else if (parentExpr is MemberAccessExpressionSyntax memberAccessExpr)
                {
                    parentExpr = memberAccessExpr.Expression;
                }
                else if (parentExpr is InvocationExpressionSyntax invocExpr)
                {
                    parentExpr = invocExpr.Expression;
                }
                else
                {
                    break;
                }
            }

            if (taskVariableName is null)
            {
                return false;
            }

            // Find all `Task.WhenAll` invocations that precede the problematic member access, which are also in the same enclosing block.
            IEnumerable<InvocationExpressionSyntax>? taskWhenAllInvocationList =
                from invoc in enclosingBlock.DescendantNodes().OfType<InvocationExpressionSyntax>()
                where memberAccessSyntax.SpanStart > invoc.Span.End &&
                      IsTaskCompletedWithWhenAll(context, invoc, taskVariableName)
                select invoc;

            if (!taskWhenAllInvocationList.Any())
            {
                return false;
            }

            // If a `Task.WhenAll` invocation precedes the problematic member access, and the task variable has not been
            // invalidated in between, then we consider the task to be completed.
            // Example:
            //   await Task.WhenAll(task1, task2, task3);
            //   task1 = Task.Run(...);    // Invalidates `task1`
            //   DoSomething(ref task2);   // Invalidates `task2`
            //   task1.Result;             // Warn
            //   task2.Result;             // Warn
            //   task3.Result;             // No warning, task3 has not been invalidated in between WhenAll and this problematic member access
            foreach (InvocationExpressionSyntax? taskWhenAllInvocation in taskWhenAllInvocationList)
            {
                // Has the task variable been assigned to a new task?
                IEnumerable<AssignmentExpressionSyntax>? assignmentList =
                    from assign in enclosingBlock.DescendantNodes().OfType<AssignmentExpressionSyntax>()
                    where assign.SpanStart > taskWhenAllInvocation.Span.End &&
                          assign.SpanStart < memberAccessSyntax.SpanStart &&
                          ((IdentifierNameSyntax)assign.Left).Identifier.ValueText == taskVariableName
                    select assign;

                if (assignmentList.Any())
                {
                    return false;
                }

                // Has the task variable been passed by ref to a method?
                // If so, we must assume the worst case that the method has assigned it to a new task.
                IEnumerable<InvocationExpressionSyntax>? invocationList =
                    from invoc in enclosingBlock.DescendantNodes().OfType<InvocationExpressionSyntax>()
                    where invoc.SpanStart > taskWhenAllInvocation.Span.End &&
                          invoc.SpanStart < memberAccessSyntax.SpanStart &&
                          IsVariablePassedToInvocation(invoc, taskVariableName, byRef: true)
                    select invoc;

                return !invocationList.Any();
            }

            return false;
        }
    }
}
