我将使用Roslyn动态编译和执行代码,如下例所示。我想确保代码不会违反我的一些规则,比如:
- 不使用反射
- 不使用HttpClient或WebClient
- 不在System.IO命名空间中使用File或Directory类
- 不使用源生成器
- 不调用非托管代码
在下面的代码中,我将在哪里插入我的规则/检查,以及我将如何执行这些规则/检查?
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;
using System.Reflection;
using System.Runtime.CompilerServices;
string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;
namespace Customization
{
public class Script
{
public async Task<object?> RunAsync(object? data)
{
//The following should not be allowed
File.Delete(@""C:Templog.txt"");
return await Task.FromResult(data);
}
}
}";
var compilation = Compile(code);
var bytes = Build(compilation);
Console.WriteLine("Done");
CSharpCompilation Compile(string code)
{
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);
string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
{
throw new ArgumentNullException("Cannot determine path to current assembly.");
}
string assemblyName = Path.GetRandomFileName();
List<MetadataReference> references = new();
references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));
CSharpCompilation compilation = CSharpCompilation.Create(
assemblyName,
syntaxTrees: new[] { syntaxTree },
references: references,
options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
SemanticModel model = compilation.GetSemanticModel(syntaxTree);
CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();
//TODO: Check the code for use classes that are not allowed such as File in the System.IO namespace.
//Not exactly sure how to walk through identifiers.
IEnumerable<IdentifierNameSyntax> identifiers = root.DescendantNodes()
.Where(s => s is IdentifierNameSyntax)
.Cast<IdentifierNameSyntax>();
return compilation;
}
[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
using (MemoryStream ms = new())
{
//Emit to catch build errors
EmitResult emitResult = compilation.Emit(ms);
if (!emitResult.Success)
{
Diagnostic? firstError =
emitResult
.Diagnostics
.FirstOrDefault
(
diagnostic => diagnostic.IsWarningAsError ||
diagnostic.Severity == DiagnosticSeverity.Error
);
throw new Exception(firstError?.GetMessage());
}
return ms.ToArray();
}
}
检查特定类的使用情况时,可以使用OfType<>()
方法查找IdentifierNameSyntax
类型的节点,并按类名过滤结果:
var names = root.DescendantNodes()
.OfType<IdentifierNameSyntax>()
.Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));
然后可以使用SemanticModel
检查类的名称空间:
foreach (var name in names)
{
var typeInfo = model.GetTypeInfo(name);
if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
{
throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
}
}
若要检查反射或非托管代码的使用情况,可以检查相关usingSystem.Reflection
和System.Runtime.InteropServices
。
if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
{
throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
}
这会捕捉到using未使用的情况,即没有实际的反射或非托管代码,但这似乎是一种可以接受的折衷
我不知道该怎么处理源代码生成器检查,因为这些检查通常作为项目引用包含,所以我不知道它们如何在动态编译的代码中运行。
将支票放在同一个地方并更新您的代码可以得到:
using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;
string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;
using System;
using System.Net.Http;
using System.Reflection;
using System.Runtime.InteropServices
namespace Customization
{
public class Script
{
static readonly HttpClient client = new HttpClient();
public async Task<object?> RunAsync(object? data)
{
//The following should not be allowed
File.Delete(@""C:Templog.txt"");
return await Task.FromResult(data);
}
}
}";
var compilation = Compile(code);
var bytes = Build(compilation);
Console.WriteLine("Done");
CSharpCompilation Compile(string code)
{
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);
string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
{
throw new InvalidOperationException("Cannot determine path to current assembly.");
}
string assemblyName = Path.GetRandomFileName();
List<MetadataReference> references = new();
references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(typeof(HttpClient).Assembly.Location));
references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));
CSharpCompilation compilation = CSharpCompilation.Create(
assemblyName,
syntaxTrees: new[] { syntaxTree },
references: references,
options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
SemanticModel model = compilation.GetSemanticModel(syntaxTree);
CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();
ThrowOnDisallowedClass("File", "System.IO", root, model);
ThrowOnDisallowedClass("HttpClient", "System.Net.Http", root, model);
ThrowOnDisallowedNamespace("System.Reflection", root);
ThrowOnDisallowedNamespace("System.Runtime.InteropServices", root);
return compilation;
}
[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
using (MemoryStream ms = new())
{
//Emit to catch build errors
EmitResult emitResult = compilation.Emit(ms);
if (!emitResult.Success)
{
Diagnostic? firstError =
emitResult
.Diagnostics
.FirstOrDefault
(
diagnostic => diagnostic.IsWarningAsError ||
diagnostic.Severity == DiagnosticSeverity.Error
);
throw new Exception(firstError?.GetMessage());
}
return ms.ToArray();
}
}
void ThrowOnDisallowedClass(string className, string containingNamespace, CompilationUnitSyntax root, SemanticModel model)
{
var names = root.DescendantNodes()
.OfType<IdentifierNameSyntax>()
.Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));
foreach (var name in names)
{
var typeInfo = model.GetTypeInfo(name);
if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
{
throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
}
}
}
void ThrowOnDisallowedNamespace(string disallowedNamespace, CompilationUnitSyntax root)
{
if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
{
throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
}
}
我在这里使用了throw
来处理规则违规,这意味着不会同时报告多个违规,所以你可能想调整一下,这样会更有效率。
SymbolInfo类提供了创建规则以限制某些代码使用所需的一些meadatata。以下是我到目前为止的想法。如有任何关于如何改进的建议,我们将不胜感激。
//Check for banned namespaces
string[] namespaceBlacklist = new string[] { "System.Net", "System.IO" };
foreach (IdentifierNameSyntax identifier in identifiers)
{
SymbolInfo symbolInfo = semanticModel.GetSymbolInfo(identifier);
if (symbolInfo.Symbol is { })
{
if (symbolInfo.Symbol.Kind == SymbolKind.Namespace)
{
if (namespaceBlacklist.Any(ns => ns == symbolInfo.Symbol.ToDisplayString()))
{
throw new Exception($"Declaration of namespace '{symbolInfo.Symbol.ToDisplayString()}' is not allowed.");
}
}
else if (symbolInfo.Symbol.Kind == SymbolKind.NamedType)
{
if (namespaceBlacklist.Any(ns => symbolInfo.Symbol.ToDisplayString().StartsWith(ns + ".")))
{
throw new Exception($"Use of namespace '{identifier.Identifier.ValueText}' is not allowed.");
}
}
}
}