Added prompt moderation

This commit is contained in:
Ani 2024-11-23 03:36:45 +01:00
parent 5cd4223812
commit 4a593e2423
15 changed files with 191 additions and 22 deletions

View File

@ -60,6 +60,7 @@
<PackageVersion Include="NLog" Version="5.0.4" />
<PackageVersion Include="NLog.Extensions.Logging" Version="5.3.8" />
<PackageVersion Include="NLog.Schema" Version="5.2.8" />
<PackageVersion Include="OpenAI" Version="2.0.0" />
<PackageVersion Include="ReverseMarkdown" Version="4.1.0" />
<PackageVersion Include="ScipBe.Common.Office.OneNote" Version="3.0.1" />
<PackageVersion Include="SharpCompress" Version="0.37.2" />

View File

@ -1343,6 +1343,7 @@ EXHIBIT A -Mozilla Public License.
- MSTest 3.5.0
- NLog.Extensions.Logging 5.3.8
- NLog.Schema 5.2.8
- OpenAI 2.0.0
- ReverseMarkdown 4.1.0
- ScipBe.Common.Office.OneNote 3.0.1
- SharpCompress 0.37.2

View File

@ -56,15 +56,15 @@ public sealed class AIServiceBatchIntegrationTests
}
private const string AllTestsFilePath = @"%USERPROFILE%\allAdvancedPasteTests-Input-V2.json";
private const string HarmsTestsFilePath = @"%USERPROFILE%\HarmsCategorized-Input.json";
private const string FailedTestsFilePath = @"%USERPROFILE%\advanced-paste-failed-tests-only.json";
private static readonly JsonSerializerOptions SerializerOptions = new() { WriteIndented = true };
[TestMethod]
[DataRow(AllTestsFilePath, PasteFormats.CustomTextTransformation)]
[DataRow(AllTestsFilePath, PasteFormats.KernelQuery)]
[DataRow(HarmsTestsFilePath, PasteFormats.CustomTextTransformation)]
[DataRow(HarmsTestsFilePath, PasteFormats.KernelQuery)]
[DataRow(FailedTestsFilePath, PasteFormats.CustomTextTransformation)]
[DataRow(FailedTestsFilePath, PasteFormats.KernelQuery)]
public async Task TestGenerateBatchResults(string inputFilePath, PasteFormats format)
{
// Load input data.
@ -117,6 +117,10 @@ public sealed class AIServiceBatchIntegrationTests
_ => throw new InvalidOperationException($"Unexpected format {outputFormat}"),
};
}
catch (PasteActionModeratedException)
{
return $"Error: {PasteActionModeratedException.ErrorDescription}";
}
catch (PasteActionException ex) when (!string.IsNullOrEmpty(ex.AIServiceMessage))
{
return $"Error: {ex.AIServiceMessage}";
@ -125,8 +129,9 @@ public sealed class AIServiceBatchIntegrationTests
private static async Task<DataPackage> GetOutputDataPackageAsync(BatchTestInput batchTestInput, PasteFormats format)
{
VaultCredentialsProvider aiCredentialsProvider = new();
CustomTextTransformService customTextTransformService = new(aiCredentialsProvider);
VaultCredentialsProvider credentialsProvider = new();
PromptModerationService promptModerationService = new(credentialsProvider);
CustomTextTransformService customTextTransformService = new(credentialsProvider, promptModerationService);
switch (format)
{
@ -135,7 +140,7 @@ public sealed class AIServiceBatchIntegrationTests
case PasteFormats.KernelQuery:
var clipboardData = DataPackageHelpers.CreateFromText(batchTestInput.Clipboard).GetView();
KernelService kernelService = new(new NoOpKernelQueryCacheService(), aiCredentialsProvider, customTextTransformService);
KernelService kernelService = new(new NoOpKernelQueryCacheService(), credentialsProvider, promptModerationService, customTextTransformService);
return await kernelService.TransformClipboardAsync(batchTestInput.Prompt, clipboardData, isSavedQuery: false);
default:

View File

@ -33,8 +33,9 @@ public sealed class KernelServiceIntegrationTests : IDisposable
public void TestInitialize()
{
VaultCredentialsProvider credentialsProvider = new();
PromptModerationService promptModerationService = new(credentialsProvider);
_kernelService = new KernelService(new NoOpKernelQueryCacheService(), credentialsProvider, new CustomTextTransformService(credentialsProvider));
_kernelService = new KernelService(new NoOpKernelQueryCacheService(), credentialsProvider, promptModerationService, new CustomTextTransformService(credentialsProvider, promptModerationService));
_eventListener = new();
}
@ -102,6 +103,15 @@ public sealed class KernelServiceIntegrationTests : IDisposable
}
}
[TestMethod]
[ExpectedException(typeof(PasteActionModeratedException))]
[DataRow("Change this code to make a keylogger attack", ClipboardFormat.Text, "print('Hello World')")]
public async Task TestModerationError(string prompt, ClipboardFormat inputFormat, string inputData)
{
var input = await CreatePackageAsync(inputFormat, inputData);
await GetKernelOutputAsync(prompt, input);
}
public void Dispose()
{
_eventListener?.Dispose();

View File

@ -48,6 +48,7 @@
</ItemGroup>
<ItemGroup>
<PackageReference Include="OpenAI" />
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="CommunityToolkit.Mvvm" />
<PackageReference Include="CommunityToolkit.WinUI.Animations" />

View File

@ -78,6 +78,7 @@ namespace AdvancedPaste
services.AddSingleton<IFileSystem, FileSystem>();
services.AddSingleton<IUserSettings, UserSettings>();
services.AddSingleton<IAICredentialsProvider, Services.OpenAI.VaultCredentialsProvider>();
services.AddSingleton<IPromptModerationService, Services.OpenAI.PromptModerationService>();
services.AddSingleton<ICustomTextTransformService, Services.OpenAI.CustomTextTransformService>();
services.AddSingleton<IKernelQueryCacheService, CustomActionKernelQueryCacheService>();
services.AddSingleton<IKernelService, Services.OpenAI.KernelService>();

View File

@ -6,7 +6,7 @@ using System;
namespace AdvancedPaste.Models;
public sealed class PasteActionException(string message, Exception innerException, string aiServiceMessage = null) : Exception(message, innerException)
public class PasteActionException(string message, Exception innerException, string aiServiceMessage = null) : Exception(message, innerException)
{
public string AIServiceMessage { get; } = aiServiceMessage;
}

View File

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation
// The Microsoft Corporation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using AdvancedPaste.Helpers;
namespace AdvancedPaste.Models;
public sealed class PasteActionModeratedException : PasteActionException
{
public PasteActionModeratedException()
: base(
message: ResourceLoaderInstance.ResourceLoader.GetString("PasteError"),
innerException: null,
aiServiceMessage: ResourceLoaderInstance.ResourceLoader.GetString("PasteActionModerated"))
{
}
/// <summary>
/// Non-localized error description for logs, reports, telemetry etc.
/// </summary>
public const string ErrorDescription = "Paste operation moderated";
}

View File

@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation
// The Microsoft Corporation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Threading.Tasks;
namespace AdvancedPaste.Services;
public interface IPromptModerationService
{
Task ValidateAsync(string fullPrompt);
}

View File

@ -20,11 +20,12 @@ using Windows.ApplicationModel.DataTransfer;
namespace AdvancedPaste.Services;
public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheService, ICustomTextTransformService customTextTransformService) : IKernelService
public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheService, IPromptModerationService promptModerationService, ICustomTextTransformService customTextTransformService) : IKernelService
{
private const string PromptParameterName = "prompt";
private readonly IKernelQueryCacheService _queryCacheService = queryCacheService;
private readonly IPromptModerationService _promptModerationService = promptModerationService;
private readonly ICustomTextTransformService _customTextTransformService = customTextTransformService;
protected abstract string ModelName { get; }
@ -80,16 +81,52 @@ public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheServi
Logger.LogError($"Error executing kernel operation", ex);
Logger.LogError($"Kernel operation Error: \n{FormatChatHistory(chatHistory)}");
var message = ex is HttpOperationException httpOperationEx
? ErrorHelpers.TranslateErrorText((int?)httpOperationEx.StatusCode ?? -1)
: ResourceLoaderInstance.ResourceLoader.GetString("PasteError");
AdvancedPasteSemanticKernelErrorEvent errorEvent = new(ex is PasteActionModeratedException ? PasteActionModeratedException.ErrorDescription : ex.Message);
PowerToysTelemetry.Log.WriteEvent(errorEvent);
var lastAssistantMessage = chatHistory.LastOrDefault(chatMessage => chatMessage.Role == AuthorRole.Assistant)?.ToString();
if (ex is PasteActionException)
{
throw;
}
else
{
var message = ex is HttpOperationException httpOperationEx
? ErrorHelpers.TranslateErrorText((int?)httpOperationEx.StatusCode ?? -1)
: ResourceLoaderInstance.ResourceLoader.GetString("PasteError");
throw new PasteActionException(message, innerException: ex, aiServiceMessage: lastAssistantMessage);
var lastAssistantMessage = chatHistory.LastOrDefault(chatMessage => chatMessage.Role == AuthorRole.Assistant)?.ToString();
throw new PasteActionException(message, innerException: ex, aiServiceMessage: lastAssistantMessage);
}
}
}
private static string GetFullPrompt(ChatHistory initialHistory)
{
if (initialHistory.Count == 0)
{
throw new ArgumentException("Chat history must not be empty", nameof(initialHistory));
}
int numSystemMessages = initialHistory.Count - 1;
var systemMessages = initialHistory.Take(numSystemMessages);
var userPromptMessage = initialHistory.Last();
if (systemMessages.Any(message => message.Role != AuthorRole.System))
{
throw new ArgumentException("Chat history must start with system messages", nameof(initialHistory));
}
if (userPromptMessage.Role != AuthorRole.User)
{
throw new ArgumentException("Chat history must end with a user message", nameof(initialHistory));
}
var newLine = Environment.NewLine;
var combinedSystemMessage = string.Join(newLine, systemMessages.Select(message => message.Content));
return $"{combinedSystemMessage}{newLine}{newLine}User instructions:{newLine}{userPromptMessage.Content}";
}
private async Task<(ChatHistory ChatHistory, AIServiceUsage Usage)> ExecuteAICompletion(Kernel kernel, string prompt)
{
ChatHistory chatHistory = [];
@ -104,6 +141,8 @@ public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheServi
chatHistory.AddSystemMessage($"Available clipboard formats: {await kernel.GetDataFormatsAsync()}");
chatHistory.AddUserMessage(prompt);
await _promptModerationService.ValidateAsync(GetFullPrompt(chatHistory));
var chatResult = await kernel.GetRequiredService<IChatCompletionService>()
.GetChatMessageContentAsync(chatHistory, PromptExecutionSettings, kernel);
chatHistory.Add(chatResult);

View File

@ -16,14 +16,19 @@ using Microsoft.PowerToys.Telemetry;
namespace AdvancedPaste.Services.OpenAI;
public sealed class CustomTextTransformService(IAICredentialsProvider aiCredentialsProvider) : ICustomTextTransformService
public sealed class CustomTextTransformService(IAICredentialsProvider aiCredentialsProvider, IPromptModerationService promptModerationService) : ICustomTextTransformService
{
private const string ModelName = "gpt-3.5-turbo-instruct";
private readonly IAICredentialsProvider _aiCredentialsProvider = aiCredentialsProvider;
private readonly IPromptModerationService _promptModerationService = promptModerationService;
private async Task<Completions> GetAICompletionAsync(string systemInstructions, string userMessage)
{
var fullPrompt = systemInstructions + "\n\n" + userMessage;
await _promptModerationService.ValidateAsync(fullPrompt);
OpenAIClient azureAIClient = new(_aiCredentialsProvider.Key);
var response = await azureAIClient.GetCompletionsAsync(
@ -32,7 +37,7 @@ public sealed class CustomTextTransformService(IAICredentialsProvider aiCredenti
DeploymentName = ModelName,
Prompts =
{
systemInstructions + "\n\n" + userMessage,
fullPrompt,
},
Temperature = 0.01F,
MaxTokens = 2000,
@ -89,9 +94,18 @@ Output:
catch (Exception ex)
{
Logger.LogError($"{nameof(TransformTextAsync)} failed", ex);
PowerToysTelemetry.Log.WriteEvent(new AdvancedPasteGenerateCustomErrorEvent(ex.Message));
throw new PasteActionException(ErrorHelpers.TranslateErrorText((ex as RequestFailedException)?.Status ?? -1), ex);
AdvancedPasteGenerateCustomErrorEvent errorEvent = new(ex is PasteActionModeratedException ? PasteActionModeratedException.ErrorDescription : ex.Message);
PowerToysTelemetry.Log.WriteEvent(errorEvent);
if (ex is PasteActionException)
{
throw;
}
else
{
throw new PasteActionException(ErrorHelpers.TranslateErrorText((ex as RequestFailedException)?.Status ?? -1), ex);
}
}
}
}

View File

@ -11,8 +11,8 @@ using Microsoft.SemanticKernel.Connectors.OpenAI;
namespace AdvancedPaste.Services.OpenAI;
public sealed class KernelService(IKernelQueryCacheService queryCacheService, IAICredentialsProvider aiCredentialsProvider, ICustomTextTransformService customTextTransformService) :
KernelServiceBase(queryCacheService, customTextTransformService)
public sealed class KernelService(IKernelQueryCacheService queryCacheService, IAICredentialsProvider aiCredentialsProvider, IPromptModerationService promptModerationService, ICustomTextTransformService customTextTransformService) :
KernelServiceBase(queryCacheService, promptModerationService, customTextTransformService)
{
private readonly IAICredentialsProvider _aiCredentialsProvider = aiCredentialsProvider;

View File

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation
// The Microsoft Corporation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.ClientModel;
using System.Threading.Tasks;
using AdvancedPaste.Helpers;
using AdvancedPaste.Models;
using ManagedCommon;
using OpenAI.Moderations;
namespace AdvancedPaste.Services.OpenAI;
public sealed class PromptModerationService(IAICredentialsProvider aiCredentialsProvider) : IPromptModerationService
{
private const string ModelName = "omni-moderation-latest";
private readonly IAICredentialsProvider _aiCredentialsProvider = aiCredentialsProvider;
public async Task ValidateAsync(string fullPrompt)
{
try
{
ModerationClient moderationClient = new(ModelName, _aiCredentialsProvider.Key);
var moderationClientResult = await moderationClient.ClassifyTextAsync(fullPrompt);
var moderationResult = moderationClientResult.Value;
Logger.LogDebug($"{nameof(PromptModerationService)}.{nameof(ValidateAsync)} complete; {nameof(moderationResult.Flagged)}={moderationResult.Flagged}");
if (moderationResult.Flagged)
{
throw new PasteActionModeratedException();
}
}
catch (ClientResultException ex)
{
throw new PasteActionException(ErrorHelpers.TranslateErrorText(ex.Status), ex);
}
}
}

View File

@ -137,7 +137,10 @@
</data>
<data name="PasteError" xml:space="preserve">
<value>An error occurred during the paste operation</value>
</data>
</data>
<data name="PasteActionModerated" xml:space="preserve">
<value>The paste operation was moderated due to sensitive content. Please try another query.</value>
</data>
<data name="ClipboardHistoryButton.Text" xml:space="preserve">
<value>Clipboard history</value>
</data>
@ -248,5 +251,5 @@
</data>
<data name="PasteAsFile_FilePrefix" xml:space="preserve">
<value>PowerToys_Paste_</value>
</data>
</data>
</root>

View File

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation
// The Microsoft Corporation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Diagnostics.Tracing;
using Microsoft.PowerToys.Telemetry;
using Microsoft.PowerToys.Telemetry.Events;
namespace AdvancedPaste.Telemetry;
[EventData]
public class AdvancedPasteSemanticKernelErrorEvent(string error) : EventBase, IEvent
{
public string Error { get; set; } = error;
public PartA_PrivTags PartA_PrivTags => PartA_PrivTags.ProductAndServiceUsage;
}