/*
 * Decompiled with CFR 0.152.
 */
package ghidra.app.util.opinion;

import ghidra.app.util.MemoryBlockUtils;
import ghidra.app.util.Option;
import ghidra.app.util.bin.ByteProvider;
import ghidra.app.util.bin.StructConverter;
import ghidra.app.util.bin.format.omf.AbstractOmfRecordFactory;
import ghidra.app.util.bin.format.omf.OmfException;
import ghidra.app.util.bin.format.omf.OmfRecord;
import ghidra.app.util.bin.format.omf.OmfUtils;
import ghidra.app.util.bin.format.omf.omf51.Omf51Content;
import ghidra.app.util.bin.format.omf.omf51.Omf51ExternalDef;
import ghidra.app.util.bin.format.omf.omf51.Omf51ExternalDefsRecord;
import ghidra.app.util.bin.format.omf.omf51.Omf51Fixup;
import ghidra.app.util.bin.format.omf.omf51.Omf51FixupRecord;
import ghidra.app.util.bin.format.omf.omf51.Omf51PublicDef;
import ghidra.app.util.bin.format.omf.omf51.Omf51PublicDefsRecord;
import ghidra.app.util.bin.format.omf.omf51.Omf51RecordFactory;
import ghidra.app.util.bin.format.omf.omf51.Omf51Segment;
import ghidra.app.util.bin.format.omf.omf51.Omf51SegmentDefs;
import ghidra.app.util.importer.MessageLog;
import ghidra.app.util.opinion.AbstractProgramLoader;
import ghidra.app.util.opinion.AbstractProgramWrapperLoader;
import ghidra.app.util.opinion.LoadSpec;
import ghidra.app.util.opinion.Loader;
import ghidra.app.util.opinion.QueryOpinionService;
import ghidra.app.util.opinion.QueryResult;
import ghidra.program.database.mem.FileBytes;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressFactory;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.address.AddressSetView;
import ghidra.program.model.address.AddressSpace;
import ghidra.program.model.data.DataType;
import ghidra.program.model.data.DataUtilities;
import ghidra.program.model.listing.Data;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.FunctionManager;
import ghidra.program.model.listing.Program;
import ghidra.program.model.mem.MemoryAccessException;
import ghidra.program.model.mem.MemoryBlock;
import ghidra.program.model.reloc.Relocation;
import ghidra.program.model.symbol.ExternalLocation;
import ghidra.program.model.symbol.SourceType;
import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class Omf51Loader
extends AbstractProgramWrapperLoader {
    public static final String OMF51_NAME = "Object Module Format (OMF-51)";
    public static final long MIN_BYTE_LENGTH = 11L;

    @Override
    public Collection<LoadSpec> findSupportedLoadSpecs(ByteProvider provider) throws IOException {
        ArrayList<LoadSpec> loadSpecs = new ArrayList<LoadSpec>();
        if (provider.length() < 11L) {
            return loadSpecs;
        }
        Omf51RecordFactory factory = new Omf51RecordFactory(provider);
        try {
            OmfRecord first = ((AbstractOmfRecordFactory)factory).readNextRecord();
            if (((AbstractOmfRecordFactory)factory).getStartRecordTypes().contains(first.getRecordType()) && first.validCheckSum()) {
                List<QueryResult> results = QueryOpinionService.query(this.getName(), "8051", null);
                for (QueryResult result : results) {
                    loadSpecs.add(new LoadSpec((Loader)this, 0L, result));
                }
                if (loadSpecs.isEmpty()) {
                    loadSpecs.add(new LoadSpec((Loader)this, 0L, true));
                }
            }
        }
        catch (OmfException | IOException exception) {
            // empty catch block
        }
        return loadSpecs;
    }

    @Override
    protected void load(ByteProvider provider, LoadSpec loadSpec, List<Option> options, Program program, TaskMonitor monitor, MessageLog log) throws IOException, CancelledException {
        FileBytes fileBytes = MemoryBlockUtils.createFileBytes(program, provider, monitor);
        Omf51RecordFactory factory = new Omf51RecordFactory(provider);
        try {
            List<OmfRecord> records = OmfUtils.readRecords(factory);
            Map<Integer, Address> segmentToAddr = this.processMemoryBlocks(program, fileBytes, records, log, monitor);
            Map<Integer, Address> extIdToAddr = this.processExternalDefs(program, records, log, monitor);
            this.performFixups(program, fileBytes, records, segmentToAddr, extIdToAddr, log, monitor);
            this.markupPublicDefs(program, records, segmentToAddr, log, monitor);
            this.markupRecords(program, fileBytes, records, log, monitor);
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    private Map<Integer, Address> processMemoryBlocks(Program program, FileBytes fileBytes, List<OmfRecord> records, MessageLog log, TaskMonitor monitor) throws Exception {
        List segments = OmfUtils.filterRecords(records, Omf51SegmentDefs.class).map(Omf51SegmentDefs::getSegments).flatMap(Collection::stream).sorted((a, b) -> Integer.compare(a.id(), b.id())).toList();
        Map<Integer, List<Omf51Content>> contentMap = OmfUtils.filterRecords(records, Omf51Content.class).collect(Collectors.groupingBy(Omf51Content::getSegId));
        AddressSet usedAddresses = new AddressSet();
        HashMap<String, Integer> segmentSizes = new HashMap<String, Integer>();
        HashMap<String, Address> segmentEnds = new HashMap<String, Address>();
        for (Omf51Segment segment : segments) {
            segmentSizes.compute(this.key(segment), (k, v) -> v == null ? segment.size() : v + segment.size());
        }
        HashMap<Integer, Address> segmentToAddr = new HashMap<Integer, Address>();
        for (Omf51Segment segment : segments) {
            Address segmentAddr;
            String blockName;
            List<Omf51Content> segmentContent = contentMap.get(segment.id());
            String string = blockName = segment.isAbsolute() ? "<ABSOLUTE>" : segment.name().str();
            if (blockName.isBlank()) {
                blockName = "<NONAME>";
            }
            AddressSpace space = this.getAddressSpace(program, segment);
            if (segmentContent != null) {
                segmentAddr = this.findAddr(segment, segmentSizes, segmentEnds, space, usedAddresses);
                for (Omf51Content content : segmentContent) {
                    Address contentAddr = segment.isAbsolute() ? space.getAddress((long)content.getOffset()) : segmentAddr.add((long)content.getOffset());
                    try {
                        MemoryBlockUtils.createInitializedBlock(program, false, blockName, contentAddr, fileBytes, content.getDataIndex(), content.getDataSize(), "", space.getName(), true, !segment.isCode(), segment.isCode(), log);
                    }
                    catch (Exception e) {
                        log.appendMsg(e.getMessage());
                    }
                }
            } else {
                segmentAddr = this.findAddr(segment, segmentSizes, segmentEnds, space, usedAddresses);
                MemoryBlockUtils.createUninitializedBlock(program, false, blockName, segmentAddr, segment.size(), "", space.getName(), true, true, false, log);
            }
            if (segment.isCode()) {
                AbstractProgramLoader.markAsFunction(program, blockName, segmentAddr);
            }
            segmentToAddr.put(segment.id(), segmentAddr);
        }
        return segmentToAddr;
    }

    private Map<Integer, Address> processExternalDefs(Program program, List<OmfRecord> records, MessageLog log, TaskMonitor monitor) throws Exception {
        HashMap<Integer, Address> map = new HashMap<Integer, Address>();
        List<Omf51ExternalDef> defs = OmfUtils.filterRecords(records, Omf51ExternalDefsRecord.class).map(Omf51ExternalDefsRecord::getDefinitions).flatMap(Collection::stream).filter(def -> !def.isVariable()).sorted((a, b) -> Integer.compare(a.getExtId(), b.getExtId())).toList();
        int externalSize = defs.size();
        if (externalSize == 0) {
            return map;
        }
        Address codeEndAddr = Arrays.stream(program.getMemory().getBlocks()).filter(block -> block.getSourceName().equals("CODE")).map(block -> block.getEnd()).sorted((a, b) -> b.compareTo(a)).findFirst().get();
        int availableSize = (int)(program.getAddressFactory().getAddressSpace("CODE").getMaxAddress().getOffset() - codeEndAddr.getOffset());
        if (availableSize < externalSize) {
            throw new Exception("Not enough CODE space for externals");
        }
        MemoryBlock block2 = MemoryBlockUtils.createUninitializedBlock(program, false, "EXTERNAL", codeEndAddr.add(1L), externalSize, "", "CODE", false, false, true, log);
        if (block2 == null) {
            throw new Exception("Couldn't create EXTERNAL block");
        }
        block2.setArtificial(true);
        block2.setComment("NOTE: This block is artificial and allows external fixups to work correctly");
        Address addr = codeEndAddr;
        for (Omf51ExternalDef def2 : defs) {
            addr = addr.add(1L);
            Function f = program.getFunctionManager().createFunction(def2.getName().str(), addr, (AddressSetView)new AddressSet(addr), SourceType.IMPORTED);
            ExternalLocation extLoc = program.getExternalManager().addExtFunction("<EXTERNAL>", def2.getName().str(), null, SourceType.IMPORTED);
            f.setThunkedFunction(extLoc.getFunction());
            map.put(def2.getExtId(), addr);
        }
        return map;
    }

    private void performFixups(Program program, FileBytes fileBytes, List<OmfRecord> records, Map<Integer, Address> segmentToAddr, Map<Integer, Address> extIdToAddr, MessageLog log, TaskMonitor monitor) throws Exception {
        OmfRecord previous = null;
        for (OmfRecord record : records) {
            if (record instanceof Omf51FixupRecord) {
                Omf51FixupRecord fixupRec = (Omf51FixupRecord)record;
                if (!(previous instanceof Omf51Content)) {
                    throw new Exception("Record prior to fixup is not content!");
                }
                Omf51Content content = (Omf51Content)previous;
                Address segmentAddr = segmentToAddr.get(content.getSegId());
                if (segmentAddr == null) {
                    throw new Exception("Failed to lookup segment ID 0x%x for content fixup!".formatted(content.getSegId()));
                }
                Address contentAddr = segmentAddr.add((long)content.getOffset());
                for (Omf51Fixup fixup : fixupRec.getFixups()) {
                    Address refLocAddr = contentAddr.add((long)fixup.getRefLoc());
                    Address baseAddr = null;
                    Address addr = null;
                    Relocation.Status status = Relocation.Status.UNSUPPORTED;
                    switch (fixup.getBlockType()) {
                        case 0: 
                        case 1: {
                            baseAddr = segmentToAddr.get(fixup.getBlockId());
                            addr = fixup.getRefType() == 7 ? baseAddr.getNewAddress((baseAddr.getOffset() - 32L) * 8L + (long)fixup.getOffset()) : baseAddr.add((long)fixup.getOffset());
                            break;
                        }
                        case 2: {
                            addr = extIdToAddr.get(fixup.getBlockId());
                        }
                    }
                    if (addr != null) {
                        this.applyFixup(program, refLocAddr, addr, fixup.getRefType());
                        status = Relocation.Status.APPLIED;
                    }
                    program.getRelocationTable().add(refLocAddr, status, fixup.getRefType(), new long[]{fixup.getRefLoc(), fixup.getRefType(), fixup.getBlockType(), fixup.getBlockId(), fixup.getOffset()}, 0, null);
                }
            }
            previous = record;
        }
    }

    private void applyFixup(Program program, Address refLocAddr, Address addr, int refType) throws MemoryAccessException, Exception {
        int normAddr = (int)addr.getOffset() & 0xFFFF;
        switch (refType) {
            case 6: 
            case 7: {
                if (normAddr <= 127) break;
                throw new Exception("Bad address 0x%04x for BIT fixup!".formatted(normAddr));
            }
            case 1: {
                if (normAddr <= 255) break;
                throw new Exception("Bad address 0x%04x for BYTE fixup!".formatted(normAddr));
            }
        }
        switch (refType) {
            case 0: 
            case 1: 
            case 6: 
            case 7: {
                program.getMemory().setByte(refLocAddr, (byte)(normAddr & 0xFF));
                break;
            }
            case 3: {
                program.getMemory().setByte(refLocAddr, (byte)(normAddr >> 7 & 0xFF));
                break;
            }
            case 4: {
                program.getMemory().setShort(refLocAddr, (short)(normAddr & 0xFFFF));
                break;
            }
            default: {
                throw new Exception("Unhandled ref type");
            }
        }
    }

    private void markupPublicDefs(Program program, List<OmfRecord> records, Map<Integer, Address> segmentToAddr, MessageLog log, TaskMonitor monitor) throws Exception {
        monitor.setMessage("Marking up public defs...");
        for (OmfRecord record : records) {
            if (!(record instanceof Omf51PublicDefsRecord)) continue;
            Omf51PublicDefsRecord publicDefRec = (Omf51PublicDefsRecord)record;
            for (Omf51PublicDef def : publicDefRec.getDefinitions()) {
                FunctionManager functionMgr;
                Function function;
                if (def.getUsageType() == 5) continue;
                Address segmentAddr = segmentToAddr.get(def.getSegId());
                if (segmentAddr == null) {
                    throw new Exception("Failed to get lookup segment ID 0x%x for public def!".formatted(def.getSegId()));
                }
                Address defAddress = segmentAddr.add((long)def.getOffset());
                if (!def.isVariable() && (function = (functionMgr = program.getFunctionManager()).getFunctionAt(defAddress)) == null) {
                    function = functionMgr.createFunction(def.getName().str(), defAddress, (AddressSetView)new AddressSet(defAddress), SourceType.IMPORTED);
                }
                program.getSymbolTable().createLabel(defAddress, def.getName().str(), null, SourceType.IMPORTED);
                program.getSymbolTable().addExternalEntryPoint(defAddress);
            }
        }
    }

    private void markupRecords(Program program, FileBytes fileBytes, List<OmfRecord> records, MessageLog log, TaskMonitor monitor) {
        monitor.setMessage("Marking up records...");
        int size = records.stream().mapToInt(r -> r.getRecordLength() + 3).sum();
        try {
            Address recordSpaceAddr = AddressSpace.OTHER_SPACE.getAddress(0L);
            MemoryBlock headerBlock = MemoryBlockUtils.createInitializedBlock(program, true, "RECORDS", recordSpaceAddr, fileBytes, 0L, size, "", "", false, false, false, log);
            Address start = headerBlock.getStart();
            for (OmfRecord record : records) {
                try {
                    Data d = DataUtilities.createData((Program)program, (Address)start.add(record.getRecordOffset()), (DataType)record.toDataType(), (int)-1, (DataUtilities.ClearDataMode)DataUtilities.ClearDataMode.CHECK_FOR_SPACE);
                    StructConverter.setEndian(d, false);
                }
                catch (Exception e) {
                    log.appendMsg("Failed to markup record type 0x%x at offset 0x%x. %s.".formatted(record.getRecordType(), record.getRecordOffset(), e.getMessage()));
                }
            }
        }
        catch (Exception e) {
            log.appendMsg("Failed to markup records: " + e.getMessage());
        }
    }

    private Address findAddr(Omf51Segment segment, Map<String, Integer> segmentSizes, Map<String, Address> segmentEnds, AddressSpace space, AddressSet usedAddresses) throws Exception {
        return switch (segment.relType()) {
            case 0 -> {
                Address end;
                if (segment.id() != 0) {
                    throw new Exception("Absolute segment does not have ID 0!");
                }
                Address start = space.getAddress((long)segment.base());
                if (usedAddresses.intersects(start, end = start.add((long)segment.size()))) {
                    throw new Exception("Absolute segment overlaps with existing segment!");
                }
                usedAddresses.add(start, end);
                yield start;
            }
            case 1, 2, 3, 4, 5 -> {
                Address lastEnd = segmentEnds.get(this.key(segment));
                if (lastEnd != null) {
                    Address start = lastEnd.add(1L);
                    Address end = start.add((long)(segment.size() - 1));
                    segmentEnds.put(this.key(segment), end);
                    yield start;
                }
                Address start = space.getMinAddress();
                Address end = start.add((long)(segment.size() - 1));
                int requiredSize = segmentSizes.get(this.key(segment));
                AddressSet intersection = usedAddresses.intersectRange(start, start.add((long)(requiredSize - 1)));
                while (!intersection.isEmpty()) {
                    start = intersection.getMaxAddress().add(1L);
                    end = start.add((long)(segment.size() - 1));
                    intersection = usedAddresses.intersectRange(start, start.add((long)(requiredSize - 1)));
                }
                usedAddresses.add(start, start.add((long)(requiredSize - 1)));
                segmentEnds.put(this.key(segment), end);
                yield start;
            }
            default -> throw new Exception("Skipping segment '%s'. Relocation type 0x%x is not yet supported".formatted(segment.name(), segment.relType()));
        };
    }

    private AddressSpace getAddressSpace(Program program, Omf51Segment segment) throws Exception {
        AddressFactory addressFactory = program.getAddressFactory();
        return addressFactory.getAddressSpace(switch (segment.getType()) {
            case 0 -> "CODE";
            case 1 -> "EXTMEM";
            case 2 -> "INTMEM";
            case 3 -> "INTMEM";
            case 4 -> "BITS";
            default -> throw new Exception("Unsupported address space: 0x%x".formatted(segment.getType()));
        });
    }

    private String key(Omf51Segment segment) {
        return segment.name().str() + "_" + segment.getType();
    }

    @Override
    public String getName() {
        return OMF51_NAME;
    }
}

