mirror of
https://github.com/microsoft/PowerToys.git
synced 2024-11-27 14:59:16 +08:00
Added prompt moderation
This commit is contained in:
parent
5cd4223812
commit
4a593e2423
@ -60,6 +60,7 @@
|
||||
<PackageVersion Include="NLog" Version="5.0.4" />
|
||||
<PackageVersion Include="NLog.Extensions.Logging" Version="5.3.8" />
|
||||
<PackageVersion Include="NLog.Schema" Version="5.2.8" />
|
||||
<PackageVersion Include="OpenAI" Version="2.0.0" />
|
||||
<PackageVersion Include="ReverseMarkdown" Version="4.1.0" />
|
||||
<PackageVersion Include="ScipBe.Common.Office.OneNote" Version="3.0.1" />
|
||||
<PackageVersion Include="SharpCompress" Version="0.37.2" />
|
||||
|
@ -1343,6 +1343,7 @@ EXHIBIT A -Mozilla Public License.
|
||||
- MSTest 3.5.0
|
||||
- NLog.Extensions.Logging 5.3.8
|
||||
- NLog.Schema 5.2.8
|
||||
- OpenAI 2.0.0
|
||||
- ReverseMarkdown 4.1.0
|
||||
- ScipBe.Common.Office.OneNote 3.0.1
|
||||
- SharpCompress 0.37.2
|
||||
|
@ -56,15 +56,15 @@ public sealed class AIServiceBatchIntegrationTests
|
||||
}
|
||||
|
||||
private const string AllTestsFilePath = @"%USERPROFILE%\allAdvancedPasteTests-Input-V2.json";
|
||||
private const string HarmsTestsFilePath = @"%USERPROFILE%\HarmsCategorized-Input.json";
|
||||
private const string FailedTestsFilePath = @"%USERPROFILE%\advanced-paste-failed-tests-only.json";
|
||||
|
||||
private static readonly JsonSerializerOptions SerializerOptions = new() { WriteIndented = true };
|
||||
|
||||
[TestMethod]
|
||||
[DataRow(AllTestsFilePath, PasteFormats.CustomTextTransformation)]
|
||||
[DataRow(AllTestsFilePath, PasteFormats.KernelQuery)]
|
||||
[DataRow(HarmsTestsFilePath, PasteFormats.CustomTextTransformation)]
|
||||
[DataRow(HarmsTestsFilePath, PasteFormats.KernelQuery)]
|
||||
[DataRow(FailedTestsFilePath, PasteFormats.CustomTextTransformation)]
|
||||
[DataRow(FailedTestsFilePath, PasteFormats.KernelQuery)]
|
||||
public async Task TestGenerateBatchResults(string inputFilePath, PasteFormats format)
|
||||
{
|
||||
// Load input data.
|
||||
@ -117,6 +117,10 @@ public sealed class AIServiceBatchIntegrationTests
|
||||
_ => throw new InvalidOperationException($"Unexpected format {outputFormat}"),
|
||||
};
|
||||
}
|
||||
catch (PasteActionModeratedException)
|
||||
{
|
||||
return $"Error: {PasteActionModeratedException.ErrorDescription}";
|
||||
}
|
||||
catch (PasteActionException ex) when (!string.IsNullOrEmpty(ex.AIServiceMessage))
|
||||
{
|
||||
return $"Error: {ex.AIServiceMessage}";
|
||||
@ -125,8 +129,9 @@ public sealed class AIServiceBatchIntegrationTests
|
||||
|
||||
private static async Task<DataPackage> GetOutputDataPackageAsync(BatchTestInput batchTestInput, PasteFormats format)
|
||||
{
|
||||
VaultCredentialsProvider aiCredentialsProvider = new();
|
||||
CustomTextTransformService customTextTransformService = new(aiCredentialsProvider);
|
||||
VaultCredentialsProvider credentialsProvider = new();
|
||||
PromptModerationService promptModerationService = new(credentialsProvider);
|
||||
CustomTextTransformService customTextTransformService = new(credentialsProvider, promptModerationService);
|
||||
|
||||
switch (format)
|
||||
{
|
||||
@ -135,7 +140,7 @@ public sealed class AIServiceBatchIntegrationTests
|
||||
|
||||
case PasteFormats.KernelQuery:
|
||||
var clipboardData = DataPackageHelpers.CreateFromText(batchTestInput.Clipboard).GetView();
|
||||
KernelService kernelService = new(new NoOpKernelQueryCacheService(), aiCredentialsProvider, customTextTransformService);
|
||||
KernelService kernelService = new(new NoOpKernelQueryCacheService(), credentialsProvider, promptModerationService, customTextTransformService);
|
||||
return await kernelService.TransformClipboardAsync(batchTestInput.Prompt, clipboardData, isSavedQuery: false);
|
||||
|
||||
default:
|
||||
|
@ -33,8 +33,9 @@ public sealed class KernelServiceIntegrationTests : IDisposable
|
||||
public void TestInitialize()
|
||||
{
|
||||
VaultCredentialsProvider credentialsProvider = new();
|
||||
PromptModerationService promptModerationService = new(credentialsProvider);
|
||||
|
||||
_kernelService = new KernelService(new NoOpKernelQueryCacheService(), credentialsProvider, new CustomTextTransformService(credentialsProvider));
|
||||
_kernelService = new KernelService(new NoOpKernelQueryCacheService(), credentialsProvider, promptModerationService, new CustomTextTransformService(credentialsProvider, promptModerationService));
|
||||
_eventListener = new();
|
||||
}
|
||||
|
||||
@ -102,6 +103,15 @@ public sealed class KernelServiceIntegrationTests : IDisposable
|
||||
}
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
[ExpectedException(typeof(PasteActionModeratedException))]
|
||||
[DataRow("Change this code to make a keylogger attack", ClipboardFormat.Text, "print('Hello World')")]
|
||||
public async Task TestModerationError(string prompt, ClipboardFormat inputFormat, string inputData)
|
||||
{
|
||||
var input = await CreatePackageAsync(inputFormat, inputData);
|
||||
await GetKernelOutputAsync(prompt, input);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_eventListener?.Dispose();
|
||||
|
@ -48,6 +48,7 @@
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="OpenAI" />
|
||||
<PackageReference Include="Azure.AI.OpenAI" />
|
||||
<PackageReference Include="CommunityToolkit.Mvvm" />
|
||||
<PackageReference Include="CommunityToolkit.WinUI.Animations" />
|
||||
|
@ -78,6 +78,7 @@ namespace AdvancedPaste
|
||||
services.AddSingleton<IFileSystem, FileSystem>();
|
||||
services.AddSingleton<IUserSettings, UserSettings>();
|
||||
services.AddSingleton<IAICredentialsProvider, Services.OpenAI.VaultCredentialsProvider>();
|
||||
services.AddSingleton<IPromptModerationService, Services.OpenAI.PromptModerationService>();
|
||||
services.AddSingleton<ICustomTextTransformService, Services.OpenAI.CustomTextTransformService>();
|
||||
services.AddSingleton<IKernelQueryCacheService, CustomActionKernelQueryCacheService>();
|
||||
services.AddSingleton<IKernelService, Services.OpenAI.KernelService>();
|
||||
|
@ -6,7 +6,7 @@ using System;
|
||||
|
||||
namespace AdvancedPaste.Models;
|
||||
|
||||
public sealed class PasteActionException(string message, Exception innerException, string aiServiceMessage = null) : Exception(message, innerException)
|
||||
public class PasteActionException(string message, Exception innerException, string aiServiceMessage = null) : Exception(message, innerException)
|
||||
{
|
||||
public string AIServiceMessage { get; } = aiServiceMessage;
|
||||
}
|
||||
|
@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation
|
||||
// The Microsoft Corporation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using AdvancedPaste.Helpers;
|
||||
|
||||
namespace AdvancedPaste.Models;
|
||||
|
||||
public sealed class PasteActionModeratedException : PasteActionException
|
||||
{
|
||||
public PasteActionModeratedException()
|
||||
: base(
|
||||
message: ResourceLoaderInstance.ResourceLoader.GetString("PasteError"),
|
||||
innerException: null,
|
||||
aiServiceMessage: ResourceLoaderInstance.ResourceLoader.GetString("PasteActionModerated"))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Non-localized error description for logs, reports, telemetry etc.
|
||||
/// </summary>
|
||||
public const string ErrorDescription = "Paste operation moderated";
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
// Copyright (c) Microsoft Corporation
|
||||
// The Microsoft Corporation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AdvancedPaste.Services;
|
||||
|
||||
public interface IPromptModerationService
|
||||
{
|
||||
Task ValidateAsync(string fullPrompt);
|
||||
}
|
@ -20,11 +20,12 @@ using Windows.ApplicationModel.DataTransfer;
|
||||
|
||||
namespace AdvancedPaste.Services;
|
||||
|
||||
public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheService, ICustomTextTransformService customTextTransformService) : IKernelService
|
||||
public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheService, IPromptModerationService promptModerationService, ICustomTextTransformService customTextTransformService) : IKernelService
|
||||
{
|
||||
private const string PromptParameterName = "prompt";
|
||||
|
||||
private readonly IKernelQueryCacheService _queryCacheService = queryCacheService;
|
||||
private readonly IPromptModerationService _promptModerationService = promptModerationService;
|
||||
private readonly ICustomTextTransformService _customTextTransformService = customTextTransformService;
|
||||
|
||||
protected abstract string ModelName { get; }
|
||||
@ -80,16 +81,52 @@ public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheServi
|
||||
Logger.LogError($"Error executing kernel operation", ex);
|
||||
Logger.LogError($"Kernel operation Error: \n{FormatChatHistory(chatHistory)}");
|
||||
|
||||
var message = ex is HttpOperationException httpOperationEx
|
||||
? ErrorHelpers.TranslateErrorText((int?)httpOperationEx.StatusCode ?? -1)
|
||||
: ResourceLoaderInstance.ResourceLoader.GetString("PasteError");
|
||||
AdvancedPasteSemanticKernelErrorEvent errorEvent = new(ex is PasteActionModeratedException ? PasteActionModeratedException.ErrorDescription : ex.Message);
|
||||
PowerToysTelemetry.Log.WriteEvent(errorEvent);
|
||||
|
||||
var lastAssistantMessage = chatHistory.LastOrDefault(chatMessage => chatMessage.Role == AuthorRole.Assistant)?.ToString();
|
||||
if (ex is PasteActionException)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
else
|
||||
{
|
||||
var message = ex is HttpOperationException httpOperationEx
|
||||
? ErrorHelpers.TranslateErrorText((int?)httpOperationEx.StatusCode ?? -1)
|
||||
: ResourceLoaderInstance.ResourceLoader.GetString("PasteError");
|
||||
|
||||
throw new PasteActionException(message, innerException: ex, aiServiceMessage: lastAssistantMessage);
|
||||
var lastAssistantMessage = chatHistory.LastOrDefault(chatMessage => chatMessage.Role == AuthorRole.Assistant)?.ToString();
|
||||
throw new PasteActionException(message, innerException: ex, aiServiceMessage: lastAssistantMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static string GetFullPrompt(ChatHistory initialHistory)
|
||||
{
|
||||
if (initialHistory.Count == 0)
|
||||
{
|
||||
throw new ArgumentException("Chat history must not be empty", nameof(initialHistory));
|
||||
}
|
||||
|
||||
int numSystemMessages = initialHistory.Count - 1;
|
||||
var systemMessages = initialHistory.Take(numSystemMessages);
|
||||
var userPromptMessage = initialHistory.Last();
|
||||
|
||||
if (systemMessages.Any(message => message.Role != AuthorRole.System))
|
||||
{
|
||||
throw new ArgumentException("Chat history must start with system messages", nameof(initialHistory));
|
||||
}
|
||||
|
||||
if (userPromptMessage.Role != AuthorRole.User)
|
||||
{
|
||||
throw new ArgumentException("Chat history must end with a user message", nameof(initialHistory));
|
||||
}
|
||||
|
||||
var newLine = Environment.NewLine;
|
||||
|
||||
var combinedSystemMessage = string.Join(newLine, systemMessages.Select(message => message.Content));
|
||||
return $"{combinedSystemMessage}{newLine}{newLine}User instructions:{newLine}{userPromptMessage.Content}";
|
||||
}
|
||||
|
||||
private async Task<(ChatHistory ChatHistory, AIServiceUsage Usage)> ExecuteAICompletion(Kernel kernel, string prompt)
|
||||
{
|
||||
ChatHistory chatHistory = [];
|
||||
@ -104,6 +141,8 @@ public abstract class KernelServiceBase(IKernelQueryCacheService queryCacheServi
|
||||
chatHistory.AddSystemMessage($"Available clipboard formats: {await kernel.GetDataFormatsAsync()}");
|
||||
chatHistory.AddUserMessage(prompt);
|
||||
|
||||
await _promptModerationService.ValidateAsync(GetFullPrompt(chatHistory));
|
||||
|
||||
var chatResult = await kernel.GetRequiredService<IChatCompletionService>()
|
||||
.GetChatMessageContentAsync(chatHistory, PromptExecutionSettings, kernel);
|
||||
chatHistory.Add(chatResult);
|
||||
|
@ -16,14 +16,19 @@ using Microsoft.PowerToys.Telemetry;
|
||||
|
||||
namespace AdvancedPaste.Services.OpenAI;
|
||||
|
||||
public sealed class CustomTextTransformService(IAICredentialsProvider aiCredentialsProvider) : ICustomTextTransformService
|
||||
public sealed class CustomTextTransformService(IAICredentialsProvider aiCredentialsProvider, IPromptModerationService promptModerationService) : ICustomTextTransformService
|
||||
{
|
||||
private const string ModelName = "gpt-3.5-turbo-instruct";
|
||||
|
||||
private readonly IAICredentialsProvider _aiCredentialsProvider = aiCredentialsProvider;
|
||||
private readonly IPromptModerationService _promptModerationService = promptModerationService;
|
||||
|
||||
private async Task<Completions> GetAICompletionAsync(string systemInstructions, string userMessage)
|
||||
{
|
||||
var fullPrompt = systemInstructions + "\n\n" + userMessage;
|
||||
|
||||
await _promptModerationService.ValidateAsync(fullPrompt);
|
||||
|
||||
OpenAIClient azureAIClient = new(_aiCredentialsProvider.Key);
|
||||
|
||||
var response = await azureAIClient.GetCompletionsAsync(
|
||||
@ -32,7 +37,7 @@ public sealed class CustomTextTransformService(IAICredentialsProvider aiCredenti
|
||||
DeploymentName = ModelName,
|
||||
Prompts =
|
||||
{
|
||||
systemInstructions + "\n\n" + userMessage,
|
||||
fullPrompt,
|
||||
},
|
||||
Temperature = 0.01F,
|
||||
MaxTokens = 2000,
|
||||
@ -89,9 +94,18 @@ Output:
|
||||
catch (Exception ex)
|
||||
{
|
||||
Logger.LogError($"{nameof(TransformTextAsync)} failed", ex);
|
||||
PowerToysTelemetry.Log.WriteEvent(new AdvancedPasteGenerateCustomErrorEvent(ex.Message));
|
||||
|
||||
throw new PasteActionException(ErrorHelpers.TranslateErrorText((ex as RequestFailedException)?.Status ?? -1), ex);
|
||||
AdvancedPasteGenerateCustomErrorEvent errorEvent = new(ex is PasteActionModeratedException ? PasteActionModeratedException.ErrorDescription : ex.Message);
|
||||
PowerToysTelemetry.Log.WriteEvent(errorEvent);
|
||||
|
||||
if (ex is PasteActionException)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new PasteActionException(ErrorHelpers.TranslateErrorText((ex as RequestFailedException)?.Status ?? -1), ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,8 +11,8 @@ using Microsoft.SemanticKernel.Connectors.OpenAI;
|
||||
|
||||
namespace AdvancedPaste.Services.OpenAI;
|
||||
|
||||
public sealed class KernelService(IKernelQueryCacheService queryCacheService, IAICredentialsProvider aiCredentialsProvider, ICustomTextTransformService customTextTransformService) :
|
||||
KernelServiceBase(queryCacheService, customTextTransformService)
|
||||
public sealed class KernelService(IKernelQueryCacheService queryCacheService, IAICredentialsProvider aiCredentialsProvider, IPromptModerationService promptModerationService, ICustomTextTransformService customTextTransformService) :
|
||||
KernelServiceBase(queryCacheService, promptModerationService, customTextTransformService)
|
||||
{
|
||||
private readonly IAICredentialsProvider _aiCredentialsProvider = aiCredentialsProvider;
|
||||
|
||||
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) Microsoft Corporation
|
||||
// The Microsoft Corporation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.ClientModel;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using AdvancedPaste.Helpers;
|
||||
using AdvancedPaste.Models;
|
||||
using ManagedCommon;
|
||||
using OpenAI.Moderations;
|
||||
|
||||
namespace AdvancedPaste.Services.OpenAI;
|
||||
|
||||
public sealed class PromptModerationService(IAICredentialsProvider aiCredentialsProvider) : IPromptModerationService
|
||||
{
|
||||
private const string ModelName = "omni-moderation-latest";
|
||||
|
||||
private readonly IAICredentialsProvider _aiCredentialsProvider = aiCredentialsProvider;
|
||||
|
||||
public async Task ValidateAsync(string fullPrompt)
|
||||
{
|
||||
try
|
||||
{
|
||||
ModerationClient moderationClient = new(ModelName, _aiCredentialsProvider.Key);
|
||||
var moderationClientResult = await moderationClient.ClassifyTextAsync(fullPrompt);
|
||||
var moderationResult = moderationClientResult.Value;
|
||||
|
||||
Logger.LogDebug($"{nameof(PromptModerationService)}.{nameof(ValidateAsync)} complete; {nameof(moderationResult.Flagged)}={moderationResult.Flagged}");
|
||||
|
||||
if (moderationResult.Flagged)
|
||||
{
|
||||
throw new PasteActionModeratedException();
|
||||
}
|
||||
}
|
||||
catch (ClientResultException ex)
|
||||
{
|
||||
throw new PasteActionException(ErrorHelpers.TranslateErrorText(ex.Status), ex);
|
||||
}
|
||||
}
|
||||
}
|
@ -137,7 +137,10 @@
|
||||
</data>
|
||||
<data name="PasteError" xml:space="preserve">
|
||||
<value>An error occurred during the paste operation</value>
|
||||
</data>
|
||||
</data>
|
||||
<data name="PasteActionModerated" xml:space="preserve">
|
||||
<value>The paste operation was moderated due to sensitive content. Please try another query.</value>
|
||||
</data>
|
||||
<data name="ClipboardHistoryButton.Text" xml:space="preserve">
|
||||
<value>Clipboard history</value>
|
||||
</data>
|
||||
@ -248,5 +251,5 @@
|
||||
</data>
|
||||
<data name="PasteAsFile_FilePrefix" xml:space="preserve">
|
||||
<value>PowerToys_Paste_</value>
|
||||
</data>
|
||||
</data>
|
||||
</root>
|
@ -0,0 +1,18 @@
|
||||
// Copyright (c) Microsoft Corporation
|
||||
// The Microsoft Corporation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Diagnostics.Tracing;
|
||||
|
||||
using Microsoft.PowerToys.Telemetry;
|
||||
using Microsoft.PowerToys.Telemetry.Events;
|
||||
|
||||
namespace AdvancedPaste.Telemetry;
|
||||
|
||||
[EventData]
|
||||
public class AdvancedPasteSemanticKernelErrorEvent(string error) : EventBase, IEvent
|
||||
{
|
||||
public string Error { get; set; } = error;
|
||||
|
||||
public PartA_PrivTags PartA_PrivTags => PartA_PrivTags.ProductAndServiceUsage;
|
||||
}
|
Loading…
Reference in New Issue
Block a user