import {
  extension,
  getMarkRanges,
  getTextSelection,
  Helper,
  helper,
  PrimitiveSelection,
} from '@remirror/core'
import { FontFamilyExtension } from '@remirror/extension-font-family'
import {
  ApplySchemaAttributes,
  ExtensionPriority,
  isElementDomNode,
  isString,
  joinStyles,
  MarkExtensionSpec,
  MarkSpecOverride,
  omitExtraAttributes,
  ProsemirrorAttributes,
} from 'remirror'
import { fontFamilyAvailable } from './FontFamilyCustom.constants'

const DEFAULT_FONT_FAMILY = 'Lato'

const FONT_FAMILY_ATTRIBUTE = 'data-font-family'

type FontFamilyAttributes = ProsemirrorAttributes<{
  fontFamily?: string
}>

/**
 * Add a font family to the selected text (or text within a specified range).
 */
// @ts-ignore
@extension({})
export class FontFamilyCustomExtension extends FontFamilyExtension {
  createMarkSpec(
    extra: ApplySchemaAttributes,
    override: MarkSpecOverride,
  ): MarkExtensionSpec {
    return {
      ...override,
      attrs: { ...extra.defaults(), fontFamily: { default: null } },
      parseDOM: [
        {
          tag: `span[${FONT_FAMILY_ATTRIBUTE}]`,
          getAttrs: (dom: string | HTMLElement) => {
            if (!isElementDomNode(dom)) {
              return false
            }

            const fontFamily = dom.getAttribute(FONT_FAMILY_ATTRIBUTE)

            if (!fontFamily) {
              return false
            }

            return { ...extra.parse(dom), fontFamily }
          },
        },
        {
          // Get the color from the css style property. This is useful for pasted content.
          style: 'font-family',
          priority: ExtensionPriority.Low,
          getAttrs: (fontFamily) => {
            if (!isString(fontFamily)) {
              return false
            }

            const fonts = fontFamily
              .split(', ')
              .map((font) => font.replace(/['"]+/g, ''))
            const filteredFonts = fonts.filter((font) =>
              fontFamilyAvailable.includes(font),
            )
            const filteredFontFamily = filteredFonts.length
              ? filteredFonts.join(', ')
              : DEFAULT_FONT_FAMILY

            return {
              fontFamily: fontFamily ? filteredFontFamily : '',
            }
          },
        },
        ...(override.parseDOM ?? []),
      ],
      toDOM: (mark) => {
        const { fontFamily } = omitExtraAttributes(
          mark.attrs,
          extra,
        ) as FontFamilyAttributes
        const extraAttrs = extra.dom(mark)
        let style = extraAttrs.style

        style = joinStyles({ fontFamily }, style)

        return [
          'span',
          { ...extraAttrs, style, [FONT_FAMILY_ATTRIBUTE]: fontFamily },
          0,
        ]
      },
    }
  }

  /**
   * Get the font family at the current selection (or provided custom selection).
   * Returns the font family in the non-empty selection
   */
  @helper()
  getFontFamilyForSelection(
    position?: PrimitiveSelection,
  ): Helper<string | null> {
    const state = this.store.getState()
    const selection = getTextSelection(position ?? state.selection, state.doc)
    const [range] = getMarkRanges(selection, this.type)

    if (range) {
      return range.mark.attrs.fontFamily
    }

    return DEFAULT_FONT_FAMILY
  }
}
