需要检查代码是否包含某些标识符



我将使用Roslyn动态编译和执行代码,如下例所示。我想确保代码不会违反我的一些规则,比如:

  • 不使用反射
  • 不使用HttpClient或WebClient
  • 不在System.IO命名空间中使用File或Directory类
  • 不使用源生成器
  • 不调用非托管代码

在下面的代码中,我将在哪里插入我的规则/检查,以及我将如何执行这些规则/检查?

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;
using System.Reflection;
using System.Runtime.CompilerServices;
string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;
namespace Customization
{
public class Script
{
public async Task<object?> RunAsync(object? data)
{
//The following should not be allowed
File.Delete(@""C:Templog.txt"");
return await Task.FromResult(data);
}
}
}";
var compilation = Compile(code);
var bytes = Build(compilation);
Console.WriteLine("Done");
CSharpCompilation Compile(string code)
{
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);
string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
{
throw new ArgumentNullException("Cannot determine path to current assembly.");
}
string assemblyName = Path.GetRandomFileName();
List<MetadataReference> references = new();
references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));
CSharpCompilation compilation = CSharpCompilation.Create(
assemblyName,
syntaxTrees: new[] { syntaxTree },
references: references,
options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));

SemanticModel model = compilation.GetSemanticModel(syntaxTree);
CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();
//TODO: Check the code for use classes that are not allowed such as File in the System.IO namespace.
//Not exactly sure how to walk through identifiers.
IEnumerable<IdentifierNameSyntax> identifiers = root.DescendantNodes()
.Where(s => s is IdentifierNameSyntax)
.Cast<IdentifierNameSyntax>();

return compilation;
}
[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
using (MemoryStream ms = new())
{
//Emit to catch build errors
EmitResult emitResult = compilation.Emit(ms);
if (!emitResult.Success)
{
Diagnostic? firstError =
emitResult
.Diagnostics
.FirstOrDefault
(
diagnostic => diagnostic.IsWarningAsError ||
diagnostic.Severity == DiagnosticSeverity.Error
);
throw new Exception(firstError?.GetMessage());
}
return ms.ToArray();
}
}

检查特定类的使用情况时,可以使用OfType<>()方法查找IdentifierNameSyntax类型的节点,并按类名过滤结果:

var names = root.DescendantNodes()
.OfType<IdentifierNameSyntax>()
.Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));

然后可以使用SemanticModel检查类的名称空间:

foreach (var name in names)
{
var typeInfo = model.GetTypeInfo(name);
if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
{
throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
}
}

若要检查反射或非托管代码的使用情况,可以检查相关usingSystem.ReflectionSystem.Runtime.InteropServices

if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
{
throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
}

这会捕捉到using未使用的情况,即没有实际的反射或非托管代码,但这似乎是一种可以接受的折衷

我不知道该怎么处理源代码生成器检查,因为这些检查通常作为项目引用包含,所以我不知道它们如何在动态编译的代码中运行。

将支票放在同一个地方并更新您的代码可以得到:

using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;
string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;
using System;
using System.Net.Http;
using System.Reflection;
using System.Runtime.InteropServices
namespace Customization
{
public class Script
{
static readonly HttpClient client = new HttpClient();
public async Task<object?> RunAsync(object? data)
{
//The following should not be allowed
File.Delete(@""C:Templog.txt"");
return await Task.FromResult(data);
}
}
}";
var compilation = Compile(code);
var bytes = Build(compilation);
Console.WriteLine("Done");

CSharpCompilation Compile(string code)
{
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);
string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
{
throw new InvalidOperationException("Cannot determine path to current assembly.");
}
string assemblyName = Path.GetRandomFileName();
List<MetadataReference> references = new();
references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(HttpClient).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));
CSharpCompilation compilation = CSharpCompilation.Create(
assemblyName,
syntaxTrees: new[] { syntaxTree },
references: references,
options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));

SemanticModel model = compilation.GetSemanticModel(syntaxTree);
CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();
ThrowOnDisallowedClass("File", "System.IO", root, model);
ThrowOnDisallowedClass("HttpClient", "System.Net.Http", root, model);
ThrowOnDisallowedNamespace("System.Reflection", root);
ThrowOnDisallowedNamespace("System.Runtime.InteropServices", root);
return compilation;
}
[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
using (MemoryStream ms = new())
{
//Emit to catch build errors
EmitResult emitResult = compilation.Emit(ms);
if (!emitResult.Success)
{
Diagnostic? firstError =
emitResult
.Diagnostics
.FirstOrDefault
(
diagnostic => diagnostic.IsWarningAsError ||
diagnostic.Severity == DiagnosticSeverity.Error
);
throw new Exception(firstError?.GetMessage());
}
return ms.ToArray();
}
}
void ThrowOnDisallowedClass(string className, string containingNamespace, CompilationUnitSyntax root, SemanticModel model)
{
var names = root.DescendantNodes()
.OfType<IdentifierNameSyntax>()
.Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));
foreach (var name in names)
{
var typeInfo = model.GetTypeInfo(name);
if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
{
throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
}
}
}
void ThrowOnDisallowedNamespace(string disallowedNamespace, CompilationUnitSyntax root)
{
if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
{
throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
}
}

我在这里使用了throw来处理规则违规,这意味着不会同时报告多个违规,所以你可能想调整一下,这样会更有效率。

SymbolInfo类提供了创建规则以限制某些代码使用所需的一些meadatata。以下是我到目前为止的想法。如有任何关于如何改进的建议,我们将不胜感激。

//Check for banned namespaces
string[] namespaceBlacklist = new string[] { "System.Net", "System.IO" };
foreach (IdentifierNameSyntax identifier in identifiers)
{
SymbolInfo symbolInfo = semanticModel.GetSymbolInfo(identifier);
if (symbolInfo.Symbol is { })
{
if (symbolInfo.Symbol.Kind == SymbolKind.Namespace)
{
if (namespaceBlacklist.Any(ns => ns == symbolInfo.Symbol.ToDisplayString()))
{
throw new Exception($"Declaration of namespace '{symbolInfo.Symbol.ToDisplayString()}' is not allowed.");
}
}
else if (symbolInfo.Symbol.Kind == SymbolKind.NamedType)
{
if (namespaceBlacklist.Any(ns => symbolInfo.Symbol.ToDisplayString().StartsWith(ns + ".")))
{
throw new Exception($"Use of namespace '{identifier.Identifier.ValueText}' is not allowed.");
}
}
}
}

相关内容

  • 没有找到相关文章

最新更新