diff --git a/src/TimeDataFrames.jl b/src/TimeDataFrames.jl index 2e7ad3664a76047093680955acbd261094b44571..0cec102f6af89c2fab5a4a839657c19adfcc2cfa 100644 --- a/src/TimeDataFrames.jl +++ b/src/TimeDataFrames.jl @@ -85,7 +85,56 @@ Base.view(df::TimeDataFrame, idx::CartesianIndex{2}) = view(df, idx[1], idx[2]) Base.setindex!(df::TimeDataFrame, val, idx::CartesianIndex{2}) = (df[idx[1], idx[2]] = val) -Base.broadcastable(df::TimeDataFrame) = df +Base.broadcastable(tdf::TimeDataFrame) = tdf +struct TimeDataFrameStyle <: Base.Broadcast.BroadcastStyle end + +Base.Broadcast.BroadcastStyle(::Type{<:TimeDataFrame}) = + TimeDataFrameStyle() + +Base.Broadcast.BroadcastStyle(::TimeDataFrameStyle, ::Base.Broadcast.BroadcastStyle) = TimeDataFrameStyle() +Base.Broadcast.BroadcastStyle(::Base.Broadcast.BroadcastStyle, ::TimeDataFrameStyle) = TimeDataFrameStyle() +Base.Broadcast.BroadcastStyle(::TimeDataFrameStyle, ::TimeDataFrameStyle) = TimeDataFrameStyle() + +function Base.copy(bc::Base.Broadcast.Broadcasted{TimeDataFrameStyle}) + ndim = length(axes(bc)) + if ndim != 2 + throw(DimensionMismatch("cannot broadcast a time data frame into $ndim dimensions")) + end + + data + bcf = Base.Broadcast.flatten(bc) + colnames = unique!([_names(df) for df in bcf.args if df isa AbstractDataFrame]) + if length(colnames) != 1 + wrongnames = setdiff(union(colnames...), intersect(colnames...)) + if isempty(wrongnames) + throw(ArgumentError("Column names in broadcasted data frames " * + "must have the same order")) + else + msg = join(wrongnames, ", ", " and ") + throw(ArgumentError("Column names in broadcasted data frames must match. " * + "Non matching column names are $msg")) + end + end + nrows = length(axes(bcf)[1]) + df = DataFrame() + for i in axes(bcf)[2] + if nrows == 0 + col = Any[] + else + bcf′ = getcolbc(bcf, i) + v1 = bcf′[CartesianIndex(1, i)] + startcol = similar(Vector{typeof(v1)}, nrows) + startcol[1] = v1 + col = copyto_widen!(startcol, bcf′, 2, i) + end + df[!, colnames[1][i]] = col + end + return df +end + + + + Base.ndims(::TimeDataFrame) = 2 Base.ndims(::Type{<:TimeDataFrame}) = 2 index(df::TimeDataFrame) = getfield(getfield(df, :data), :colindex)