diff --git a/packages/core/src/commands/toggleList.ts b/packages/core/src/commands/toggleList.ts index a912f7277..50c29ded1 100644 --- a/packages/core/src/commands/toggleList.ts +++ b/packages/core/src/commands/toggleList.ts @@ -1,9 +1,53 @@ import { NodeType } from 'prosemirror-model' +import { Transaction } from 'prosemirror-state' +import { canJoin } from 'prosemirror-transform' import { RawCommands } from '../types' import { getNodeType } from '../helpers/getNodeType' import { findParentNode } from '../helpers/findParentNode' import { isList } from '../helpers/isList' +const joinListBackwards = (tr: Transaction, listType: NodeType): boolean => { + const list = findParentNode(node => node.type === listType)(tr.selection) + + if (!list) { + return true + } + + const before = tr.doc.resolve(Math.max(0, list.pos - 1)).before(list.depth) + const nodeBefore = tr.doc.nodeAt(before) + const canJoinBackwards = list.node.type === nodeBefore?.type + && canJoin(tr.doc, list.pos) + + if (!canJoinBackwards) { + return true + } + + tr.join(list.pos) + + return true +} + +const joinListForwards = (tr: Transaction, listType: NodeType): boolean => { + const list = findParentNode(node => node.type === listType)(tr.selection) + + if (!list) { + return true + } + + const after = tr.doc.resolve(list.start).after(list.depth) + const nodeAfter = tr.doc.nodeAt(after) + const canJoinForwards = list.node.type === nodeAfter?.type + && canJoin(tr.doc, after) + + if (!canJoinForwards) { + return true + } + + tr.join(after) + + return true +} + declare module '@tiptap/core' { interface Commands { toggleList: { @@ -43,21 +87,31 @@ export const toggleList: RawCommands['toggleList'] = (listTypeOrName, itemTypeOr && listType.validContent(parentList.node.content) && dispatch ) { - tr.setNodeMarkup(parentList.pos, listType) + return chain() + .command(() => { + tr.setNodeMarkup(parentList.pos, listType) - return true + return true + }) + .command(() => joinListBackwards(tr, listType)) + .command(() => joinListForwards(tr, listType)) + .run() } } - const canWrapInList = can().wrapInList(listType) + return chain() + // try to convert node to default node if needed + .command(() => { + const canWrapInList = can().wrapInList(listType) - // try to convert node to paragraph if needed - if (!canWrapInList) { - return chain() - .clearNodes() - .wrapInList(listType) - .run() - } + if (canWrapInList) { + return true + } - return commands.wrapInList(listType) + return commands.clearNodes() + }) + .wrapInList(listType) + .command(() => joinListBackwards(tr, listType)) + .command(() => joinListForwards(tr, listType)) + .run() }