From 016bda4010b217df43a321408dbd58e17d37a3ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Ku=CC=88hn?= Date: Thu, 15 Apr 2021 21:43:41 +0200 Subject: [PATCH] wip: fix extendNodeSchema and extendMarkSchema --- packages/core/src/Extension.ts | 3 +- packages/core/src/Mark.ts | 3 +- packages/core/src/helpers/getSchema.ts | 47 +++++++++++--------------- 3 files changed, 23 insertions(+), 30 deletions(-) diff --git a/packages/core/src/Extension.ts b/packages/core/src/Extension.ts index 1eef45a0a..73c56a6d1 100644 --- a/packages/core/src/Extension.ts +++ b/packages/core/src/Extension.ts @@ -3,6 +3,7 @@ import { Command as ProseMirrorCommand } from 'prosemirror-commands' import { InputRule } from 'prosemirror-inputrules' import { Editor } from './Editor' import { Node } from './Node' +import { Mark } from './Mark' import mergeDeep from './utilities/mergeDeep' import { GlobalAttributes, RawCommands, ParentConfig } from './types' import { ExtensionConfig } from '.' @@ -102,7 +103,7 @@ declare module '@tiptap/core' { options: Options, parent: ParentConfig>['extendMarkSchema'], }, - extension: Node, + extension: Mark, ) => { [key: string]: any, }) | null, diff --git a/packages/core/src/Mark.ts b/packages/core/src/Mark.ts index 767a70eb8..2c5fe04f0 100644 --- a/packages/core/src/Mark.ts +++ b/packages/core/src/Mark.ts @@ -14,6 +14,7 @@ import { GlobalAttributes, ParentConfig, } from './types' +import { Node } from './Node' import { MarkConfig } from '.' import { Editor } from './Editor' @@ -117,7 +118,7 @@ declare module '@tiptap/core' { options: Options, parent: ParentConfig>['extendMarkSchema'], }, - extension: Node, + extension: Mark, ) => { [key: string]: any, }) | null, diff --git a/packages/core/src/helpers/getSchema.ts b/packages/core/src/helpers/getSchema.ts index 05cc00749..92fabbf5f 100644 --- a/packages/core/src/helpers/getSchema.ts +++ b/packages/core/src/helpers/getSchema.ts @@ -1,5 +1,5 @@ import { NodeSpec, MarkSpec, Schema } from 'prosemirror-model' -import { Extensions } from '../types' +import { AnyConfig, Extensions } from '../types' import { ExtensionConfig, NodeConfig, MarkConfig } from '..' import splitExtensions from './splitExtensions' import getAttributesFromExtensions from './getAttributesFromExtensions' @@ -22,27 +22,10 @@ function cleanUpSchemaItem(data: T) { export default function getSchema(extensions: Extensions): Schema { const allAttributes = getAttributesFromExtensions(extensions) const { nodeExtensions, markExtensions } = splitExtensions(extensions) - const topNode = nodeExtensions.find(extension => extension.config.topNode)?.config.name - const nodeSchemaExtenders: ( - | ExtensionConfig['extendNodeSchema'] - | NodeConfig['extendNodeSchema'] - | MarkConfig['extendNodeSchema'] - )[] = [] - const markSchemaExtenders: ( - | ExtensionConfig['extendNodeSchema'] - | NodeConfig['extendNodeSchema'] - | MarkConfig['extendNodeSchema'] - )[] = [] - - extensions.forEach(extension => { - if (typeof extension.config.extendNodeSchema === 'function') { - nodeSchemaExtenders.push(extension.config.extendNodeSchema) - } - - if (typeof extension.config.extendMarkSchema === 'function') { - markSchemaExtenders.push(extension.config.extendMarkSchema) - } - }) + const topNodeExtension = nodeExtensions.find(extension => getExtensionField(extension, 'topNode')) + const topNode = topNodeExtension + ? getExtensionField(topNodeExtension, 'name') + : null const nodes = Object.fromEntries(nodeExtensions.map(extension => { const extensionAttributes = allAttributes.filter(attribute => attribute.type === extension.config.name) @@ -50,12 +33,16 @@ export default function getSchema(extensions: Extensions): Schema { options: extension.options, } - const extraNodeFields = nodeSchemaExtenders.reduce((fields, nodeSchemaExtender) => { - const extraFields = callOrReturn(nodeSchemaExtender, context, extension) + const extraNodeFields = extensions.reduce((fields, e) => { + const extendNodeSchema = getExtensionField( + e, + 'extendNodeSchema', + context, + ) return { ...fields, - ...extraFields, + ...(extendNodeSchema ? extendNodeSchema(extension) : {}), } }, {}) @@ -101,12 +88,16 @@ export default function getSchema(extensions: Extensions): Schema { options: extension.options, } - const extraMarkFields = markSchemaExtenders.reduce((fields, markSchemaExtender) => { - const extraFields = callOrReturn(markSchemaExtender, context, extension) + const extraMarkFields = extensions.reduce((fields, e) => { + const extendMarkSchema = getExtensionField( + e, + 'extendMarkSchema', + context, + ) return { ...fields, - ...extraFields, + ...(extendMarkSchema ? extendMarkSchema(extension) : {}), } }, {})