@@ -11,9 +11,13 @@ import type {
1111 StickyOffsets ,
1212} from '../interface' ;
1313import HeaderRow from './HeaderRow' ;
14+ import cls from 'classnames' ;
15+ import { TableProps } from '..' ;
1416
1517function parseHeaderRows < RecordType > (
1618 rootColumns : ColumnsType < RecordType > ,
19+ classNames : TableProps [ 'classNames' ] [ 'header' ] ,
20+ styles : TableProps [ 'styles' ] [ 'header' ] ,
1721) : CellType < RecordType > [ ] [ ] {
1822 const rows : CellType < RecordType > [ ] [ ] = [ ] ;
1923
@@ -29,7 +33,8 @@ function parseHeaderRows<RecordType>(
2933 const colSpans : number [ ] = columns . filter ( Boolean ) . map ( column => {
3034 const cell : CellType < RecordType > = {
3135 key : column . key ,
32- className : column . className || '' ,
36+ className : cls ( column . className , classNames . cell ) || '' ,
37+ style : styles . cell ,
3338 children : column . title ,
3439 column,
3540 colStart : currentColIndex ,
@@ -97,18 +102,33 @@ const Header = <RecordType extends any>(props: HeaderProps<RecordType>) => {
97102
98103 const { stickyOffsets, columns, flattenColumns, onHeaderRow } = props ;
99104
100- const { prefixCls, getComponent } = useContext ( TableContext , [ 'prefixCls' , 'getComponent' ] ) ;
101- const rows = React . useMemo < CellType < RecordType > [ ] [ ] > ( ( ) => parseHeaderRows ( columns ) , [ columns ] ) ;
105+ const { prefixCls, getComponent, classNames, styles } = useContext ( TableContext , [
106+ 'prefixCls' ,
107+ 'getComponent' ,
108+ 'classNames' ,
109+ 'styles' ,
110+ ] ) ;
111+ const { header : headerCls = { } } = classNames || { } ;
112+ const { header : headerStyles = { } } = styles || { } ;
113+ const rows = React . useMemo < CellType < RecordType > [ ] [ ] > (
114+ ( ) => parseHeaderRows ( columns , headerCls , headerStyles ) ,
115+ [ columns , headerCls , headerStyles ] ,
116+ ) ;
102117
103118 const WrapperComponent = getComponent ( [ 'header' , 'wrapper' ] , 'thead' ) ;
104119 const trComponent = getComponent ( [ 'header' , 'row' ] , 'tr' ) ;
105120 const thComponent = getComponent ( [ 'header' , 'cell' ] , 'th' ) ;
106121
107122 return (
108- < WrapperComponent className = { `${ prefixCls } -thead` } >
123+ < WrapperComponent
124+ className = { cls ( `${ prefixCls } -thead` , headerCls . wrapper ) }
125+ style = { headerStyles . wrapper }
126+ >
109127 { rows . map ( ( row , rowIndex ) => {
110128 const rowNode = (
111129 < HeaderRow
130+ classNames = { headerCls }
131+ styles = { headerStyles }
112132 key = { rowIndex }
113133 flattenColumns = { flattenColumns }
114134 cells = { row }
0 commit comments